0008-run-clang-format.patch 2.2 MB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982298329842985298629872988298929902991299229932994299529962997299829993000300130023003300430053006300730083009301030113012301330143015301630173018301930203021302230233024302530263027302830293030303130323033303430353036303730383039304030413042304330443045304630473048304930503051305230533054305530563057305830593060306130623063306430653066306730683069307030713072307330743075307630773078307930803081308230833084308530863087308830893090309130923093309430953096309730983099310031013102310331043105310631073108310931103111311231133114311531163117311831193120312131223123312431253126312731283129313031313132313331343135313631373138313931403141314231433144314531463147314831493150315131523153315431553156315731583159316031613162316331643165316631673168316931703171317231733174317531763177317831793180318131823183318431853186318731883189319031913192319331943195319631973198319932003201320232033204320532063207320832093210321132123213321432153216321732183219322032213222322332243225322632273228322932303231323232333234323532363237323832393240324132423243324432453246324732483249325032513252325332543255325632573258325932603261326232633264326532663267326832693270327132723273327432753276327732783279328032813282328332843285328632873288328932903291329232933294329532963297329832993300330133023303330433053306330733083309331033113312331333143315331633173318331933203321332233233324332533263327332833293330333133323333333433353336333733383339334033413342334333443345334633473348334933503351335233533354335533563357335833593360336133623363336433653366336733683369337033713372337333743375337633773378337933803381338233833384338533863387338833893390339133923393339433953396339733983399340034013402340334043405340634073408340934103411341234133414341534163417341834193420342134223423342434253426342734283429343034313432343334343435343634373438343934403441344234433444344534463447344834493450345134523453345434553456345734583459346034613462346334643465346634673468346934703471347234733474347534763477347834793480348134823483348434853486348734883489349034913492349334943495349634973498349935003501350235033504350535063507350835093510351135123513351435153516351735183519352035213522352335243525352635273528352935303531353235333534353535363537353835393540354135423543354435453546354735483549355035513552355335543555355635573558355935603561356235633564356535663567356835693570357135723573357435753576357735783579358035813582358335843585358635873588358935903591359235933594359535963597359835993600360136023603360436053606360736083609361036113612361336143615361636173618361936203621362236233624362536263627362836293630363136323633363436353636363736383639364036413642364336443645364636473648364936503651365236533654365536563657365836593660366136623663366436653666366736683669367036713672367336743675367636773678367936803681368236833684368536863687368836893690369136923693369436953696369736983699370037013702370337043705370637073708370937103711371237133714371537163717371837193720372137223723372437253726372737283729373037313732373337343735373637373738373937403741374237433744374537463747374837493750375137523753375437553756375737583759376037613762376337643765376637673768376937703771377237733774377537763777377837793780378137823783378437853786378737883789379037913792379337943795379637973798379938003801380238033804380538063807380838093810381138123813381438153816381738183819382038213822382338243825382638273828382938303831383238333834383538363837383838393840384138423843384438453846384738483849385038513852385338543855385638573858385938603861386238633864386538663867386838693870387138723873387438753876387738783879388038813882388338843885388638873888388938903891389238933894389538963897389838993900390139023903390439053906390739083909391039113912391339143915391639173918391939203921392239233924392539263927392839293930393139323933393439353936393739383939394039413942394339443945394639473948394939503951395239533954395539563957395839593960396139623963396439653966396739683969397039713972397339743975397639773978397939803981398239833984398539863987398839893990399139923993399439953996399739983999400040014002400340044005400640074008400940104011401240134014401540164017401840194020402140224023402440254026402740284029403040314032403340344035403640374038403940404041404240434044404540464047404840494050405140524053405440554056405740584059406040614062406340644065406640674068406940704071407240734074407540764077407840794080408140824083408440854086408740884089409040914092409340944095409640974098409941004101410241034104410541064107410841094110411141124113411441154116411741184119412041214122412341244125412641274128412941304131413241334134413541364137413841394140414141424143414441454146414741484149415041514152415341544155415641574158415941604161416241634164416541664167416841694170417141724173417441754176417741784179418041814182418341844185418641874188418941904191419241934194419541964197419841994200420142024203420442054206420742084209421042114212421342144215421642174218421942204221422242234224422542264227422842294230423142324233423442354236423742384239424042414242424342444245424642474248424942504251425242534254425542564257425842594260426142624263426442654266426742684269427042714272427342744275427642774278427942804281428242834284428542864287428842894290429142924293429442954296429742984299430043014302430343044305430643074308430943104311431243134314431543164317431843194320432143224323432443254326432743284329433043314332433343344335433643374338433943404341434243434344434543464347434843494350435143524353435443554356435743584359436043614362436343644365436643674368436943704371437243734374437543764377437843794380438143824383438443854386438743884389439043914392439343944395439643974398439944004401440244034404440544064407440844094410441144124413441444154416441744184419442044214422442344244425442644274428442944304431443244334434443544364437443844394440444144424443444444454446444744484449445044514452445344544455445644574458445944604461446244634464446544664467446844694470447144724473447444754476447744784479448044814482448344844485448644874488448944904491449244934494449544964497449844994500450145024503450445054506450745084509451045114512451345144515451645174518451945204521452245234524452545264527452845294530453145324533453445354536453745384539454045414542454345444545454645474548454945504551455245534554455545564557455845594560456145624563456445654566456745684569457045714572457345744575457645774578457945804581458245834584458545864587458845894590459145924593459445954596459745984599460046014602460346044605460646074608460946104611461246134614461546164617461846194620462146224623462446254626462746284629463046314632463346344635463646374638463946404641464246434644464546464647464846494650465146524653465446554656465746584659466046614662466346644665466646674668466946704671467246734674467546764677467846794680468146824683468446854686468746884689469046914692469346944695469646974698469947004701470247034704470547064707470847094710471147124713471447154716471747184719472047214722472347244725472647274728472947304731473247334734473547364737473847394740474147424743474447454746474747484749475047514752475347544755475647574758475947604761476247634764476547664767476847694770477147724773477447754776477747784779478047814782478347844785478647874788478947904791479247934794479547964797479847994800480148024803480448054806480748084809481048114812481348144815481648174818481948204821482248234824482548264827482848294830483148324833483448354836483748384839484048414842484348444845484648474848484948504851485248534854485548564857485848594860486148624863486448654866486748684869487048714872487348744875487648774878487948804881488248834884488548864887488848894890489148924893489448954896489748984899490049014902490349044905490649074908490949104911491249134914491549164917491849194920492149224923492449254926492749284929493049314932493349344935493649374938493949404941494249434944494549464947494849494950495149524953495449554956495749584959496049614962496349644965496649674968496949704971497249734974497549764977497849794980498149824983498449854986498749884989499049914992499349944995499649974998499950005001500250035004500550065007500850095010501150125013501450155016501750185019502050215022502350245025502650275028502950305031503250335034503550365037503850395040504150425043504450455046504750485049505050515052505350545055505650575058505950605061506250635064506550665067506850695070507150725073507450755076507750785079508050815082508350845085508650875088508950905091509250935094509550965097509850995100510151025103510451055106510751085109511051115112511351145115511651175118511951205121512251235124512551265127512851295130513151325133513451355136513751385139514051415142514351445145514651475148514951505151515251535154515551565157515851595160516151625163516451655166516751685169517051715172517351745175517651775178517951805181518251835184518551865187518851895190519151925193519451955196519751985199520052015202520352045205520652075208520952105211521252135214521552165217521852195220522152225223522452255226522752285229523052315232523352345235523652375238523952405241524252435244524552465247524852495250525152525253525452555256525752585259526052615262526352645265526652675268526952705271527252735274527552765277527852795280528152825283528452855286528752885289529052915292529352945295529652975298529953005301530253035304530553065307530853095310531153125313531453155316531753185319532053215322532353245325532653275328532953305331533253335334533553365337533853395340534153425343534453455346534753485349535053515352535353545355535653575358535953605361536253635364536553665367536853695370537153725373537453755376537753785379538053815382538353845385538653875388538953905391539253935394539553965397539853995400540154025403540454055406540754085409541054115412541354145415541654175418541954205421542254235424542554265427542854295430543154325433543454355436543754385439544054415442544354445445544654475448544954505451545254535454545554565457545854595460546154625463546454655466546754685469547054715472547354745475547654775478547954805481548254835484548554865487548854895490549154925493549454955496549754985499550055015502550355045505550655075508550955105511551255135514551555165517551855195520552155225523552455255526552755285529553055315532553355345535553655375538553955405541554255435544554555465547554855495550555155525553555455555556555755585559556055615562556355645565556655675568556955705571557255735574557555765577557855795580558155825583558455855586558755885589559055915592559355945595559655975598559956005601560256035604560556065607560856095610561156125613561456155616561756185619562056215622562356245625562656275628562956305631563256335634563556365637563856395640564156425643564456455646564756485649565056515652565356545655565656575658565956605661566256635664566556665667566856695670567156725673567456755676567756785679568056815682568356845685568656875688568956905691569256935694569556965697569856995700570157025703570457055706570757085709571057115712571357145715571657175718571957205721572257235724572557265727572857295730573157325733573457355736573757385739574057415742574357445745574657475748574957505751575257535754575557565757575857595760576157625763576457655766576757685769577057715772577357745775577657775778577957805781578257835784578557865787578857895790579157925793579457955796579757985799580058015802580358045805580658075808580958105811581258135814581558165817581858195820582158225823582458255826582758285829583058315832583358345835583658375838583958405841584258435844584558465847584858495850585158525853585458555856585758585859586058615862586358645865586658675868586958705871587258735874587558765877587858795880588158825883588458855886588758885889589058915892589358945895589658975898589959005901590259035904590559065907590859095910591159125913591459155916591759185919592059215922592359245925592659275928592959305931593259335934593559365937593859395940594159425943594459455946594759485949595059515952595359545955595659575958595959605961596259635964596559665967596859695970597159725973597459755976597759785979598059815982598359845985598659875988598959905991599259935994599559965997599859996000600160026003600460056006600760086009601060116012601360146015601660176018601960206021602260236024602560266027602860296030603160326033603460356036603760386039604060416042604360446045604660476048604960506051605260536054605560566057605860596060606160626063606460656066606760686069607060716072607360746075607660776078607960806081608260836084608560866087608860896090609160926093609460956096609760986099610061016102610361046105610661076108610961106111611261136114611561166117611861196120612161226123612461256126612761286129613061316132613361346135613661376138613961406141614261436144614561466147614861496150615161526153615461556156615761586159616061616162616361646165616661676168616961706171617261736174617561766177617861796180618161826183618461856186618761886189619061916192619361946195619661976198619962006201620262036204620562066207620862096210621162126213621462156216621762186219622062216222622362246225622662276228622962306231623262336234623562366237623862396240624162426243624462456246624762486249625062516252625362546255625662576258625962606261626262636264626562666267626862696270627162726273627462756276627762786279628062816282628362846285628662876288628962906291629262936294629562966297629862996300630163026303630463056306630763086309631063116312631363146315631663176318631963206321632263236324632563266327632863296330633163326333633463356336633763386339634063416342634363446345634663476348634963506351635263536354635563566357635863596360636163626363636463656366636763686369637063716372637363746375637663776378637963806381638263836384638563866387638863896390639163926393639463956396639763986399640064016402640364046405640664076408640964106411641264136414641564166417641864196420642164226423642464256426642764286429643064316432643364346435643664376438643964406441644264436444644564466447644864496450645164526453645464556456645764586459646064616462646364646465646664676468646964706471647264736474647564766477647864796480648164826483648464856486648764886489649064916492649364946495649664976498649965006501650265036504650565066507650865096510651165126513651465156516651765186519652065216522652365246525652665276528652965306531653265336534653565366537653865396540654165426543654465456546654765486549655065516552655365546555655665576558655965606561656265636564656565666567656865696570657165726573657465756576657765786579658065816582658365846585658665876588658965906591659265936594659565966597659865996600660166026603660466056606660766086609661066116612661366146615661666176618661966206621662266236624662566266627662866296630663166326633663466356636663766386639664066416642664366446645664666476648664966506651665266536654665566566657665866596660666166626663666466656666666766686669667066716672667366746675667666776678667966806681668266836684668566866687668866896690669166926693669466956696669766986699670067016702670367046705670667076708670967106711671267136714671567166717671867196720672167226723672467256726672767286729673067316732673367346735673667376738673967406741674267436744674567466747674867496750675167526753675467556756675767586759676067616762676367646765676667676768676967706771677267736774677567766777677867796780678167826783678467856786678767886789679067916792679367946795679667976798679968006801680268036804680568066807680868096810681168126813681468156816681768186819682068216822682368246825682668276828682968306831683268336834683568366837683868396840684168426843684468456846684768486849685068516852685368546855685668576858685968606861686268636864686568666867686868696870687168726873687468756876687768786879688068816882688368846885688668876888688968906891689268936894689568966897689868996900690169026903690469056906690769086909691069116912691369146915691669176918691969206921692269236924692569266927692869296930693169326933693469356936693769386939694069416942694369446945694669476948694969506951695269536954695569566957695869596960696169626963696469656966696769686969697069716972697369746975697669776978697969806981698269836984698569866987698869896990699169926993699469956996699769986999700070017002700370047005700670077008700970107011701270137014701570167017701870197020702170227023702470257026702770287029703070317032703370347035703670377038703970407041704270437044704570467047704870497050705170527053705470557056705770587059706070617062706370647065706670677068706970707071707270737074707570767077707870797080708170827083708470857086708770887089709070917092709370947095709670977098709971007101710271037104710571067107710871097110711171127113711471157116711771187119712071217122712371247125712671277128712971307131713271337134713571367137713871397140714171427143714471457146714771487149715071517152715371547155715671577158715971607161716271637164716571667167716871697170717171727173717471757176717771787179718071817182718371847185718671877188718971907191719271937194719571967197719871997200720172027203720472057206720772087209721072117212721372147215721672177218721972207221722272237224722572267227722872297230723172327233723472357236723772387239724072417242724372447245724672477248724972507251725272537254725572567257725872597260726172627263726472657266726772687269727072717272727372747275727672777278727972807281728272837284728572867287728872897290729172927293729472957296729772987299730073017302730373047305730673077308730973107311731273137314731573167317731873197320732173227323732473257326732773287329733073317332733373347335733673377338733973407341734273437344734573467347734873497350735173527353735473557356735773587359736073617362736373647365736673677368736973707371737273737374737573767377737873797380738173827383738473857386738773887389739073917392739373947395739673977398739974007401740274037404740574067407740874097410741174127413741474157416741774187419742074217422742374247425742674277428742974307431743274337434743574367437743874397440744174427443744474457446744774487449745074517452745374547455745674577458745974607461746274637464746574667467746874697470747174727473747474757476747774787479748074817482748374847485748674877488748974907491749274937494749574967497749874997500750175027503750475057506750775087509751075117512751375147515751675177518751975207521752275237524752575267527752875297530753175327533753475357536753775387539754075417542754375447545754675477548754975507551755275537554755575567557755875597560756175627563756475657566756775687569757075717572757375747575757675777578757975807581758275837584758575867587758875897590759175927593759475957596759775987599760076017602760376047605760676077608760976107611761276137614761576167617761876197620762176227623762476257626762776287629763076317632763376347635763676377638763976407641764276437644764576467647764876497650765176527653765476557656765776587659766076617662766376647665766676677668766976707671767276737674767576767677767876797680768176827683768476857686768776887689769076917692769376947695769676977698769977007701770277037704770577067707770877097710771177127713771477157716771777187719772077217722772377247725772677277728772977307731773277337734773577367737773877397740774177427743774477457746774777487749775077517752775377547755775677577758775977607761776277637764776577667767776877697770777177727773777477757776777777787779778077817782778377847785778677877788778977907791779277937794779577967797779877997800780178027803780478057806780778087809781078117812781378147815781678177818781978207821782278237824782578267827782878297830783178327833783478357836783778387839784078417842784378447845784678477848784978507851785278537854785578567857785878597860786178627863786478657866786778687869787078717872787378747875787678777878787978807881788278837884788578867887788878897890789178927893789478957896789778987899790079017902790379047905790679077908790979107911791279137914791579167917791879197920792179227923792479257926792779287929793079317932793379347935793679377938793979407941794279437944794579467947794879497950795179527953795479557956795779587959796079617962796379647965796679677968796979707971797279737974797579767977797879797980798179827983798479857986798779887989799079917992799379947995799679977998799980008001800280038004800580068007800880098010801180128013801480158016801780188019802080218022802380248025802680278028802980308031803280338034803580368037803880398040804180428043804480458046804780488049805080518052805380548055805680578058805980608061806280638064806580668067806880698070807180728073807480758076807780788079808080818082808380848085808680878088808980908091809280938094809580968097809880998100810181028103810481058106810781088109811081118112811381148115811681178118811981208121812281238124812581268127812881298130813181328133813481358136813781388139814081418142814381448145814681478148814981508151815281538154815581568157815881598160816181628163816481658166816781688169817081718172817381748175817681778178817981808181818281838184818581868187818881898190819181928193819481958196819781988199820082018202820382048205820682078208820982108211821282138214821582168217821882198220822182228223822482258226822782288229823082318232823382348235823682378238823982408241824282438244824582468247824882498250825182528253825482558256825782588259826082618262826382648265826682678268826982708271827282738274827582768277827882798280828182828283828482858286828782888289829082918292829382948295829682978298829983008301830283038304830583068307830883098310831183128313831483158316831783188319832083218322832383248325832683278328832983308331833283338334833583368337833883398340834183428343834483458346834783488349835083518352835383548355835683578358835983608361836283638364836583668367836883698370837183728373837483758376837783788379838083818382838383848385838683878388838983908391839283938394839583968397839883998400840184028403840484058406840784088409841084118412841384148415841684178418841984208421842284238424842584268427842884298430843184328433843484358436843784388439844084418442844384448445844684478448844984508451845284538454845584568457845884598460846184628463846484658466846784688469847084718472847384748475847684778478847984808481848284838484848584868487848884898490849184928493849484958496849784988499850085018502850385048505850685078508850985108511851285138514851585168517851885198520852185228523852485258526852785288529853085318532853385348535853685378538853985408541854285438544854585468547854885498550855185528553855485558556855785588559856085618562856385648565856685678568856985708571857285738574857585768577857885798580858185828583858485858586858785888589859085918592859385948595859685978598859986008601860286038604860586068607860886098610861186128613861486158616861786188619862086218622862386248625862686278628862986308631863286338634863586368637863886398640864186428643864486458646864786488649865086518652865386548655865686578658865986608661866286638664866586668667866886698670867186728673867486758676867786788679868086818682868386848685868686878688868986908691869286938694869586968697869886998700870187028703870487058706870787088709871087118712871387148715871687178718871987208721872287238724872587268727872887298730873187328733873487358736873787388739874087418742874387448745874687478748874987508751875287538754875587568757875887598760876187628763876487658766876787688769877087718772877387748775877687778778877987808781878287838784878587868787878887898790879187928793879487958796879787988799880088018802880388048805880688078808880988108811881288138814881588168817881888198820882188228823882488258826882788288829883088318832883388348835883688378838883988408841884288438844884588468847884888498850885188528853885488558856885788588859886088618862886388648865886688678868886988708871887288738874887588768877887888798880888188828883888488858886888788888889889088918892889388948895889688978898889989008901890289038904890589068907890889098910891189128913891489158916891789188919892089218922892389248925892689278928892989308931893289338934893589368937893889398940894189428943894489458946894789488949895089518952895389548955895689578958895989608961896289638964896589668967896889698970897189728973897489758976897789788979898089818982898389848985898689878988898989908991899289938994899589968997899889999000900190029003900490059006900790089009901090119012901390149015901690179018901990209021902290239024902590269027902890299030903190329033903490359036903790389039904090419042904390449045904690479048904990509051905290539054905590569057905890599060906190629063906490659066906790689069907090719072907390749075907690779078907990809081908290839084908590869087908890899090909190929093909490959096909790989099910091019102910391049105910691079108910991109111911291139114911591169117911891199120912191229123912491259126912791289129913091319132913391349135913691379138913991409141914291439144914591469147914891499150915191529153915491559156915791589159916091619162916391649165916691679168916991709171917291739174917591769177917891799180918191829183918491859186918791889189919091919192919391949195919691979198919992009201920292039204920592069207920892099210921192129213921492159216921792189219922092219222922392249225922692279228922992309231923292339234923592369237923892399240924192429243924492459246924792489249925092519252925392549255925692579258925992609261926292639264926592669267926892699270927192729273927492759276927792789279928092819282928392849285928692879288928992909291929292939294929592969297929892999300930193029303930493059306930793089309931093119312931393149315931693179318931993209321932293239324932593269327932893299330933193329333933493359336933793389339934093419342934393449345934693479348934993509351935293539354935593569357935893599360936193629363936493659366936793689369937093719372937393749375937693779378937993809381938293839384938593869387938893899390939193929393939493959396939793989399940094019402940394049405940694079408940994109411941294139414941594169417941894199420942194229423942494259426942794289429943094319432943394349435943694379438943994409441944294439444944594469447944894499450945194529453945494559456945794589459946094619462946394649465946694679468946994709471947294739474947594769477947894799480948194829483948494859486948794889489949094919492949394949495949694979498949995009501950295039504950595069507950895099510951195129513951495159516951795189519952095219522952395249525952695279528952995309531953295339534953595369537953895399540954195429543954495459546954795489549955095519552955395549555955695579558955995609561956295639564956595669567956895699570957195729573957495759576957795789579958095819582958395849585958695879588958995909591959295939594959595969597959895999600960196029603960496059606960796089609961096119612961396149615961696179618961996209621962296239624962596269627962896299630963196329633963496359636963796389639964096419642964396449645964696479648964996509651965296539654965596569657965896599660966196629663966496659666966796689669967096719672967396749675967696779678967996809681968296839684968596869687968896899690969196929693969496959696969796989699970097019702970397049705970697079708970997109711971297139714971597169717971897199720972197229723972497259726972797289729973097319732973397349735973697379738973997409741974297439744974597469747974897499750975197529753975497559756975797589759976097619762976397649765976697679768976997709771977297739774977597769777977897799780978197829783978497859786978797889789979097919792979397949795979697979798979998009801980298039804980598069807980898099810981198129813981498159816981798189819982098219822982398249825982698279828982998309831983298339834983598369837983898399840984198429843984498459846984798489849985098519852985398549855985698579858985998609861986298639864986598669867986898699870987198729873987498759876987798789879988098819882988398849885988698879888988998909891989298939894989598969897989898999900990199029903990499059906990799089909991099119912991399149915991699179918991999209921992299239924992599269927992899299930993199329933993499359936993799389939994099419942994399449945994699479948994999509951995299539954995599569957995899599960996199629963996499659966996799689969997099719972997399749975997699779978997999809981998299839984998599869987998899899990999199929993999499959996999799989999100001000110002100031000410005100061000710008100091001010011100121001310014100151001610017100181001910020100211002210023100241002510026100271002810029100301003110032100331003410035100361003710038100391004010041100421004310044100451004610047100481004910050100511005210053100541005510056100571005810059100601006110062100631006410065100661006710068100691007010071100721007310074100751007610077100781007910080100811008210083100841008510086100871008810089100901009110092100931009410095100961009710098100991010010101101021010310104101051010610107101081010910110101111011210113101141011510116101171011810119101201012110122101231012410125101261012710128101291013010131101321013310134101351013610137101381013910140101411014210143101441014510146101471014810149101501015110152101531015410155101561015710158101591016010161101621016310164101651016610167101681016910170101711017210173101741017510176101771017810179101801018110182101831018410185101861018710188101891019010191101921019310194101951019610197101981019910200102011020210203102041020510206102071020810209102101021110212102131021410215102161021710218102191022010221102221022310224102251022610227102281022910230102311023210233102341023510236102371023810239102401024110242102431024410245102461024710248102491025010251102521025310254102551025610257102581025910260102611026210263102641026510266102671026810269102701027110272102731027410275102761027710278102791028010281102821028310284102851028610287102881028910290102911029210293102941029510296102971029810299103001030110302103031030410305103061030710308103091031010311103121031310314103151031610317103181031910320103211032210323103241032510326103271032810329103301033110332103331033410335103361033710338103391034010341103421034310344103451034610347103481034910350103511035210353103541035510356103571035810359103601036110362103631036410365103661036710368103691037010371103721037310374103751037610377103781037910380103811038210383103841038510386103871038810389103901039110392103931039410395103961039710398103991040010401104021040310404104051040610407104081040910410104111041210413104141041510416104171041810419104201042110422104231042410425104261042710428104291043010431104321043310434104351043610437104381043910440104411044210443104441044510446104471044810449104501045110452104531045410455104561045710458104591046010461104621046310464104651046610467104681046910470104711047210473104741047510476104771047810479104801048110482104831048410485104861048710488104891049010491104921049310494104951049610497104981049910500105011050210503105041050510506105071050810509105101051110512105131051410515105161051710518105191052010521105221052310524105251052610527105281052910530105311053210533105341053510536105371053810539105401054110542105431054410545105461054710548105491055010551105521055310554105551055610557105581055910560105611056210563105641056510566105671056810569105701057110572105731057410575105761057710578105791058010581105821058310584105851058610587105881058910590105911059210593105941059510596105971059810599106001060110602106031060410605106061060710608106091061010611106121061310614106151061610617106181061910620106211062210623106241062510626106271062810629106301063110632106331063410635106361063710638106391064010641106421064310644106451064610647106481064910650106511065210653106541065510656106571065810659106601066110662106631066410665106661066710668106691067010671106721067310674106751067610677106781067910680106811068210683106841068510686106871068810689106901069110692106931069410695106961069710698106991070010701107021070310704107051070610707107081070910710107111071210713107141071510716107171071810719107201072110722107231072410725107261072710728107291073010731107321073310734107351073610737107381073910740107411074210743107441074510746107471074810749107501075110752107531075410755107561075710758107591076010761107621076310764107651076610767107681076910770107711077210773107741077510776107771077810779107801078110782107831078410785107861078710788107891079010791107921079310794107951079610797107981079910800108011080210803108041080510806108071080810809108101081110812108131081410815108161081710818108191082010821108221082310824108251082610827108281082910830108311083210833108341083510836108371083810839108401084110842108431084410845108461084710848108491085010851108521085310854108551085610857108581085910860108611086210863108641086510866108671086810869108701087110872108731087410875108761087710878108791088010881108821088310884108851088610887108881088910890108911089210893108941089510896108971089810899109001090110902109031090410905109061090710908109091091010911109121091310914109151091610917109181091910920109211092210923109241092510926109271092810929109301093110932109331093410935109361093710938109391094010941109421094310944109451094610947109481094910950109511095210953109541095510956109571095810959109601096110962109631096410965109661096710968109691097010971109721097310974109751097610977109781097910980109811098210983109841098510986109871098810989109901099110992109931099410995109961099710998109991100011001110021100311004110051100611007110081100911010110111101211013110141101511016110171101811019110201102111022110231102411025110261102711028110291103011031110321103311034110351103611037110381103911040110411104211043110441104511046110471104811049110501105111052110531105411055110561105711058110591106011061110621106311064110651106611067110681106911070110711107211073110741107511076110771107811079110801108111082110831108411085110861108711088110891109011091110921109311094110951109611097110981109911100111011110211103111041110511106111071110811109111101111111112111131111411115111161111711118111191112011121111221112311124111251112611127111281112911130111311113211133111341113511136111371113811139111401114111142111431114411145111461114711148111491115011151111521115311154111551115611157111581115911160111611116211163111641116511166111671116811169111701117111172111731117411175111761117711178111791118011181111821118311184111851118611187111881118911190111911119211193111941119511196111971119811199112001120111202112031120411205112061120711208112091121011211112121121311214112151121611217112181121911220112211122211223112241122511226112271122811229112301123111232112331123411235112361123711238112391124011241112421124311244112451124611247112481124911250112511125211253112541125511256112571125811259112601126111262112631126411265112661126711268112691127011271112721127311274112751127611277112781127911280112811128211283112841128511286112871128811289112901129111292112931129411295112961129711298112991130011301113021130311304113051130611307113081130911310113111131211313113141131511316113171131811319113201132111322113231132411325113261132711328113291133011331113321133311334113351133611337113381133911340113411134211343113441134511346113471134811349113501135111352113531135411355113561135711358113591136011361113621136311364113651136611367113681136911370113711137211373113741137511376113771137811379113801138111382113831138411385113861138711388113891139011391113921139311394113951139611397113981139911400114011140211403114041140511406114071140811409114101141111412114131141411415114161141711418114191142011421114221142311424114251142611427114281142911430114311143211433114341143511436114371143811439114401144111442114431144411445114461144711448114491145011451114521145311454114551145611457114581145911460114611146211463114641146511466114671146811469114701147111472114731147411475114761147711478114791148011481114821148311484114851148611487114881148911490114911149211493114941149511496114971149811499115001150111502115031150411505115061150711508115091151011511115121151311514115151151611517115181151911520115211152211523115241152511526115271152811529115301153111532115331153411535115361153711538115391154011541115421154311544115451154611547115481154911550115511155211553115541155511556115571155811559115601156111562115631156411565115661156711568115691157011571115721157311574115751157611577115781157911580115811158211583115841158511586115871158811589115901159111592115931159411595115961159711598115991160011601116021160311604116051160611607116081160911610116111161211613116141161511616116171161811619116201162111622116231162411625116261162711628116291163011631116321163311634116351163611637116381163911640116411164211643116441164511646116471164811649116501165111652116531165411655116561165711658116591166011661116621166311664116651166611667116681166911670116711167211673116741167511676116771167811679116801168111682116831168411685116861168711688116891169011691116921169311694116951169611697116981169911700117011170211703117041170511706117071170811709117101171111712117131171411715117161171711718117191172011721117221172311724117251172611727117281172911730117311173211733117341173511736117371173811739117401174111742117431174411745117461174711748117491175011751117521175311754117551175611757117581175911760117611176211763117641176511766117671176811769117701177111772117731177411775117761177711778117791178011781117821178311784117851178611787117881178911790117911179211793117941179511796117971179811799118001180111802118031180411805118061180711808118091181011811118121181311814118151181611817118181181911820118211182211823118241182511826118271182811829118301183111832118331183411835118361183711838118391184011841118421184311844118451184611847118481184911850118511185211853118541185511856118571185811859118601186111862118631186411865118661186711868118691187011871118721187311874118751187611877118781187911880118811188211883118841188511886118871188811889118901189111892118931189411895118961189711898118991190011901119021190311904119051190611907119081190911910119111191211913119141191511916119171191811919119201192111922119231192411925119261192711928119291193011931119321193311934119351193611937119381193911940119411194211943119441194511946119471194811949119501195111952119531195411955119561195711958119591196011961119621196311964119651196611967119681196911970119711197211973119741197511976119771197811979119801198111982119831198411985119861198711988119891199011991119921199311994119951199611997119981199912000120011200212003120041200512006120071200812009120101201112012120131201412015120161201712018120191202012021120221202312024120251202612027120281202912030120311203212033120341203512036120371203812039120401204112042120431204412045120461204712048120491205012051120521205312054120551205612057120581205912060120611206212063120641206512066120671206812069120701207112072120731207412075120761207712078120791208012081120821208312084120851208612087120881208912090120911209212093120941209512096120971209812099121001210112102121031210412105121061210712108121091211012111121121211312114121151211612117121181211912120121211212212123121241212512126121271212812129121301213112132121331213412135121361213712138121391214012141121421214312144121451214612147121481214912150121511215212153121541215512156121571215812159121601216112162121631216412165121661216712168121691217012171121721217312174121751217612177121781217912180121811218212183121841218512186121871218812189121901219112192121931219412195121961219712198121991220012201122021220312204122051220612207122081220912210122111221212213122141221512216122171221812219122201222112222122231222412225122261222712228122291223012231122321223312234122351223612237122381223912240122411224212243122441224512246122471224812249122501225112252122531225412255122561225712258122591226012261122621226312264122651226612267122681226912270122711227212273122741227512276122771227812279122801228112282122831228412285122861228712288122891229012291122921229312294122951229612297122981229912300123011230212303123041230512306123071230812309123101231112312123131231412315123161231712318123191232012321123221232312324123251232612327123281232912330123311233212333123341233512336123371233812339123401234112342123431234412345123461234712348123491235012351123521235312354123551235612357123581235912360123611236212363123641236512366123671236812369123701237112372123731237412375123761237712378123791238012381123821238312384123851238612387123881238912390123911239212393123941239512396123971239812399124001240112402124031240412405124061240712408124091241012411124121241312414124151241612417124181241912420124211242212423124241242512426124271242812429124301243112432124331243412435124361243712438124391244012441124421244312444124451244612447124481244912450124511245212453124541245512456124571245812459124601246112462124631246412465124661246712468124691247012471124721247312474124751247612477124781247912480124811248212483124841248512486124871248812489124901249112492124931249412495124961249712498124991250012501125021250312504125051250612507125081250912510125111251212513125141251512516125171251812519125201252112522125231252412525125261252712528125291253012531125321253312534125351253612537125381253912540125411254212543125441254512546125471254812549125501255112552125531255412555125561255712558125591256012561125621256312564125651256612567125681256912570125711257212573125741257512576125771257812579125801258112582125831258412585125861258712588125891259012591125921259312594125951259612597125981259912600126011260212603126041260512606126071260812609126101261112612126131261412615126161261712618126191262012621126221262312624126251262612627126281262912630126311263212633126341263512636126371263812639126401264112642126431264412645126461264712648126491265012651126521265312654126551265612657126581265912660126611266212663126641266512666126671266812669126701267112672126731267412675126761267712678126791268012681126821268312684126851268612687126881268912690126911269212693126941269512696126971269812699127001270112702127031270412705127061270712708127091271012711127121271312714127151271612717127181271912720127211272212723127241272512726127271272812729127301273112732127331273412735127361273712738127391274012741127421274312744127451274612747127481274912750127511275212753127541275512756127571275812759127601276112762127631276412765127661276712768127691277012771127721277312774127751277612777127781277912780127811278212783127841278512786127871278812789127901279112792127931279412795127961279712798127991280012801128021280312804128051280612807128081280912810128111281212813128141281512816128171281812819128201282112822128231282412825128261282712828128291283012831128321283312834128351283612837128381283912840128411284212843128441284512846128471284812849128501285112852128531285412855128561285712858128591286012861128621286312864128651286612867128681286912870128711287212873128741287512876128771287812879128801288112882128831288412885128861288712888128891289012891128921289312894128951289612897128981289912900129011290212903129041290512906129071290812909129101291112912129131291412915129161291712918129191292012921129221292312924129251292612927129281292912930129311293212933129341293512936129371293812939129401294112942129431294412945129461294712948129491295012951129521295312954129551295612957129581295912960129611296212963129641296512966129671296812969129701297112972129731297412975129761297712978129791298012981129821298312984129851298612987129881298912990129911299212993129941299512996129971299812999130001300113002130031300413005130061300713008130091301013011130121301313014130151301613017130181301913020130211302213023130241302513026130271302813029130301303113032130331303413035130361303713038130391304013041130421304313044130451304613047130481304913050130511305213053130541305513056130571305813059130601306113062130631306413065130661306713068130691307013071130721307313074130751307613077130781307913080130811308213083130841308513086130871308813089130901309113092130931309413095130961309713098130991310013101131021310313104131051310613107131081310913110131111311213113131141311513116131171311813119131201312113122131231312413125131261312713128131291313013131131321313313134131351313613137131381313913140131411314213143131441314513146131471314813149131501315113152131531315413155131561315713158131591316013161131621316313164131651316613167131681316913170131711317213173131741317513176131771317813179131801318113182131831318413185131861318713188131891319013191131921319313194131951319613197131981319913200132011320213203132041320513206132071320813209132101321113212132131321413215132161321713218132191322013221132221322313224132251322613227132281322913230132311323213233132341323513236132371323813239132401324113242132431324413245132461324713248132491325013251132521325313254132551325613257132581325913260132611326213263132641326513266132671326813269132701327113272132731327413275132761327713278132791328013281132821328313284132851328613287132881328913290132911329213293132941329513296132971329813299133001330113302133031330413305133061330713308133091331013311133121331313314133151331613317133181331913320133211332213323133241332513326133271332813329133301333113332133331333413335133361333713338133391334013341133421334313344133451334613347133481334913350133511335213353133541335513356133571335813359133601336113362133631336413365133661336713368133691337013371133721337313374133751337613377133781337913380133811338213383133841338513386133871338813389133901339113392133931339413395133961339713398133991340013401134021340313404134051340613407134081340913410134111341213413134141341513416134171341813419134201342113422134231342413425134261342713428134291343013431134321343313434134351343613437134381343913440134411344213443134441344513446134471344813449134501345113452134531345413455134561345713458134591346013461134621346313464134651346613467134681346913470134711347213473134741347513476134771347813479134801348113482134831348413485134861348713488134891349013491134921349313494134951349613497134981349913500135011350213503135041350513506135071350813509135101351113512135131351413515135161351713518135191352013521135221352313524135251352613527135281352913530135311353213533135341353513536135371353813539135401354113542135431354413545135461354713548135491355013551135521355313554135551355613557135581355913560135611356213563135641356513566135671356813569135701357113572135731357413575135761357713578135791358013581135821358313584135851358613587135881358913590135911359213593135941359513596135971359813599136001360113602136031360413605136061360713608136091361013611136121361313614136151361613617136181361913620136211362213623136241362513626136271362813629136301363113632136331363413635136361363713638136391364013641136421364313644136451364613647136481364913650136511365213653136541365513656136571365813659136601366113662136631366413665136661366713668136691367013671136721367313674136751367613677136781367913680136811368213683136841368513686136871368813689136901369113692136931369413695136961369713698136991370013701137021370313704137051370613707137081370913710137111371213713137141371513716137171371813719137201372113722137231372413725137261372713728137291373013731137321373313734137351373613737137381373913740137411374213743137441374513746137471374813749137501375113752137531375413755137561375713758137591376013761137621376313764137651376613767137681376913770137711377213773137741377513776137771377813779137801378113782137831378413785137861378713788137891379013791137921379313794137951379613797137981379913800138011380213803138041380513806138071380813809138101381113812138131381413815138161381713818138191382013821138221382313824138251382613827138281382913830138311383213833138341383513836138371383813839138401384113842138431384413845138461384713848138491385013851138521385313854138551385613857138581385913860138611386213863138641386513866138671386813869138701387113872138731387413875138761387713878138791388013881138821388313884138851388613887138881388913890138911389213893138941389513896138971389813899139001390113902139031390413905139061390713908139091391013911139121391313914139151391613917139181391913920139211392213923139241392513926139271392813929139301393113932139331393413935139361393713938139391394013941139421394313944139451394613947139481394913950139511395213953139541395513956139571395813959139601396113962139631396413965139661396713968139691397013971139721397313974139751397613977139781397913980139811398213983139841398513986139871398813989139901399113992139931399413995139961399713998139991400014001140021400314004140051400614007140081400914010140111401214013140141401514016140171401814019140201402114022140231402414025140261402714028140291403014031140321403314034140351403614037140381403914040140411404214043140441404514046140471404814049140501405114052140531405414055140561405714058140591406014061140621406314064140651406614067140681406914070140711407214073140741407514076140771407814079140801408114082140831408414085140861408714088140891409014091140921409314094140951409614097140981409914100141011410214103141041410514106141071410814109141101411114112141131411414115141161411714118141191412014121141221412314124141251412614127141281412914130141311413214133141341413514136141371413814139141401414114142141431414414145141461414714148141491415014151141521415314154141551415614157141581415914160141611416214163141641416514166141671416814169141701417114172141731417414175141761417714178141791418014181141821418314184141851418614187141881418914190141911419214193141941419514196141971419814199142001420114202142031420414205142061420714208142091421014211142121421314214142151421614217142181421914220142211422214223142241422514226142271422814229142301423114232142331423414235142361423714238142391424014241142421424314244142451424614247142481424914250142511425214253142541425514256142571425814259142601426114262142631426414265142661426714268142691427014271142721427314274142751427614277142781427914280142811428214283142841428514286142871428814289142901429114292142931429414295142961429714298142991430014301143021430314304143051430614307143081430914310143111431214313143141431514316143171431814319143201432114322143231432414325143261432714328143291433014331143321433314334143351433614337143381433914340143411434214343143441434514346143471434814349143501435114352143531435414355143561435714358143591436014361143621436314364143651436614367143681436914370143711437214373143741437514376143771437814379143801438114382143831438414385143861438714388143891439014391143921439314394143951439614397143981439914400144011440214403144041440514406144071440814409144101441114412144131441414415144161441714418144191442014421144221442314424144251442614427144281442914430144311443214433144341443514436144371443814439144401444114442144431444414445144461444714448144491445014451144521445314454144551445614457144581445914460144611446214463144641446514466144671446814469144701447114472144731447414475144761447714478144791448014481144821448314484144851448614487144881448914490144911449214493144941449514496144971449814499145001450114502145031450414505145061450714508145091451014511145121451314514145151451614517145181451914520145211452214523145241452514526145271452814529145301453114532145331453414535145361453714538145391454014541145421454314544145451454614547145481454914550145511455214553145541455514556145571455814559145601456114562145631456414565145661456714568145691457014571145721457314574145751457614577145781457914580145811458214583145841458514586145871458814589145901459114592145931459414595145961459714598145991460014601146021460314604146051460614607146081460914610146111461214613146141461514616146171461814619146201462114622146231462414625146261462714628146291463014631146321463314634146351463614637146381463914640146411464214643146441464514646146471464814649146501465114652146531465414655146561465714658146591466014661146621466314664146651466614667146681466914670146711467214673146741467514676146771467814679146801468114682146831468414685146861468714688146891469014691146921469314694146951469614697146981469914700147011470214703147041470514706147071470814709147101471114712147131471414715147161471714718147191472014721147221472314724147251472614727147281472914730147311473214733147341473514736147371473814739147401474114742147431474414745147461474714748147491475014751147521475314754147551475614757147581475914760147611476214763147641476514766147671476814769147701477114772147731477414775147761477714778147791478014781147821478314784147851478614787147881478914790147911479214793147941479514796147971479814799148001480114802148031480414805148061480714808148091481014811148121481314814148151481614817148181481914820148211482214823148241482514826148271482814829148301483114832148331483414835148361483714838148391484014841148421484314844148451484614847148481484914850148511485214853148541485514856148571485814859148601486114862148631486414865148661486714868148691487014871148721487314874148751487614877148781487914880148811488214883148841488514886148871488814889148901489114892148931489414895148961489714898148991490014901149021490314904149051490614907149081490914910149111491214913149141491514916149171491814919149201492114922149231492414925149261492714928149291493014931149321493314934149351493614937149381493914940149411494214943149441494514946149471494814949149501495114952149531495414955149561495714958149591496014961149621496314964149651496614967149681496914970149711497214973149741497514976149771497814979149801498114982149831498414985149861498714988149891499014991149921499314994149951499614997149981499915000150011500215003150041500515006150071500815009150101501115012150131501415015150161501715018150191502015021150221502315024150251502615027150281502915030150311503215033150341503515036150371503815039150401504115042150431504415045150461504715048150491505015051150521505315054150551505615057150581505915060150611506215063150641506515066150671506815069150701507115072150731507415075150761507715078150791508015081150821508315084150851508615087150881508915090150911509215093150941509515096150971509815099151001510115102151031510415105151061510715108151091511015111151121511315114151151511615117151181511915120151211512215123151241512515126151271512815129151301513115132151331513415135151361513715138151391514015141151421514315144151451514615147151481514915150151511515215153151541515515156151571515815159151601516115162151631516415165151661516715168151691517015171151721517315174151751517615177151781517915180151811518215183151841518515186151871518815189151901519115192151931519415195151961519715198151991520015201152021520315204152051520615207152081520915210152111521215213152141521515216152171521815219152201522115222152231522415225152261522715228152291523015231152321523315234152351523615237152381523915240152411524215243152441524515246152471524815249152501525115252152531525415255152561525715258152591526015261152621526315264152651526615267152681526915270152711527215273152741527515276152771527815279152801528115282152831528415285152861528715288152891529015291152921529315294152951529615297152981529915300153011530215303153041530515306153071530815309153101531115312153131531415315153161531715318153191532015321153221532315324153251532615327153281532915330153311533215333153341533515336153371533815339153401534115342153431534415345153461534715348153491535015351153521535315354153551535615357153581535915360153611536215363153641536515366153671536815369153701537115372153731537415375153761537715378153791538015381153821538315384153851538615387153881538915390153911539215393153941539515396153971539815399154001540115402154031540415405154061540715408154091541015411154121541315414154151541615417154181541915420154211542215423154241542515426154271542815429154301543115432154331543415435154361543715438154391544015441154421544315444154451544615447154481544915450154511545215453154541545515456154571545815459154601546115462154631546415465154661546715468154691547015471154721547315474154751547615477154781547915480154811548215483154841548515486154871548815489154901549115492154931549415495154961549715498154991550015501155021550315504155051550615507155081550915510155111551215513155141551515516155171551815519155201552115522155231552415525155261552715528155291553015531155321553315534155351553615537155381553915540155411554215543155441554515546155471554815549155501555115552155531555415555155561555715558155591556015561155621556315564155651556615567155681556915570155711557215573155741557515576155771557815579155801558115582155831558415585155861558715588155891559015591155921559315594155951559615597155981559915600156011560215603156041560515606156071560815609156101561115612156131561415615156161561715618156191562015621156221562315624156251562615627156281562915630156311563215633156341563515636156371563815639156401564115642156431564415645156461564715648156491565015651156521565315654156551565615657156581565915660156611566215663156641566515666156671566815669156701567115672156731567415675156761567715678156791568015681156821568315684156851568615687156881568915690156911569215693156941569515696156971569815699157001570115702157031570415705157061570715708157091571015711157121571315714157151571615717157181571915720157211572215723157241572515726157271572815729157301573115732157331573415735157361573715738157391574015741157421574315744157451574615747157481574915750157511575215753157541575515756157571575815759157601576115762157631576415765157661576715768157691577015771157721577315774157751577615777157781577915780157811578215783157841578515786157871578815789157901579115792157931579415795157961579715798157991580015801158021580315804158051580615807158081580915810158111581215813158141581515816158171581815819158201582115822158231582415825158261582715828158291583015831158321583315834158351583615837158381583915840158411584215843158441584515846158471584815849158501585115852158531585415855158561585715858158591586015861158621586315864158651586615867158681586915870158711587215873158741587515876158771587815879158801588115882158831588415885158861588715888158891589015891158921589315894158951589615897158981589915900159011590215903159041590515906159071590815909159101591115912159131591415915159161591715918159191592015921159221592315924159251592615927159281592915930159311593215933159341593515936159371593815939159401594115942159431594415945159461594715948159491595015951159521595315954159551595615957159581595915960159611596215963159641596515966159671596815969159701597115972159731597415975159761597715978159791598015981159821598315984159851598615987159881598915990159911599215993159941599515996159971599815999160001600116002160031600416005160061600716008160091601016011160121601316014160151601616017160181601916020160211602216023160241602516026160271602816029160301603116032160331603416035160361603716038160391604016041160421604316044160451604616047160481604916050160511605216053160541605516056160571605816059160601606116062160631606416065160661606716068160691607016071160721607316074160751607616077160781607916080160811608216083160841608516086160871608816089160901609116092160931609416095160961609716098160991610016101161021610316104161051610616107161081610916110161111611216113161141611516116161171611816119161201612116122161231612416125161261612716128161291613016131161321613316134161351613616137161381613916140161411614216143161441614516146161471614816149161501615116152161531615416155161561615716158161591616016161161621616316164161651616616167161681616916170161711617216173161741617516176161771617816179161801618116182161831618416185161861618716188161891619016191161921619316194161951619616197161981619916200162011620216203162041620516206162071620816209162101621116212162131621416215162161621716218162191622016221162221622316224162251622616227162281622916230162311623216233162341623516236162371623816239162401624116242162431624416245162461624716248162491625016251162521625316254162551625616257162581625916260162611626216263162641626516266162671626816269162701627116272162731627416275162761627716278162791628016281162821628316284162851628616287162881628916290162911629216293162941629516296162971629816299163001630116302163031630416305163061630716308163091631016311163121631316314163151631616317163181631916320163211632216323163241632516326163271632816329163301633116332163331633416335163361633716338163391634016341163421634316344163451634616347163481634916350163511635216353163541635516356163571635816359163601636116362163631636416365163661636716368163691637016371163721637316374163751637616377163781637916380163811638216383163841638516386163871638816389163901639116392163931639416395163961639716398163991640016401164021640316404164051640616407164081640916410164111641216413164141641516416164171641816419164201642116422164231642416425164261642716428164291643016431164321643316434164351643616437164381643916440164411644216443164441644516446164471644816449164501645116452164531645416455164561645716458164591646016461164621646316464164651646616467164681646916470164711647216473164741647516476164771647816479164801648116482164831648416485164861648716488164891649016491164921649316494164951649616497164981649916500165011650216503165041650516506165071650816509165101651116512165131651416515165161651716518165191652016521165221652316524165251652616527165281652916530165311653216533165341653516536165371653816539165401654116542165431654416545165461654716548165491655016551165521655316554165551655616557165581655916560165611656216563165641656516566165671656816569165701657116572165731657416575165761657716578165791658016581165821658316584165851658616587165881658916590165911659216593165941659516596165971659816599166001660116602166031660416605166061660716608166091661016611166121661316614166151661616617166181661916620166211662216623166241662516626166271662816629166301663116632166331663416635166361663716638166391664016641166421664316644166451664616647166481664916650166511665216653166541665516656166571665816659166601666116662166631666416665166661666716668166691667016671166721667316674166751667616677166781667916680166811668216683166841668516686166871668816689166901669116692166931669416695166961669716698166991670016701167021670316704167051670616707167081670916710167111671216713167141671516716167171671816719167201672116722167231672416725167261672716728167291673016731167321673316734167351673616737167381673916740167411674216743167441674516746167471674816749167501675116752167531675416755167561675716758167591676016761167621676316764167651676616767167681676916770167711677216773167741677516776167771677816779167801678116782167831678416785167861678716788167891679016791167921679316794167951679616797167981679916800168011680216803168041680516806168071680816809168101681116812168131681416815168161681716818168191682016821168221682316824168251682616827168281682916830168311683216833168341683516836168371683816839168401684116842168431684416845168461684716848168491685016851168521685316854168551685616857168581685916860168611686216863168641686516866168671686816869168701687116872168731687416875168761687716878168791688016881168821688316884168851688616887168881688916890168911689216893168941689516896168971689816899169001690116902169031690416905169061690716908169091691016911169121691316914169151691616917169181691916920169211692216923169241692516926169271692816929169301693116932169331693416935169361693716938169391694016941169421694316944169451694616947169481694916950169511695216953169541695516956169571695816959169601696116962169631696416965169661696716968169691697016971169721697316974169751697616977169781697916980169811698216983169841698516986169871698816989169901699116992169931699416995169961699716998169991700017001170021700317004170051700617007170081700917010170111701217013170141701517016170171701817019170201702117022170231702417025170261702717028170291703017031170321703317034170351703617037170381703917040170411704217043170441704517046170471704817049170501705117052170531705417055170561705717058170591706017061170621706317064170651706617067170681706917070170711707217073170741707517076170771707817079170801708117082170831708417085170861708717088170891709017091170921709317094170951709617097170981709917100171011710217103171041710517106171071710817109171101711117112171131711417115171161711717118171191712017121171221712317124171251712617127171281712917130171311713217133171341713517136171371713817139171401714117142171431714417145171461714717148171491715017151171521715317154171551715617157171581715917160171611716217163171641716517166171671716817169171701717117172171731717417175171761717717178171791718017181171821718317184171851718617187171881718917190171911719217193171941719517196171971719817199172001720117202172031720417205172061720717208172091721017211172121721317214172151721617217172181721917220172211722217223172241722517226172271722817229172301723117232172331723417235172361723717238172391724017241172421724317244172451724617247172481724917250172511725217253172541725517256172571725817259172601726117262172631726417265172661726717268172691727017271172721727317274172751727617277172781727917280172811728217283172841728517286172871728817289172901729117292172931729417295172961729717298172991730017301173021730317304173051730617307173081730917310173111731217313173141731517316173171731817319173201732117322173231732417325173261732717328173291733017331173321733317334173351733617337173381733917340173411734217343173441734517346173471734817349173501735117352173531735417355173561735717358173591736017361173621736317364173651736617367173681736917370173711737217373173741737517376173771737817379173801738117382173831738417385173861738717388173891739017391173921739317394173951739617397173981739917400174011740217403174041740517406174071740817409174101741117412174131741417415174161741717418174191742017421174221742317424174251742617427174281742917430174311743217433174341743517436174371743817439174401744117442174431744417445174461744717448174491745017451174521745317454174551745617457174581745917460174611746217463174641746517466174671746817469174701747117472174731747417475174761747717478174791748017481174821748317484174851748617487174881748917490174911749217493174941749517496174971749817499175001750117502175031750417505175061750717508175091751017511175121751317514175151751617517175181751917520175211752217523175241752517526175271752817529175301753117532175331753417535175361753717538175391754017541175421754317544175451754617547175481754917550175511755217553175541755517556175571755817559175601756117562175631756417565175661756717568175691757017571175721757317574175751757617577175781757917580175811758217583175841758517586175871758817589175901759117592175931759417595175961759717598175991760017601176021760317604176051760617607176081760917610176111761217613176141761517616176171761817619176201762117622176231762417625176261762717628176291763017631176321763317634176351763617637176381763917640176411764217643176441764517646176471764817649176501765117652176531765417655176561765717658176591766017661176621766317664176651766617667176681766917670176711767217673176741767517676176771767817679176801768117682176831768417685176861768717688176891769017691176921769317694176951769617697176981769917700177011770217703177041770517706177071770817709177101771117712177131771417715177161771717718177191772017721177221772317724177251772617727177281772917730177311773217733177341773517736177371773817739177401774117742177431774417745177461774717748177491775017751177521775317754177551775617757177581775917760177611776217763177641776517766177671776817769177701777117772177731777417775177761777717778177791778017781177821778317784177851778617787177881778917790177911779217793177941779517796177971779817799178001780117802178031780417805178061780717808178091781017811178121781317814178151781617817178181781917820178211782217823178241782517826178271782817829178301783117832178331783417835178361783717838178391784017841178421784317844178451784617847178481784917850178511785217853178541785517856178571785817859178601786117862178631786417865178661786717868178691787017871178721787317874178751787617877178781787917880178811788217883178841788517886178871788817889178901789117892178931789417895178961789717898178991790017901179021790317904179051790617907179081790917910179111791217913179141791517916179171791817919179201792117922179231792417925179261792717928179291793017931179321793317934179351793617937179381793917940179411794217943179441794517946179471794817949179501795117952179531795417955179561795717958179591796017961179621796317964179651796617967179681796917970179711797217973179741797517976179771797817979179801798117982179831798417985179861798717988179891799017991179921799317994179951799617997179981799918000180011800218003180041800518006180071800818009180101801118012180131801418015180161801718018180191802018021180221802318024180251802618027180281802918030180311803218033180341803518036180371803818039180401804118042180431804418045180461804718048180491805018051180521805318054180551805618057180581805918060180611806218063180641806518066180671806818069180701807118072180731807418075180761807718078180791808018081180821808318084180851808618087180881808918090180911809218093180941809518096180971809818099181001810118102181031810418105181061810718108181091811018111181121811318114181151811618117181181811918120181211812218123181241812518126181271812818129181301813118132181331813418135181361813718138181391814018141181421814318144181451814618147181481814918150181511815218153181541815518156181571815818159181601816118162181631816418165181661816718168181691817018171181721817318174181751817618177181781817918180181811818218183181841818518186181871818818189181901819118192181931819418195181961819718198181991820018201182021820318204182051820618207182081820918210182111821218213182141821518216182171821818219182201822118222182231822418225182261822718228182291823018231182321823318234182351823618237182381823918240182411824218243182441824518246182471824818249182501825118252182531825418255182561825718258182591826018261182621826318264182651826618267182681826918270182711827218273182741827518276182771827818279182801828118282182831828418285182861828718288182891829018291182921829318294182951829618297182981829918300183011830218303183041830518306183071830818309183101831118312183131831418315183161831718318183191832018321183221832318324183251832618327183281832918330183311833218333183341833518336183371833818339183401834118342183431834418345183461834718348183491835018351183521835318354183551835618357183581835918360183611836218363183641836518366183671836818369183701837118372183731837418375183761837718378183791838018381183821838318384183851838618387183881838918390183911839218393183941839518396183971839818399184001840118402184031840418405184061840718408184091841018411184121841318414184151841618417184181841918420184211842218423184241842518426184271842818429184301843118432184331843418435184361843718438184391844018441184421844318444184451844618447184481844918450184511845218453184541845518456184571845818459184601846118462184631846418465184661846718468184691847018471184721847318474184751847618477184781847918480184811848218483184841848518486184871848818489184901849118492184931849418495184961849718498184991850018501185021850318504185051850618507185081850918510185111851218513185141851518516185171851818519185201852118522185231852418525185261852718528185291853018531185321853318534185351853618537185381853918540185411854218543185441854518546185471854818549185501855118552185531855418555185561855718558185591856018561185621856318564185651856618567185681856918570185711857218573185741857518576185771857818579185801858118582185831858418585185861858718588185891859018591185921859318594185951859618597185981859918600186011860218603186041860518606186071860818609186101861118612186131861418615186161861718618186191862018621186221862318624186251862618627186281862918630186311863218633186341863518636186371863818639186401864118642186431864418645186461864718648186491865018651186521865318654186551865618657186581865918660186611866218663186641866518666186671866818669186701867118672186731867418675186761867718678186791868018681186821868318684186851868618687186881868918690186911869218693186941869518696186971869818699187001870118702187031870418705187061870718708187091871018711187121871318714187151871618717187181871918720187211872218723187241872518726187271872818729187301873118732187331873418735187361873718738187391874018741187421874318744187451874618747187481874918750187511875218753187541875518756187571875818759187601876118762187631876418765187661876718768187691877018771187721877318774187751877618777187781877918780187811878218783187841878518786187871878818789187901879118792187931879418795187961879718798187991880018801188021880318804188051880618807188081880918810188111881218813188141881518816188171881818819188201882118822188231882418825188261882718828188291883018831188321883318834188351883618837188381883918840188411884218843188441884518846188471884818849188501885118852188531885418855188561885718858188591886018861188621886318864188651886618867188681886918870188711887218873188741887518876188771887818879188801888118882188831888418885188861888718888188891889018891188921889318894188951889618897188981889918900189011890218903189041890518906189071890818909189101891118912189131891418915189161891718918189191892018921189221892318924189251892618927189281892918930189311893218933189341893518936189371893818939189401894118942189431894418945189461894718948189491895018951189521895318954189551895618957189581895918960189611896218963189641896518966189671896818969189701897118972189731897418975189761897718978189791898018981189821898318984189851898618987189881898918990189911899218993189941899518996189971899818999190001900119002190031900419005190061900719008190091901019011190121901319014190151901619017190181901919020190211902219023190241902519026190271902819029190301903119032190331903419035190361903719038190391904019041190421904319044190451904619047190481904919050190511905219053190541905519056190571905819059190601906119062190631906419065190661906719068190691907019071190721907319074190751907619077190781907919080190811908219083190841908519086190871908819089190901909119092190931909419095190961909719098190991910019101191021910319104191051910619107191081910919110191111911219113191141911519116191171911819119191201912119122191231912419125191261912719128191291913019131191321913319134191351913619137191381913919140191411914219143191441914519146191471914819149191501915119152191531915419155191561915719158191591916019161191621916319164191651916619167191681916919170191711917219173191741917519176191771917819179191801918119182191831918419185191861918719188191891919019191191921919319194191951919619197191981919919200192011920219203192041920519206192071920819209192101921119212192131921419215192161921719218192191922019221192221922319224192251922619227192281922919230192311923219233192341923519236192371923819239192401924119242192431924419245192461924719248192491925019251192521925319254192551925619257192581925919260192611926219263192641926519266192671926819269192701927119272192731927419275192761927719278192791928019281192821928319284192851928619287192881928919290192911929219293192941929519296192971929819299193001930119302193031930419305193061930719308193091931019311193121931319314193151931619317193181931919320193211932219323193241932519326193271932819329193301933119332193331933419335193361933719338193391934019341193421934319344193451934619347193481934919350193511935219353193541935519356193571935819359193601936119362193631936419365193661936719368193691937019371193721937319374193751937619377193781937919380193811938219383193841938519386193871938819389193901939119392193931939419395193961939719398193991940019401194021940319404194051940619407194081940919410194111941219413194141941519416194171941819419194201942119422194231942419425194261942719428194291943019431194321943319434194351943619437194381943919440194411944219443194441944519446194471944819449194501945119452194531945419455194561945719458194591946019461194621946319464194651946619467194681946919470194711947219473194741947519476194771947819479194801948119482194831948419485194861948719488194891949019491194921949319494194951949619497194981949919500195011950219503195041950519506195071950819509195101951119512195131951419515195161951719518195191952019521195221952319524195251952619527195281952919530195311953219533195341953519536195371953819539195401954119542195431954419545195461954719548195491955019551195521955319554195551955619557195581955919560195611956219563195641956519566195671956819569195701957119572195731957419575195761957719578195791958019581195821958319584195851958619587195881958919590195911959219593195941959519596195971959819599196001960119602196031960419605196061960719608196091961019611196121961319614196151961619617196181961919620196211962219623196241962519626196271962819629196301963119632196331963419635196361963719638196391964019641196421964319644196451964619647196481964919650196511965219653196541965519656196571965819659196601966119662196631966419665196661966719668196691967019671196721967319674196751967619677196781967919680196811968219683196841968519686196871968819689196901969119692196931969419695196961969719698196991970019701197021970319704197051970619707197081970919710197111971219713197141971519716197171971819719197201972119722197231972419725197261972719728197291973019731197321973319734197351973619737197381973919740197411974219743197441974519746197471974819749197501975119752197531975419755197561975719758197591976019761197621976319764197651976619767197681976919770197711977219773197741977519776197771977819779197801978119782197831978419785197861978719788197891979019791197921979319794197951979619797197981979919800198011980219803198041980519806198071980819809198101981119812198131981419815198161981719818198191982019821198221982319824198251982619827198281982919830198311983219833198341983519836198371983819839198401984119842198431984419845198461984719848198491985019851198521985319854198551985619857198581985919860198611986219863198641986519866198671986819869198701987119872198731987419875198761987719878198791988019881198821988319884198851988619887198881988919890198911989219893198941989519896198971989819899199001990119902199031990419905199061990719908199091991019911199121991319914199151991619917199181991919920199211992219923199241992519926199271992819929199301993119932199331993419935199361993719938199391994019941199421994319944199451994619947199481994919950199511995219953199541995519956199571995819959199601996119962199631996419965199661996719968199691997019971199721997319974199751997619977199781997919980199811998219983199841998519986199871998819989199901999119992199931999419995199961999719998199992000020001200022000320004200052000620007200082000920010200112001220013200142001520016200172001820019200202002120022200232002420025200262002720028200292003020031200322003320034200352003620037200382003920040200412004220043200442004520046200472004820049200502005120052200532005420055200562005720058200592006020061200622006320064200652006620067200682006920070200712007220073200742007520076200772007820079200802008120082200832008420085200862008720088200892009020091200922009320094200952009620097200982009920100201012010220103201042010520106201072010820109201102011120112201132011420115201162011720118201192012020121201222012320124201252012620127201282012920130201312013220133201342013520136201372013820139201402014120142201432014420145201462014720148201492015020151201522015320154201552015620157201582015920160201612016220163201642016520166201672016820169201702017120172201732017420175201762017720178201792018020181201822018320184201852018620187201882018920190201912019220193201942019520196201972019820199202002020120202202032020420205202062020720208202092021020211202122021320214202152021620217202182021920220202212022220223202242022520226202272022820229202302023120232202332023420235202362023720238202392024020241202422024320244202452024620247202482024920250202512025220253202542025520256202572025820259202602026120262202632026420265202662026720268202692027020271202722027320274202752027620277202782027920280202812028220283202842028520286202872028820289202902029120292202932029420295202962029720298202992030020301203022030320304203052030620307203082030920310203112031220313203142031520316203172031820319203202032120322203232032420325203262032720328203292033020331203322033320334203352033620337203382033920340203412034220343203442034520346203472034820349203502035120352203532035420355203562035720358203592036020361203622036320364203652036620367203682036920370203712037220373203742037520376203772037820379203802038120382203832038420385203862038720388203892039020391203922039320394203952039620397203982039920400204012040220403204042040520406204072040820409204102041120412204132041420415204162041720418204192042020421204222042320424204252042620427204282042920430204312043220433204342043520436204372043820439204402044120442204432044420445204462044720448204492045020451204522045320454204552045620457204582045920460204612046220463204642046520466204672046820469204702047120472204732047420475204762047720478204792048020481204822048320484204852048620487204882048920490204912049220493204942049520496204972049820499205002050120502205032050420505205062050720508205092051020511205122051320514205152051620517205182051920520205212052220523205242052520526205272052820529205302053120532205332053420535205362053720538205392054020541205422054320544205452054620547205482054920550205512055220553205542055520556205572055820559205602056120562205632056420565205662056720568205692057020571205722057320574205752057620577205782057920580205812058220583205842058520586205872058820589205902059120592205932059420595205962059720598205992060020601206022060320604206052060620607206082060920610206112061220613206142061520616206172061820619206202062120622206232062420625206262062720628206292063020631206322063320634206352063620637206382063920640206412064220643206442064520646206472064820649206502065120652206532065420655206562065720658206592066020661206622066320664206652066620667206682066920670206712067220673206742067520676206772067820679206802068120682206832068420685206862068720688206892069020691206922069320694206952069620697206982069920700207012070220703207042070520706207072070820709207102071120712207132071420715207162071720718207192072020721207222072320724207252072620727207282072920730207312073220733207342073520736207372073820739207402074120742207432074420745207462074720748207492075020751207522075320754207552075620757207582075920760207612076220763207642076520766207672076820769207702077120772207732077420775207762077720778207792078020781207822078320784207852078620787207882078920790207912079220793207942079520796207972079820799208002080120802208032080420805208062080720808208092081020811208122081320814208152081620817208182081920820208212082220823208242082520826208272082820829208302083120832208332083420835208362083720838208392084020841208422084320844208452084620847208482084920850208512085220853208542085520856208572085820859208602086120862208632086420865208662086720868208692087020871208722087320874208752087620877208782087920880208812088220883208842088520886208872088820889208902089120892208932089420895208962089720898208992090020901209022090320904209052090620907209082090920910209112091220913209142091520916209172091820919209202092120922209232092420925209262092720928209292093020931209322093320934209352093620937209382093920940209412094220943209442094520946209472094820949209502095120952209532095420955209562095720958209592096020961209622096320964209652096620967209682096920970209712097220973209742097520976209772097820979209802098120982209832098420985209862098720988209892099020991209922099320994209952099620997209982099921000210012100221003210042100521006210072100821009210102101121012210132101421015210162101721018210192102021021210222102321024210252102621027210282102921030210312103221033210342103521036210372103821039210402104121042210432104421045210462104721048210492105021051210522105321054210552105621057210582105921060210612106221063210642106521066210672106821069210702107121072210732107421075210762107721078210792108021081210822108321084210852108621087210882108921090210912109221093210942109521096210972109821099211002110121102211032110421105211062110721108211092111021111211122111321114211152111621117211182111921120211212112221123211242112521126211272112821129211302113121132211332113421135211362113721138211392114021141211422114321144211452114621147211482114921150211512115221153211542115521156211572115821159211602116121162211632116421165211662116721168211692117021171211722117321174211752117621177211782117921180211812118221183211842118521186211872118821189211902119121192211932119421195211962119721198211992120021201212022120321204212052120621207212082120921210212112121221213212142121521216212172121821219212202122121222212232122421225212262122721228212292123021231212322123321234212352123621237212382123921240212412124221243212442124521246212472124821249212502125121252212532125421255212562125721258212592126021261212622126321264212652126621267212682126921270212712127221273212742127521276212772127821279212802128121282212832128421285212862128721288212892129021291212922129321294212952129621297212982129921300213012130221303213042130521306213072130821309213102131121312213132131421315213162131721318213192132021321213222132321324213252132621327213282132921330213312133221333213342133521336213372133821339213402134121342213432134421345213462134721348213492135021351213522135321354213552135621357213582135921360213612136221363213642136521366213672136821369213702137121372213732137421375213762137721378213792138021381213822138321384213852138621387213882138921390213912139221393213942139521396213972139821399214002140121402214032140421405214062140721408214092141021411214122141321414214152141621417214182141921420214212142221423214242142521426214272142821429214302143121432214332143421435214362143721438214392144021441214422144321444214452144621447214482144921450214512145221453214542145521456214572145821459214602146121462214632146421465214662146721468214692147021471214722147321474214752147621477214782147921480214812148221483214842148521486214872148821489214902149121492214932149421495214962149721498214992150021501215022150321504215052150621507215082150921510215112151221513215142151521516215172151821519215202152121522215232152421525215262152721528215292153021531215322153321534215352153621537215382153921540215412154221543215442154521546215472154821549215502155121552215532155421555215562155721558215592156021561215622156321564215652156621567215682156921570215712157221573215742157521576215772157821579215802158121582215832158421585215862158721588215892159021591215922159321594215952159621597215982159921600216012160221603216042160521606216072160821609216102161121612216132161421615216162161721618216192162021621216222162321624216252162621627216282162921630216312163221633216342163521636216372163821639216402164121642216432164421645216462164721648216492165021651216522165321654216552165621657216582165921660216612166221663216642166521666216672166821669216702167121672216732167421675216762167721678216792168021681216822168321684216852168621687216882168921690216912169221693216942169521696216972169821699217002170121702217032170421705217062170721708217092171021711217122171321714217152171621717217182171921720217212172221723217242172521726217272172821729217302173121732217332173421735217362173721738217392174021741217422174321744217452174621747217482174921750217512175221753217542175521756217572175821759217602176121762217632176421765217662176721768217692177021771217722177321774217752177621777217782177921780217812178221783217842178521786217872178821789217902179121792217932179421795217962179721798217992180021801218022180321804218052180621807218082180921810218112181221813218142181521816218172181821819218202182121822218232182421825218262182721828218292183021831218322183321834218352183621837218382183921840218412184221843218442184521846218472184821849218502185121852218532185421855218562185721858218592186021861218622186321864218652186621867218682186921870218712187221873218742187521876218772187821879218802188121882218832188421885218862188721888218892189021891218922189321894218952189621897218982189921900219012190221903219042190521906219072190821909219102191121912219132191421915219162191721918219192192021921219222192321924219252192621927219282192921930219312193221933219342193521936219372193821939219402194121942219432194421945219462194721948219492195021951219522195321954219552195621957219582195921960219612196221963219642196521966219672196821969219702197121972219732197421975219762197721978219792198021981219822198321984219852198621987219882198921990219912199221993219942199521996219972199821999220002200122002220032200422005220062200722008220092201022011220122201322014220152201622017220182201922020220212202222023220242202522026220272202822029220302203122032220332203422035220362203722038220392204022041220422204322044220452204622047220482204922050220512205222053220542205522056220572205822059220602206122062220632206422065220662206722068220692207022071220722207322074220752207622077220782207922080220812208222083220842208522086220872208822089220902209122092220932209422095220962209722098220992210022101221022210322104221052210622107221082210922110221112211222113221142211522116221172211822119221202212122122221232212422125221262212722128221292213022131221322213322134221352213622137221382213922140221412214222143221442214522146221472214822149221502215122152221532215422155221562215722158221592216022161221622216322164221652216622167221682216922170221712217222173221742217522176221772217822179221802218122182221832218422185221862218722188221892219022191221922219322194221952219622197221982219922200222012220222203222042220522206222072220822209222102221122212222132221422215222162221722218222192222022221222222222322224222252222622227222282222922230222312223222233222342223522236222372223822239222402224122242222432224422245222462224722248222492225022251222522225322254222552225622257222582225922260222612226222263222642226522266222672226822269222702227122272222732227422275222762227722278222792228022281222822228322284222852228622287222882228922290222912229222293222942229522296222972229822299223002230122302223032230422305223062230722308223092231022311223122231322314223152231622317223182231922320223212232222323223242232522326223272232822329223302233122332223332233422335223362233722338223392234022341223422234322344223452234622347223482234922350223512235222353223542235522356223572235822359223602236122362223632236422365223662236722368223692237022371223722237322374223752237622377223782237922380223812238222383223842238522386223872238822389223902239122392223932239422395223962239722398223992240022401224022240322404224052240622407224082240922410224112241222413224142241522416224172241822419224202242122422224232242422425224262242722428224292243022431224322243322434224352243622437224382243922440224412244222443224442244522446224472244822449224502245122452224532245422455224562245722458224592246022461224622246322464224652246622467224682246922470224712247222473224742247522476224772247822479224802248122482224832248422485224862248722488224892249022491224922249322494224952249622497224982249922500225012250222503225042250522506225072250822509225102251122512225132251422515225162251722518225192252022521225222252322524225252252622527225282252922530225312253222533225342253522536225372253822539225402254122542225432254422545225462254722548225492255022551225522255322554225552255622557225582255922560225612256222563225642256522566225672256822569225702257122572225732257422575225762257722578225792258022581225822258322584225852258622587225882258922590225912259222593225942259522596225972259822599226002260122602226032260422605226062260722608226092261022611226122261322614226152261622617226182261922620226212262222623226242262522626226272262822629226302263122632226332263422635226362263722638226392264022641226422264322644226452264622647226482264922650226512265222653226542265522656226572265822659226602266122662226632266422665226662266722668226692267022671226722267322674226752267622677226782267922680226812268222683226842268522686226872268822689226902269122692226932269422695226962269722698226992270022701227022270322704227052270622707227082270922710227112271222713227142271522716227172271822719227202272122722227232272422725227262272722728227292273022731227322273322734227352273622737227382273922740227412274222743227442274522746227472274822749227502275122752227532275422755227562275722758227592276022761227622276322764227652276622767227682276922770227712277222773227742277522776227772277822779227802278122782227832278422785227862278722788227892279022791227922279322794227952279622797227982279922800228012280222803228042280522806228072280822809228102281122812228132281422815228162281722818228192282022821228222282322824228252282622827228282282922830228312283222833228342283522836228372283822839228402284122842228432284422845228462284722848228492285022851228522285322854228552285622857228582285922860228612286222863228642286522866228672286822869228702287122872228732287422875228762287722878228792288022881228822288322884228852288622887228882288922890228912289222893228942289522896228972289822899229002290122902229032290422905229062290722908229092291022911229122291322914229152291622917229182291922920229212292222923229242292522926229272292822929229302293122932229332293422935229362293722938229392294022941229422294322944229452294622947229482294922950229512295222953229542295522956229572295822959229602296122962229632296422965229662296722968229692297022971229722297322974229752297622977229782297922980229812298222983229842298522986229872298822989229902299122992229932299422995229962299722998229992300023001230022300323004230052300623007230082300923010230112301223013230142301523016230172301823019230202302123022230232302423025230262302723028230292303023031230322303323034230352303623037230382303923040230412304223043230442304523046230472304823049230502305123052230532305423055230562305723058230592306023061230622306323064230652306623067230682306923070230712307223073230742307523076230772307823079230802308123082230832308423085230862308723088230892309023091230922309323094230952309623097230982309923100231012310223103231042310523106231072310823109231102311123112231132311423115231162311723118231192312023121231222312323124231252312623127231282312923130231312313223133231342313523136231372313823139231402314123142231432314423145231462314723148231492315023151231522315323154231552315623157231582315923160231612316223163231642316523166231672316823169231702317123172231732317423175231762317723178231792318023181231822318323184231852318623187231882318923190231912319223193231942319523196231972319823199232002320123202232032320423205232062320723208232092321023211232122321323214232152321623217232182321923220232212322223223232242322523226232272322823229232302323123232232332323423235232362323723238232392324023241232422324323244232452324623247232482324923250232512325223253232542325523256232572325823259232602326123262232632326423265232662326723268232692327023271232722327323274232752327623277232782327923280232812328223283232842328523286232872328823289232902329123292232932329423295232962329723298232992330023301233022330323304233052330623307233082330923310233112331223313233142331523316233172331823319233202332123322233232332423325233262332723328233292333023331233322333323334233352333623337233382333923340233412334223343233442334523346233472334823349233502335123352233532335423355233562335723358233592336023361233622336323364233652336623367233682336923370233712337223373233742337523376233772337823379233802338123382233832338423385233862338723388233892339023391233922339323394233952339623397233982339923400234012340223403234042340523406234072340823409234102341123412234132341423415234162341723418234192342023421234222342323424234252342623427234282342923430234312343223433234342343523436234372343823439234402344123442234432344423445234462344723448234492345023451234522345323454234552345623457234582345923460234612346223463234642346523466234672346823469234702347123472234732347423475234762347723478234792348023481234822348323484234852348623487234882348923490234912349223493234942349523496234972349823499235002350123502235032350423505235062350723508235092351023511235122351323514235152351623517235182351923520235212352223523235242352523526235272352823529235302353123532235332353423535235362353723538235392354023541235422354323544235452354623547235482354923550235512355223553235542355523556235572355823559235602356123562235632356423565235662356723568235692357023571235722357323574235752357623577235782357923580235812358223583235842358523586235872358823589235902359123592235932359423595235962359723598235992360023601236022360323604236052360623607236082360923610236112361223613236142361523616236172361823619236202362123622236232362423625236262362723628236292363023631236322363323634236352363623637236382363923640236412364223643236442364523646236472364823649236502365123652236532365423655236562365723658236592366023661236622366323664236652366623667236682366923670236712367223673236742367523676236772367823679236802368123682236832368423685236862368723688236892369023691236922369323694236952369623697236982369923700237012370223703237042370523706237072370823709237102371123712237132371423715237162371723718237192372023721237222372323724237252372623727237282372923730237312373223733237342373523736237372373823739237402374123742237432374423745237462374723748237492375023751237522375323754237552375623757237582375923760237612376223763237642376523766237672376823769237702377123772237732377423775237762377723778237792378023781237822378323784237852378623787237882378923790237912379223793237942379523796237972379823799238002380123802238032380423805238062380723808238092381023811238122381323814238152381623817238182381923820238212382223823238242382523826238272382823829238302383123832238332383423835238362383723838238392384023841238422384323844238452384623847238482384923850238512385223853238542385523856238572385823859238602386123862238632386423865238662386723868238692387023871238722387323874238752387623877238782387923880238812388223883238842388523886238872388823889238902389123892238932389423895238962389723898238992390023901239022390323904239052390623907239082390923910239112391223913239142391523916239172391823919239202392123922239232392423925239262392723928239292393023931239322393323934239352393623937239382393923940239412394223943239442394523946239472394823949239502395123952239532395423955239562395723958239592396023961239622396323964239652396623967239682396923970239712397223973239742397523976239772397823979239802398123982239832398423985239862398723988239892399023991239922399323994239952399623997239982399924000240012400224003240042400524006240072400824009240102401124012240132401424015240162401724018240192402024021240222402324024240252402624027240282402924030240312403224033240342403524036240372403824039240402404124042240432404424045240462404724048240492405024051240522405324054240552405624057240582405924060240612406224063240642406524066240672406824069240702407124072240732407424075240762407724078240792408024081240822408324084240852408624087240882408924090240912409224093240942409524096240972409824099241002410124102241032410424105241062410724108241092411024111241122411324114241152411624117241182411924120241212412224123241242412524126241272412824129241302413124132241332413424135241362413724138241392414024141241422414324144241452414624147241482414924150241512415224153241542415524156241572415824159241602416124162241632416424165241662416724168241692417024171241722417324174241752417624177241782417924180241812418224183241842418524186241872418824189241902419124192241932419424195241962419724198241992420024201242022420324204242052420624207242082420924210242112421224213242142421524216242172421824219242202422124222242232422424225242262422724228242292423024231242322423324234242352423624237242382423924240242412424224243242442424524246242472424824249242502425124252242532425424255242562425724258242592426024261242622426324264242652426624267242682426924270242712427224273242742427524276242772427824279242802428124282242832428424285242862428724288242892429024291242922429324294242952429624297242982429924300243012430224303243042430524306243072430824309243102431124312243132431424315243162431724318243192432024321243222432324324243252432624327243282432924330243312433224333243342433524336243372433824339243402434124342243432434424345243462434724348243492435024351243522435324354243552435624357243582435924360243612436224363243642436524366243672436824369243702437124372243732437424375243762437724378243792438024381243822438324384243852438624387243882438924390243912439224393243942439524396243972439824399244002440124402244032440424405244062440724408244092441024411244122441324414244152441624417244182441924420244212442224423244242442524426244272442824429244302443124432244332443424435244362443724438244392444024441244422444324444244452444624447244482444924450244512445224453244542445524456244572445824459244602446124462244632446424465244662446724468244692447024471244722447324474244752447624477244782447924480244812448224483244842448524486244872448824489244902449124492244932449424495244962449724498244992450024501245022450324504245052450624507245082450924510245112451224513245142451524516245172451824519245202452124522245232452424525245262452724528245292453024531245322453324534245352453624537245382453924540245412454224543245442454524546245472454824549245502455124552245532455424555245562455724558245592456024561245622456324564245652456624567245682456924570245712457224573245742457524576245772457824579245802458124582245832458424585245862458724588245892459024591245922459324594245952459624597245982459924600246012460224603246042460524606246072460824609246102461124612246132461424615246162461724618246192462024621246222462324624246252462624627246282462924630246312463224633246342463524636246372463824639246402464124642246432464424645246462464724648246492465024651246522465324654246552465624657246582465924660246612466224663246642466524666246672466824669246702467124672246732467424675246762467724678246792468024681246822468324684246852468624687246882468924690246912469224693246942469524696246972469824699247002470124702247032470424705247062470724708247092471024711247122471324714247152471624717247182471924720247212472224723247242472524726247272472824729247302473124732247332473424735247362473724738247392474024741247422474324744247452474624747247482474924750247512475224753247542475524756247572475824759247602476124762247632476424765247662476724768247692477024771247722477324774247752477624777247782477924780247812478224783247842478524786247872478824789247902479124792247932479424795247962479724798247992480024801248022480324804248052480624807248082480924810248112481224813248142481524816248172481824819248202482124822248232482424825248262482724828248292483024831248322483324834248352483624837248382483924840248412484224843248442484524846248472484824849248502485124852248532485424855248562485724858248592486024861248622486324864248652486624867248682486924870248712487224873248742487524876248772487824879248802488124882248832488424885248862488724888248892489024891248922489324894248952489624897248982489924900249012490224903249042490524906249072490824909249102491124912249132491424915249162491724918249192492024921249222492324924249252492624927249282492924930249312493224933249342493524936249372493824939249402494124942249432494424945249462494724948249492495024951249522495324954249552495624957249582495924960249612496224963249642496524966249672496824969249702497124972249732497424975249762497724978249792498024981249822498324984249852498624987249882498924990249912499224993249942499524996249972499824999250002500125002250032500425005250062500725008250092501025011250122501325014250152501625017250182501925020250212502225023250242502525026250272502825029250302503125032250332503425035250362503725038250392504025041250422504325044250452504625047250482504925050250512505225053250542505525056250572505825059250602506125062250632506425065250662506725068250692507025071250722507325074250752507625077250782507925080250812508225083250842508525086250872508825089250902509125092250932509425095250962509725098250992510025101251022510325104251052510625107251082510925110251112511225113251142511525116251172511825119251202512125122251232512425125251262512725128251292513025131251322513325134251352513625137251382513925140251412514225143251442514525146251472514825149251502515125152251532515425155251562515725158251592516025161251622516325164251652516625167251682516925170251712517225173251742517525176251772517825179251802518125182251832518425185251862518725188251892519025191251922519325194251952519625197251982519925200252012520225203252042520525206252072520825209252102521125212252132521425215252162521725218252192522025221252222522325224252252522625227252282522925230252312523225233252342523525236252372523825239252402524125242252432524425245252462524725248252492525025251252522525325254252552525625257252582525925260252612526225263252642526525266252672526825269252702527125272252732527425275252762527725278252792528025281252822528325284252852528625287252882528925290252912529225293252942529525296252972529825299253002530125302253032530425305253062530725308253092531025311253122531325314253152531625317253182531925320253212532225323253242532525326253272532825329253302533125332253332533425335253362533725338253392534025341253422534325344253452534625347253482534925350253512535225353253542535525356253572535825359253602536125362253632536425365253662536725368253692537025371253722537325374253752537625377253782537925380253812538225383253842538525386253872538825389253902539125392253932539425395253962539725398253992540025401254022540325404254052540625407254082540925410254112541225413254142541525416254172541825419254202542125422254232542425425254262542725428254292543025431254322543325434254352543625437254382543925440254412544225443254442544525446254472544825449254502545125452254532545425455254562545725458254592546025461254622546325464254652546625467254682546925470254712547225473254742547525476254772547825479254802548125482254832548425485254862548725488254892549025491254922549325494254952549625497254982549925500255012550225503255042550525506255072550825509255102551125512255132551425515255162551725518255192552025521255222552325524255252552625527255282552925530255312553225533255342553525536255372553825539255402554125542255432554425545255462554725548255492555025551255522555325554255552555625557255582555925560255612556225563255642556525566255672556825569255702557125572255732557425575255762557725578255792558025581255822558325584255852558625587255882558925590255912559225593255942559525596255972559825599256002560125602256032560425605256062560725608256092561025611256122561325614256152561625617256182561925620256212562225623256242562525626256272562825629256302563125632256332563425635256362563725638256392564025641256422564325644256452564625647256482564925650256512565225653256542565525656256572565825659256602566125662256632566425665256662566725668256692567025671256722567325674256752567625677256782567925680256812568225683256842568525686256872568825689256902569125692256932569425695256962569725698256992570025701257022570325704257052570625707257082570925710257112571225713257142571525716257172571825719257202572125722257232572425725257262572725728257292573025731257322573325734257352573625737257382573925740257412574225743257442574525746257472574825749257502575125752257532575425755257562575725758257592576025761257622576325764257652576625767257682576925770257712577225773257742577525776257772577825779257802578125782257832578425785257862578725788257892579025791257922579325794257952579625797257982579925800258012580225803258042580525806258072580825809258102581125812258132581425815258162581725818258192582025821258222582325824258252582625827258282582925830258312583225833258342583525836258372583825839258402584125842258432584425845258462584725848258492585025851258522585325854258552585625857258582585925860258612586225863258642586525866258672586825869258702587125872258732587425875258762587725878258792588025881258822588325884258852588625887258882588925890258912589225893258942589525896258972589825899259002590125902259032590425905259062590725908259092591025911259122591325914259152591625917259182591925920259212592225923259242592525926259272592825929259302593125932259332593425935259362593725938259392594025941259422594325944259452594625947259482594925950259512595225953259542595525956259572595825959259602596125962259632596425965259662596725968259692597025971259722597325974259752597625977259782597925980259812598225983259842598525986259872598825989259902599125992259932599425995259962599725998259992600026001260022600326004260052600626007260082600926010260112601226013260142601526016260172601826019260202602126022260232602426025260262602726028260292603026031260322603326034260352603626037260382603926040260412604226043260442604526046260472604826049260502605126052260532605426055260562605726058260592606026061260622606326064260652606626067260682606926070260712607226073260742607526076260772607826079260802608126082260832608426085260862608726088260892609026091260922609326094260952609626097260982609926100261012610226103261042610526106261072610826109261102611126112261132611426115261162611726118261192612026121261222612326124261252612626127261282612926130261312613226133261342613526136261372613826139261402614126142261432614426145261462614726148261492615026151261522615326154261552615626157261582615926160261612616226163261642616526166261672616826169261702617126172261732617426175261762617726178261792618026181261822618326184261852618626187261882618926190261912619226193261942619526196261972619826199262002620126202262032620426205262062620726208262092621026211262122621326214262152621626217262182621926220262212622226223262242622526226262272622826229262302623126232262332623426235262362623726238262392624026241262422624326244262452624626247262482624926250262512625226253262542625526256262572625826259262602626126262262632626426265262662626726268262692627026271262722627326274262752627626277262782627926280262812628226283262842628526286262872628826289262902629126292262932629426295262962629726298262992630026301263022630326304263052630626307263082630926310263112631226313263142631526316263172631826319263202632126322263232632426325263262632726328263292633026331263322633326334263352633626337263382633926340263412634226343263442634526346263472634826349263502635126352263532635426355263562635726358263592636026361263622636326364263652636626367263682636926370263712637226373263742637526376263772637826379263802638126382263832638426385263862638726388263892639026391263922639326394263952639626397263982639926400264012640226403264042640526406264072640826409264102641126412264132641426415264162641726418264192642026421264222642326424264252642626427264282642926430264312643226433264342643526436264372643826439264402644126442264432644426445264462644726448264492645026451264522645326454264552645626457264582645926460264612646226463264642646526466264672646826469264702647126472264732647426475264762647726478264792648026481264822648326484264852648626487264882648926490264912649226493264942649526496264972649826499265002650126502265032650426505265062650726508265092651026511265122651326514265152651626517265182651926520265212652226523265242652526526265272652826529265302653126532265332653426535265362653726538265392654026541265422654326544265452654626547265482654926550265512655226553265542655526556265572655826559265602656126562265632656426565265662656726568265692657026571265722657326574265752657626577265782657926580265812658226583265842658526586265872658826589265902659126592265932659426595265962659726598265992660026601266022660326604266052660626607266082660926610266112661226613266142661526616266172661826619266202662126622266232662426625266262662726628266292663026631266322663326634266352663626637266382663926640266412664226643266442664526646266472664826649266502665126652266532665426655266562665726658266592666026661266622666326664266652666626667266682666926670266712667226673266742667526676266772667826679266802668126682266832668426685266862668726688266892669026691266922669326694266952669626697266982669926700267012670226703267042670526706267072670826709267102671126712267132671426715267162671726718267192672026721267222672326724267252672626727267282672926730267312673226733267342673526736267372673826739267402674126742267432674426745267462674726748267492675026751267522675326754267552675626757267582675926760267612676226763267642676526766267672676826769267702677126772267732677426775267762677726778267792678026781267822678326784267852678626787267882678926790267912679226793267942679526796267972679826799268002680126802268032680426805268062680726808268092681026811268122681326814268152681626817268182681926820268212682226823268242682526826268272682826829268302683126832268332683426835268362683726838268392684026841268422684326844268452684626847268482684926850268512685226853268542685526856268572685826859268602686126862268632686426865268662686726868268692687026871268722687326874268752687626877268782687926880268812688226883268842688526886268872688826889268902689126892268932689426895268962689726898268992690026901269022690326904269052690626907269082690926910269112691226913269142691526916269172691826919269202692126922269232692426925269262692726928269292693026931269322693326934269352693626937269382693926940269412694226943269442694526946269472694826949269502695126952269532695426955269562695726958269592696026961269622696326964269652696626967269682696926970269712697226973269742697526976269772697826979269802698126982269832698426985269862698726988269892699026991269922699326994269952699626997269982699927000270012700227003270042700527006270072700827009270102701127012270132701427015270162701727018270192702027021270222702327024270252702627027270282702927030270312703227033270342703527036270372703827039270402704127042270432704427045270462704727048270492705027051270522705327054270552705627057270582705927060270612706227063270642706527066270672706827069270702707127072270732707427075270762707727078270792708027081270822708327084270852708627087270882708927090270912709227093270942709527096270972709827099271002710127102271032710427105271062710727108271092711027111271122711327114271152711627117271182711927120271212712227123271242712527126271272712827129271302713127132271332713427135271362713727138271392714027141271422714327144271452714627147271482714927150271512715227153271542715527156271572715827159271602716127162271632716427165271662716727168271692717027171271722717327174271752717627177271782717927180271812718227183271842718527186271872718827189271902719127192271932719427195271962719727198271992720027201272022720327204272052720627207272082720927210272112721227213272142721527216272172721827219272202722127222272232722427225272262722727228272292723027231272322723327234272352723627237272382723927240272412724227243272442724527246272472724827249272502725127252272532725427255272562725727258272592726027261272622726327264272652726627267272682726927270272712727227273272742727527276272772727827279272802728127282272832728427285272862728727288272892729027291272922729327294272952729627297272982729927300273012730227303273042730527306273072730827309273102731127312273132731427315273162731727318273192732027321273222732327324273252732627327273282732927330273312733227333273342733527336273372733827339273402734127342273432734427345273462734727348273492735027351273522735327354273552735627357273582735927360273612736227363273642736527366273672736827369273702737127372273732737427375273762737727378273792738027381273822738327384273852738627387273882738927390273912739227393273942739527396273972739827399274002740127402274032740427405274062740727408274092741027411274122741327414274152741627417274182741927420274212742227423274242742527426274272742827429274302743127432274332743427435274362743727438274392744027441274422744327444274452744627447274482744927450274512745227453274542745527456274572745827459274602746127462274632746427465274662746727468274692747027471274722747327474274752747627477274782747927480274812748227483274842748527486274872748827489274902749127492274932749427495274962749727498274992750027501275022750327504275052750627507275082750927510275112751227513275142751527516275172751827519275202752127522275232752427525275262752727528275292753027531275322753327534275352753627537275382753927540275412754227543275442754527546275472754827549275502755127552275532755427555275562755727558275592756027561275622756327564275652756627567275682756927570275712757227573275742757527576275772757827579275802758127582275832758427585275862758727588275892759027591275922759327594275952759627597275982759927600276012760227603276042760527606276072760827609276102761127612276132761427615276162761727618276192762027621276222762327624276252762627627276282762927630276312763227633276342763527636276372763827639276402764127642276432764427645276462764727648276492765027651276522765327654276552765627657276582765927660276612766227663276642766527666276672766827669276702767127672276732767427675276762767727678276792768027681276822768327684276852768627687276882768927690276912769227693276942769527696276972769827699277002770127702277032770427705277062770727708277092771027711277122771327714277152771627717277182771927720277212772227723277242772527726277272772827729277302773127732277332773427735277362773727738277392774027741277422774327744277452774627747277482774927750277512775227753277542775527756277572775827759277602776127762277632776427765277662776727768277692777027771277722777327774277752777627777277782777927780277812778227783277842778527786277872778827789277902779127792277932779427795277962779727798277992780027801278022780327804278052780627807278082780927810278112781227813278142781527816278172781827819278202782127822278232782427825278262782727828278292783027831278322783327834278352783627837278382783927840278412784227843278442784527846278472784827849278502785127852278532785427855278562785727858278592786027861278622786327864278652786627867278682786927870278712787227873278742787527876278772787827879278802788127882278832788427885278862788727888278892789027891278922789327894278952789627897278982789927900279012790227903279042790527906279072790827909279102791127912279132791427915279162791727918279192792027921279222792327924279252792627927279282792927930279312793227933279342793527936279372793827939279402794127942279432794427945279462794727948279492795027951279522795327954279552795627957279582795927960279612796227963279642796527966279672796827969279702797127972279732797427975279762797727978279792798027981279822798327984279852798627987279882798927990279912799227993279942799527996279972799827999280002800128002280032800428005280062800728008280092801028011280122801328014280152801628017280182801928020280212802228023280242802528026280272802828029280302803128032280332803428035280362803728038280392804028041280422804328044280452804628047280482804928050280512805228053280542805528056280572805828059280602806128062280632806428065280662806728068280692807028071280722807328074280752807628077280782807928080280812808228083280842808528086280872808828089280902809128092280932809428095280962809728098280992810028101281022810328104281052810628107281082810928110281112811228113281142811528116281172811828119281202812128122281232812428125281262812728128281292813028131281322813328134281352813628137281382813928140281412814228143281442814528146281472814828149281502815128152281532815428155281562815728158281592816028161281622816328164281652816628167281682816928170281712817228173281742817528176281772817828179281802818128182281832818428185281862818728188281892819028191281922819328194281952819628197281982819928200282012820228203282042820528206282072820828209282102821128212282132821428215282162821728218282192822028221282222822328224282252822628227282282822928230282312823228233282342823528236282372823828239282402824128242282432824428245282462824728248282492825028251282522825328254282552825628257282582825928260282612826228263282642826528266282672826828269282702827128272282732827428275282762827728278282792828028281282822828328284282852828628287282882828928290282912829228293282942829528296282972829828299283002830128302283032830428305283062830728308283092831028311283122831328314283152831628317283182831928320283212832228323283242832528326283272832828329283302833128332283332833428335283362833728338283392834028341283422834328344283452834628347283482834928350283512835228353283542835528356283572835828359283602836128362283632836428365283662836728368283692837028371283722837328374283752837628377283782837928380283812838228383283842838528386283872838828389283902839128392283932839428395283962839728398283992840028401284022840328404284052840628407284082840928410284112841228413284142841528416284172841828419284202842128422284232842428425284262842728428284292843028431284322843328434284352843628437284382843928440284412844228443284442844528446284472844828449284502845128452284532845428455284562845728458284592846028461284622846328464284652846628467284682846928470284712847228473284742847528476284772847828479284802848128482284832848428485284862848728488284892849028491284922849328494284952849628497284982849928500285012850228503285042850528506285072850828509285102851128512285132851428515285162851728518285192852028521285222852328524285252852628527285282852928530285312853228533285342853528536285372853828539285402854128542285432854428545285462854728548285492855028551285522855328554285552855628557285582855928560285612856228563285642856528566285672856828569285702857128572285732857428575285762857728578285792858028581285822858328584285852858628587285882858928590285912859228593285942859528596285972859828599286002860128602286032860428605286062860728608286092861028611286122861328614286152861628617286182861928620286212862228623286242862528626286272862828629286302863128632286332863428635286362863728638286392864028641286422864328644286452864628647286482864928650286512865228653286542865528656286572865828659286602866128662286632866428665286662866728668286692867028671286722867328674286752867628677286782867928680286812868228683286842868528686286872868828689286902869128692286932869428695286962869728698286992870028701287022870328704287052870628707287082870928710287112871228713287142871528716287172871828719287202872128722287232872428725287262872728728287292873028731287322873328734287352873628737287382873928740287412874228743287442874528746287472874828749287502875128752287532875428755287562875728758287592876028761287622876328764287652876628767287682876928770287712877228773287742877528776287772877828779287802878128782287832878428785287862878728788287892879028791287922879328794287952879628797287982879928800288012880228803288042880528806288072880828809288102881128812288132881428815288162881728818288192882028821288222882328824288252882628827288282882928830288312883228833288342883528836288372883828839288402884128842288432884428845288462884728848288492885028851288522885328854288552885628857288582885928860288612886228863288642886528866288672886828869288702887128872288732887428875288762887728878288792888028881288822888328884288852888628887288882888928890288912889228893288942889528896288972889828899289002890128902289032890428905289062890728908289092891028911289122891328914289152891628917289182891928920289212892228923289242892528926289272892828929289302893128932289332893428935289362893728938289392894028941289422894328944289452894628947289482894928950289512895228953289542895528956289572895828959289602896128962289632896428965289662896728968289692897028971289722897328974289752897628977289782897928980289812898228983289842898528986289872898828989289902899128992289932899428995289962899728998289992900029001290022900329004290052900629007290082900929010290112901229013290142901529016290172901829019290202902129022290232902429025290262902729028290292903029031290322903329034290352903629037290382903929040290412904229043290442904529046290472904829049290502905129052290532905429055290562905729058290592906029061290622906329064290652906629067290682906929070290712907229073290742907529076290772907829079290802908129082290832908429085290862908729088290892909029091290922909329094290952909629097290982909929100291012910229103291042910529106291072910829109291102911129112291132911429115291162911729118291192912029121291222912329124291252912629127291282912929130291312913229133291342913529136291372913829139291402914129142291432914429145291462914729148291492915029151291522915329154291552915629157291582915929160291612916229163291642916529166291672916829169291702917129172291732917429175291762917729178291792918029181291822918329184291852918629187291882918929190291912919229193291942919529196291972919829199292002920129202292032920429205292062920729208292092921029211292122921329214292152921629217292182921929220292212922229223292242922529226292272922829229292302923129232292332923429235292362923729238292392924029241292422924329244292452924629247292482924929250292512925229253292542925529256292572925829259292602926129262292632926429265292662926729268292692927029271292722927329274292752927629277292782927929280292812928229283292842928529286292872928829289292902929129292292932929429295292962929729298292992930029301293022930329304293052930629307293082930929310293112931229313293142931529316293172931829319293202932129322293232932429325293262932729328293292933029331293322933329334293352933629337293382933929340293412934229343293442934529346293472934829349293502935129352293532935429355293562935729358293592936029361293622936329364293652936629367293682936929370293712937229373293742937529376293772937829379293802938129382293832938429385293862938729388293892939029391293922939329394293952939629397293982939929400294012940229403294042940529406294072940829409294102941129412294132941429415294162941729418294192942029421294222942329424294252942629427294282942929430294312943229433294342943529436294372943829439294402944129442294432944429445294462944729448294492945029451294522945329454294552945629457294582945929460294612946229463294642946529466294672946829469294702947129472294732947429475294762947729478294792948029481294822948329484294852948629487294882948929490294912949229493294942949529496294972949829499295002950129502295032950429505295062950729508295092951029511295122951329514295152951629517295182951929520295212952229523295242952529526295272952829529295302953129532295332953429535295362953729538295392954029541295422954329544295452954629547295482954929550295512955229553295542955529556295572955829559295602956129562295632956429565295662956729568295692957029571295722957329574295752957629577295782957929580295812958229583295842958529586295872958829589295902959129592295932959429595295962959729598295992960029601296022960329604296052960629607296082960929610296112961229613296142961529616296172961829619296202962129622296232962429625296262962729628296292963029631296322963329634296352963629637296382963929640296412964229643296442964529646296472964829649296502965129652296532965429655296562965729658296592966029661296622966329664296652966629667296682966929670296712967229673296742967529676296772967829679296802968129682296832968429685296862968729688296892969029691296922969329694296952969629697296982969929700297012970229703297042970529706297072970829709297102971129712297132971429715297162971729718297192972029721297222972329724297252972629727297282972929730297312973229733297342973529736297372973829739297402974129742297432974429745297462974729748297492975029751297522975329754297552975629757297582975929760297612976229763297642976529766297672976829769297702977129772297732977429775297762977729778297792978029781297822978329784297852978629787297882978929790297912979229793297942979529796297972979829799298002980129802298032980429805298062980729808298092981029811298122981329814298152981629817298182981929820298212982229823298242982529826298272982829829298302983129832298332983429835298362983729838298392984029841298422984329844298452984629847298482984929850298512985229853298542985529856298572985829859298602986129862298632986429865298662986729868298692987029871298722987329874298752987629877298782987929880298812988229883298842988529886298872988829889298902989129892298932989429895298962989729898298992990029901299022990329904299052990629907299082990929910299112991229913299142991529916299172991829919299202992129922299232992429925299262992729928299292993029931299322993329934299352993629937299382993929940299412994229943299442994529946299472994829949299502995129952299532995429955299562995729958299592996029961299622996329964299652996629967299682996929970299712997229973299742997529976299772997829979299802998129982299832998429985299862998729988299892999029991299922999329994299952999629997299982999930000300013000230003300043000530006300073000830009300103001130012300133001430015300163001730018300193002030021300223002330024300253002630027300283002930030300313003230033300343003530036300373003830039300403004130042300433004430045300463004730048300493005030051300523005330054300553005630057300583005930060300613006230063300643006530066300673006830069300703007130072300733007430075300763007730078300793008030081300823008330084300853008630087300883008930090300913009230093300943009530096300973009830099301003010130102301033010430105301063010730108301093011030111301123011330114301153011630117301183011930120301213012230123301243012530126301273012830129301303013130132301333013430135301363013730138301393014030141301423014330144301453014630147301483014930150301513015230153301543015530156301573015830159301603016130162301633016430165301663016730168301693017030171301723017330174301753017630177301783017930180301813018230183301843018530186301873018830189301903019130192301933019430195301963019730198301993020030201302023020330204302053020630207302083020930210302113021230213302143021530216302173021830219302203022130222302233022430225302263022730228302293023030231302323023330234302353023630237302383023930240302413024230243302443024530246302473024830249302503025130252302533025430255302563025730258302593026030261302623026330264302653026630267302683026930270302713027230273302743027530276302773027830279302803028130282302833028430285302863028730288302893029030291302923029330294302953029630297302983029930300303013030230303303043030530306303073030830309303103031130312303133031430315303163031730318303193032030321303223032330324303253032630327303283032930330303313033230333303343033530336303373033830339303403034130342303433034430345303463034730348303493035030351303523035330354303553035630357303583035930360303613036230363303643036530366303673036830369303703037130372303733037430375303763037730378303793038030381303823038330384303853038630387303883038930390303913039230393303943039530396303973039830399304003040130402304033040430405304063040730408304093041030411304123041330414304153041630417304183041930420304213042230423304243042530426304273042830429304303043130432304333043430435304363043730438304393044030441304423044330444304453044630447304483044930450304513045230453304543045530456304573045830459304603046130462304633046430465304663046730468304693047030471304723047330474304753047630477304783047930480304813048230483304843048530486304873048830489304903049130492304933049430495304963049730498304993050030501305023050330504305053050630507305083050930510305113051230513305143051530516305173051830519305203052130522305233052430525305263052730528305293053030531305323053330534305353053630537305383053930540305413054230543305443054530546305473054830549305503055130552305533055430555305563055730558305593056030561305623056330564305653056630567305683056930570305713057230573305743057530576305773057830579305803058130582305833058430585305863058730588305893059030591305923059330594305953059630597305983059930600306013060230603306043060530606306073060830609306103061130612306133061430615306163061730618306193062030621306223062330624306253062630627306283062930630306313063230633306343063530636306373063830639306403064130642306433064430645306463064730648306493065030651306523065330654306553065630657306583065930660306613066230663306643066530666306673066830669306703067130672306733067430675306763067730678306793068030681306823068330684306853068630687306883068930690306913069230693306943069530696306973069830699307003070130702307033070430705307063070730708307093071030711307123071330714307153071630717307183071930720307213072230723307243072530726307273072830729307303073130732307333073430735307363073730738307393074030741307423074330744307453074630747307483074930750307513075230753307543075530756307573075830759307603076130762307633076430765307663076730768307693077030771307723077330774307753077630777307783077930780307813078230783307843078530786307873078830789307903079130792307933079430795307963079730798307993080030801308023080330804308053080630807308083080930810308113081230813308143081530816308173081830819308203082130822308233082430825308263082730828308293083030831308323083330834308353083630837308383083930840308413084230843308443084530846308473084830849308503085130852308533085430855308563085730858308593086030861308623086330864308653086630867308683086930870308713087230873308743087530876308773087830879308803088130882308833088430885308863088730888308893089030891308923089330894308953089630897308983089930900309013090230903309043090530906309073090830909309103091130912309133091430915309163091730918309193092030921309223092330924309253092630927309283092930930309313093230933309343093530936309373093830939309403094130942309433094430945309463094730948309493095030951309523095330954309553095630957309583095930960309613096230963309643096530966309673096830969309703097130972309733097430975309763097730978309793098030981309823098330984309853098630987309883098930990309913099230993309943099530996309973099830999310003100131002310033100431005310063100731008310093101031011310123101331014310153101631017310183101931020310213102231023310243102531026310273102831029310303103131032310333103431035310363103731038310393104031041310423104331044310453104631047310483104931050310513105231053310543105531056310573105831059310603106131062310633106431065310663106731068310693107031071310723107331074310753107631077310783107931080310813108231083310843108531086310873108831089310903109131092310933109431095310963109731098310993110031101311023110331104311053110631107311083110931110311113111231113311143111531116311173111831119311203112131122311233112431125311263112731128311293113031131311323113331134311353113631137311383113931140311413114231143311443114531146311473114831149311503115131152311533115431155311563115731158311593116031161311623116331164311653116631167311683116931170311713117231173311743117531176311773117831179311803118131182311833118431185311863118731188311893119031191311923119331194311953119631197311983119931200312013120231203312043120531206312073120831209312103121131212312133121431215312163121731218312193122031221312223122331224312253122631227312283122931230312313123231233312343123531236312373123831239312403124131242312433124431245312463124731248312493125031251312523125331254312553125631257312583125931260312613126231263312643126531266312673126831269312703127131272312733127431275312763127731278312793128031281312823128331284312853128631287312883128931290312913129231293312943129531296312973129831299313003130131302313033130431305313063130731308313093131031311313123131331314313153131631317313183131931320313213132231323313243132531326313273132831329313303133131332313333133431335313363133731338313393134031341313423134331344313453134631347313483134931350313513135231353313543135531356313573135831359313603136131362313633136431365313663136731368313693137031371313723137331374313753137631377313783137931380313813138231383313843138531386313873138831389313903139131392313933139431395313963139731398313993140031401314023140331404314053140631407314083140931410314113141231413314143141531416314173141831419314203142131422314233142431425314263142731428314293143031431314323143331434314353143631437314383143931440314413144231443314443144531446314473144831449314503145131452314533145431455314563145731458314593146031461314623146331464314653146631467314683146931470314713147231473314743147531476314773147831479314803148131482314833148431485314863148731488314893149031491314923149331494314953149631497314983149931500315013150231503315043150531506315073150831509315103151131512315133151431515315163151731518315193152031521315223152331524315253152631527315283152931530315313153231533315343153531536315373153831539315403154131542315433154431545315463154731548315493155031551315523155331554315553155631557315583155931560315613156231563315643156531566315673156831569315703157131572315733157431575315763157731578315793158031581315823158331584315853158631587315883158931590315913159231593315943159531596315973159831599316003160131602316033160431605316063160731608316093161031611316123161331614316153161631617316183161931620316213162231623316243162531626316273162831629316303163131632316333163431635316363163731638316393164031641316423164331644316453164631647316483164931650316513165231653316543165531656316573165831659316603166131662316633166431665316663166731668316693167031671316723167331674316753167631677316783167931680316813168231683316843168531686316873168831689316903169131692316933169431695316963169731698316993170031701317023170331704317053170631707317083170931710317113171231713317143171531716317173171831719317203172131722317233172431725317263172731728317293173031731317323173331734317353173631737317383173931740317413174231743317443174531746317473174831749317503175131752317533175431755317563175731758317593176031761317623176331764317653176631767317683176931770317713177231773317743177531776317773177831779317803178131782317833178431785317863178731788317893179031791317923179331794317953179631797317983179931800318013180231803318043180531806318073180831809318103181131812318133181431815318163181731818318193182031821318223182331824318253182631827318283182931830318313183231833318343183531836318373183831839318403184131842318433184431845318463184731848318493185031851318523185331854318553185631857318583185931860318613186231863318643186531866318673186831869318703187131872318733187431875318763187731878318793188031881318823188331884318853188631887318883188931890318913189231893318943189531896318973189831899319003190131902319033190431905319063190731908319093191031911319123191331914319153191631917319183191931920319213192231923319243192531926319273192831929319303193131932319333193431935319363193731938319393194031941319423194331944319453194631947319483194931950319513195231953319543195531956319573195831959319603196131962319633196431965319663196731968319693197031971319723197331974319753197631977319783197931980319813198231983319843198531986319873198831989319903199131992319933199431995319963199731998319993200032001320023200332004320053200632007320083200932010320113201232013320143201532016320173201832019320203202132022320233202432025320263202732028320293203032031320323203332034320353203632037320383203932040320413204232043320443204532046320473204832049320503205132052320533205432055320563205732058320593206032061320623206332064320653206632067320683206932070320713207232073320743207532076320773207832079320803208132082320833208432085320863208732088320893209032091320923209332094320953209632097320983209932100321013210232103321043210532106321073210832109321103211132112321133211432115321163211732118321193212032121321223212332124321253212632127321283212932130321313213232133321343213532136321373213832139321403214132142321433214432145321463214732148321493215032151321523215332154321553215632157321583215932160321613216232163321643216532166321673216832169321703217132172321733217432175321763217732178321793218032181321823218332184321853218632187321883218932190321913219232193321943219532196321973219832199322003220132202322033220432205322063220732208322093221032211322123221332214322153221632217322183221932220322213222232223322243222532226322273222832229322303223132232322333223432235322363223732238322393224032241322423224332244322453224632247322483224932250322513225232253322543225532256322573225832259322603226132262322633226432265322663226732268322693227032271322723227332274322753227632277322783227932280322813228232283322843228532286322873228832289322903229132292322933229432295322963229732298322993230032301323023230332304323053230632307323083230932310323113231232313323143231532316323173231832319323203232132322323233232432325323263232732328323293233032331323323233332334323353233632337323383233932340323413234232343323443234532346323473234832349323503235132352323533235432355323563235732358323593236032361323623236332364323653236632367323683236932370323713237232373323743237532376323773237832379323803238132382323833238432385323863238732388323893239032391323923239332394323953239632397323983239932400324013240232403324043240532406324073240832409324103241132412324133241432415324163241732418324193242032421324223242332424324253242632427324283242932430324313243232433324343243532436324373243832439324403244132442324433244432445324463244732448324493245032451324523245332454324553245632457324583245932460324613246232463324643246532466324673246832469324703247132472324733247432475324763247732478324793248032481324823248332484324853248632487324883248932490324913249232493324943249532496324973249832499325003250132502325033250432505325063250732508325093251032511325123251332514325153251632517325183251932520325213252232523325243252532526325273252832529325303253132532325333253432535325363253732538325393254032541325423254332544325453254632547325483254932550325513255232553325543255532556325573255832559325603256132562325633256432565325663256732568325693257032571325723257332574325753257632577325783257932580325813258232583325843258532586325873258832589325903259132592325933259432595325963259732598325993260032601326023260332604326053260632607326083260932610326113261232613326143261532616326173261832619326203262132622326233262432625326263262732628326293263032631326323263332634326353263632637326383263932640326413264232643326443264532646326473264832649326503265132652326533265432655326563265732658326593266032661326623266332664326653266632667326683266932670326713267232673326743267532676326773267832679326803268132682326833268432685326863268732688326893269032691326923269332694326953269632697326983269932700327013270232703327043270532706327073270832709327103271132712327133271432715327163271732718327193272032721327223272332724327253272632727327283272932730327313273232733327343273532736327373273832739327403274132742327433274432745327463274732748327493275032751327523275332754327553275632757327583275932760327613276232763327643276532766327673276832769327703277132772327733277432775327763277732778327793278032781327823278332784327853278632787327883278932790327913279232793327943279532796327973279832799328003280132802328033280432805328063280732808328093281032811328123281332814328153281632817328183281932820328213282232823328243282532826328273282832829328303283132832328333283432835328363283732838328393284032841328423284332844328453284632847328483284932850328513285232853328543285532856328573285832859328603286132862328633286432865328663286732868328693287032871328723287332874328753287632877328783287932880328813288232883328843288532886328873288832889328903289132892328933289432895328963289732898328993290032901329023290332904329053290632907329083290932910329113291232913329143291532916329173291832919329203292132922329233292432925329263292732928329293293032931329323293332934329353293632937329383293932940329413294232943329443294532946329473294832949329503295132952329533295432955329563295732958329593296032961329623296332964329653296632967329683296932970329713297232973329743297532976329773297832979329803298132982329833298432985329863298732988329893299032991329923299332994329953299632997329983299933000330013300233003330043300533006330073300833009330103301133012330133301433015330163301733018330193302033021330223302333024330253302633027330283302933030330313303233033330343303533036330373303833039330403304133042330433304433045330463304733048330493305033051330523305333054330553305633057330583305933060330613306233063330643306533066330673306833069330703307133072330733307433075330763307733078330793308033081330823308333084330853308633087330883308933090330913309233093330943309533096330973309833099331003310133102331033310433105331063310733108331093311033111331123311333114331153311633117331183311933120331213312233123331243312533126331273312833129331303313133132331333313433135331363313733138331393314033141331423314333144331453314633147331483314933150331513315233153331543315533156331573315833159331603316133162331633316433165331663316733168331693317033171331723317333174331753317633177331783317933180331813318233183331843318533186331873318833189331903319133192331933319433195331963319733198331993320033201332023320333204332053320633207332083320933210332113321233213332143321533216332173321833219332203322133222332233322433225332263322733228332293323033231332323323333234332353323633237332383323933240332413324233243332443324533246332473324833249332503325133252332533325433255332563325733258332593326033261332623326333264332653326633267332683326933270332713327233273332743327533276332773327833279332803328133282332833328433285332863328733288332893329033291332923329333294332953329633297332983329933300333013330233303333043330533306333073330833309333103331133312333133331433315333163331733318333193332033321333223332333324333253332633327333283332933330333313333233333333343333533336333373333833339333403334133342333433334433345333463334733348333493335033351333523335333354333553335633357333583335933360333613336233363333643336533366333673336833369333703337133372333733337433375333763337733378333793338033381333823338333384333853338633387333883338933390333913339233393333943339533396333973339833399334003340133402334033340433405334063340733408334093341033411334123341333414334153341633417334183341933420334213342233423334243342533426334273342833429334303343133432334333343433435334363343733438334393344033441334423344333444334453344633447334483344933450334513345233453334543345533456334573345833459334603346133462334633346433465334663346733468334693347033471334723347333474334753347633477334783347933480334813348233483334843348533486334873348833489334903349133492334933349433495334963349733498334993350033501335023350333504335053350633507335083350933510335113351233513335143351533516335173351833519335203352133522335233352433525335263352733528335293353033531335323353333534335353353633537335383353933540335413354233543335443354533546335473354833549335503355133552335533355433555335563355733558335593356033561335623356333564335653356633567335683356933570335713357233573335743357533576335773357833579335803358133582335833358433585335863358733588335893359033591335923359333594335953359633597335983359933600336013360233603336043360533606336073360833609336103361133612336133361433615336163361733618336193362033621336223362333624336253362633627336283362933630336313363233633336343363533636336373363833639336403364133642336433364433645336463364733648336493365033651336523365333654336553365633657336583365933660336613366233663336643366533666336673366833669336703367133672336733367433675336763367733678336793368033681336823368333684336853368633687336883368933690336913369233693336943369533696336973369833699337003370133702337033370433705337063370733708337093371033711337123371333714337153371633717337183371933720337213372233723337243372533726337273372833729337303373133732337333373433735337363373733738337393374033741337423374333744337453374633747337483374933750337513375233753337543375533756337573375833759337603376133762337633376433765337663376733768337693377033771337723377333774337753377633777337783377933780337813378233783337843378533786337873378833789337903379133792337933379433795337963379733798337993380033801338023380333804338053380633807338083380933810338113381233813338143381533816338173381833819338203382133822338233382433825338263382733828338293383033831338323383333834338353383633837338383383933840338413384233843338443384533846338473384833849338503385133852338533385433855338563385733858338593386033861338623386333864338653386633867338683386933870338713387233873338743387533876338773387833879338803388133882338833388433885338863388733888338893389033891338923389333894338953389633897338983389933900339013390233903339043390533906339073390833909339103391133912339133391433915339163391733918339193392033921339223392333924339253392633927339283392933930339313393233933339343393533936339373393833939339403394133942339433394433945339463394733948339493395033951339523395333954339553395633957339583395933960339613396233963339643396533966339673396833969339703397133972339733397433975339763397733978339793398033981339823398333984339853398633987339883398933990339913399233993339943399533996339973399833999340003400134002340033400434005340063400734008340093401034011340123401334014340153401634017340183401934020340213402234023340243402534026340273402834029340303403134032340333403434035340363403734038340393404034041340423404334044340453404634047340483404934050340513405234053340543405534056340573405834059340603406134062340633406434065340663406734068340693407034071340723407334074340753407634077340783407934080340813408234083340843408534086340873408834089340903409134092340933409434095340963409734098340993410034101341023410334104341053410634107341083410934110341113411234113341143411534116341173411834119341203412134122341233412434125341263412734128341293413034131341323413334134341353413634137341383413934140341413414234143341443414534146341473414834149341503415134152341533415434155341563415734158341593416034161341623416334164341653416634167341683416934170341713417234173341743417534176341773417834179341803418134182341833418434185341863418734188341893419034191341923419334194341953419634197341983419934200342013420234203342043420534206342073420834209342103421134212342133421434215342163421734218342193422034221342223422334224342253422634227342283422934230342313423234233342343423534236342373423834239342403424134242342433424434245342463424734248342493425034251342523425334254342553425634257342583425934260342613426234263342643426534266342673426834269342703427134272342733427434275342763427734278342793428034281342823428334284342853428634287342883428934290342913429234293342943429534296342973429834299343003430134302343033430434305343063430734308343093431034311343123431334314343153431634317343183431934320343213432234323343243432534326343273432834329343303433134332343333433434335343363433734338343393434034341343423434334344343453434634347343483434934350343513435234353343543435534356343573435834359343603436134362343633436434365343663436734368343693437034371343723437334374343753437634377343783437934380343813438234383343843438534386343873438834389343903439134392343933439434395343963439734398343993440034401344023440334404344053440634407344083440934410344113441234413344143441534416344173441834419344203442134422344233442434425344263442734428344293443034431344323443334434344353443634437344383443934440344413444234443344443444534446344473444834449344503445134452344533445434455344563445734458344593446034461344623446334464344653446634467344683446934470344713447234473344743447534476344773447834479344803448134482344833448434485344863448734488344893449034491344923449334494344953449634497344983449934500345013450234503345043450534506345073450834509345103451134512345133451434515345163451734518345193452034521345223452334524345253452634527345283452934530345313453234533345343453534536345373453834539345403454134542345433454434545345463454734548345493455034551345523455334554345553455634557345583455934560345613456234563345643456534566345673456834569345703457134572345733457434575345763457734578345793458034581345823458334584345853458634587345883458934590345913459234593345943459534596345973459834599346003460134602346033460434605346063460734608346093461034611346123461334614346153461634617346183461934620346213462234623346243462534626346273462834629346303463134632346333463434635346363463734638346393464034641346423464334644346453464634647346483464934650346513465234653346543465534656346573465834659346603466134662346633466434665346663466734668346693467034671346723467334674346753467634677346783467934680346813468234683346843468534686346873468834689346903469134692346933469434695346963469734698346993470034701347023470334704347053470634707347083470934710347113471234713347143471534716347173471834719347203472134722347233472434725347263472734728347293473034731347323473334734347353473634737347383473934740347413474234743347443474534746347473474834749347503475134752347533475434755347563475734758347593476034761347623476334764347653476634767347683476934770347713477234773347743477534776347773477834779347803478134782347833478434785347863478734788347893479034791347923479334794347953479634797347983479934800348013480234803348043480534806348073480834809348103481134812348133481434815348163481734818348193482034821348223482334824348253482634827348283482934830348313483234833348343483534836348373483834839348403484134842348433484434845348463484734848348493485034851348523485334854348553485634857348583485934860348613486234863348643486534866348673486834869348703487134872348733487434875348763487734878348793488034881348823488334884348853488634887348883488934890348913489234893348943489534896348973489834899349003490134902349033490434905349063490734908349093491034911349123491334914349153491634917349183491934920349213492234923349243492534926349273492834929349303493134932349333493434935349363493734938349393494034941349423494334944349453494634947349483494934950349513495234953349543495534956349573495834959349603496134962349633496434965349663496734968349693497034971349723497334974349753497634977349783497934980349813498234983349843498534986349873498834989349903499134992349933499434995349963499734998349993500035001350023500335004350053500635007350083500935010350113501235013350143501535016350173501835019350203502135022350233502435025350263502735028350293503035031350323503335034350353503635037350383503935040350413504235043350443504535046350473504835049350503505135052350533505435055350563505735058350593506035061350623506335064350653506635067350683506935070350713507235073350743507535076350773507835079350803508135082350833508435085350863508735088350893509035091350923509335094350953509635097350983509935100351013510235103351043510535106351073510835109351103511135112351133511435115351163511735118351193512035121351223512335124351253512635127351283512935130351313513235133351343513535136351373513835139351403514135142351433514435145351463514735148351493515035151351523515335154351553515635157351583515935160351613516235163351643516535166351673516835169351703517135172351733517435175351763517735178351793518035181351823518335184351853518635187351883518935190351913519235193351943519535196351973519835199352003520135202352033520435205352063520735208352093521035211352123521335214352153521635217352183521935220352213522235223352243522535226352273522835229352303523135232352333523435235352363523735238352393524035241352423524335244352453524635247352483524935250352513525235253352543525535256352573525835259352603526135262352633526435265352663526735268352693527035271352723527335274352753527635277352783527935280352813528235283352843528535286352873528835289352903529135292352933529435295352963529735298352993530035301353023530335304353053530635307353083530935310353113531235313353143531535316353173531835319353203532135322353233532435325353263532735328353293533035331353323533335334353353533635337353383533935340353413534235343353443534535346353473534835349353503535135352353533535435355353563535735358353593536035361353623536335364353653536635367353683536935370353713537235373353743537535376353773537835379353803538135382353833538435385353863538735388353893539035391353923539335394353953539635397353983539935400354013540235403354043540535406354073540835409354103541135412354133541435415354163541735418354193542035421354223542335424354253542635427354283542935430354313543235433354343543535436354373543835439354403544135442354433544435445354463544735448354493545035451354523545335454354553545635457354583545935460354613546235463354643546535466354673546835469354703547135472354733547435475354763547735478354793548035481354823548335484354853548635487354883548935490354913549235493354943549535496354973549835499355003550135502355033550435505355063550735508355093551035511355123551335514355153551635517355183551935520355213552235523355243552535526355273552835529355303553135532355333553435535355363553735538355393554035541355423554335544355453554635547355483554935550355513555235553355543555535556355573555835559355603556135562355633556435565355663556735568355693557035571355723557335574355753557635577355783557935580355813558235583355843558535586355873558835589355903559135592355933559435595355963559735598355993560035601356023560335604356053560635607356083560935610356113561235613356143561535616356173561835619356203562135622356233562435625356263562735628356293563035631356323563335634356353563635637356383563935640356413564235643356443564535646356473564835649356503565135652356533565435655356563565735658356593566035661356623566335664356653566635667356683566935670356713567235673356743567535676356773567835679356803568135682356833568435685356863568735688356893569035691356923569335694356953569635697356983569935700357013570235703357043570535706357073570835709357103571135712357133571435715357163571735718357193572035721357223572335724357253572635727357283572935730357313573235733357343573535736357373573835739357403574135742357433574435745357463574735748357493575035751357523575335754357553575635757357583575935760357613576235763357643576535766357673576835769357703577135772357733577435775357763577735778357793578035781357823578335784357853578635787357883578935790357913579235793357943579535796357973579835799358003580135802358033580435805358063580735808358093581035811358123581335814358153581635817358183581935820358213582235823358243582535826358273582835829358303583135832358333583435835358363583735838358393584035841358423584335844358453584635847358483584935850358513585235853358543585535856358573585835859358603586135862358633586435865358663586735868358693587035871358723587335874358753587635877358783587935880358813588235883358843588535886358873588835889358903589135892358933589435895358963589735898358993590035901359023590335904359053590635907359083590935910359113591235913359143591535916359173591835919359203592135922359233592435925359263592735928359293593035931359323593335934359353593635937359383593935940359413594235943359443594535946359473594835949359503595135952359533595435955359563595735958359593596035961359623596335964359653596635967359683596935970359713597235973359743597535976359773597835979359803598135982359833598435985359863598735988359893599035991359923599335994359953599635997359983599936000360013600236003360043600536006360073600836009360103601136012360133601436015360163601736018360193602036021360223602336024360253602636027360283602936030360313603236033360343603536036360373603836039360403604136042360433604436045360463604736048360493605036051360523605336054360553605636057360583605936060360613606236063360643606536066360673606836069360703607136072360733607436075360763607736078360793608036081360823608336084360853608636087360883608936090360913609236093360943609536096360973609836099361003610136102361033610436105361063610736108361093611036111361123611336114361153611636117361183611936120361213612236123361243612536126361273612836129361303613136132361333613436135361363613736138361393614036141361423614336144361453614636147361483614936150361513615236153361543615536156361573615836159361603616136162361633616436165361663616736168361693617036171361723617336174361753617636177361783617936180361813618236183361843618536186361873618836189361903619136192361933619436195361963619736198361993620036201362023620336204362053620636207362083620936210362113621236213362143621536216362173621836219362203622136222362233622436225362263622736228362293623036231362323623336234362353623636237362383623936240362413624236243362443624536246362473624836249362503625136252362533625436255362563625736258362593626036261362623626336264362653626636267362683626936270362713627236273362743627536276362773627836279362803628136282362833628436285362863628736288362893629036291362923629336294362953629636297362983629936300363013630236303363043630536306363073630836309363103631136312363133631436315363163631736318363193632036321363223632336324363253632636327363283632936330363313633236333363343633536336363373633836339363403634136342363433634436345363463634736348363493635036351363523635336354363553635636357363583635936360363613636236363363643636536366363673636836369363703637136372363733637436375363763637736378363793638036381363823638336384363853638636387363883638936390363913639236393363943639536396363973639836399364003640136402364033640436405364063640736408364093641036411364123641336414364153641636417364183641936420364213642236423364243642536426364273642836429364303643136432364333643436435364363643736438364393644036441364423644336444364453644636447364483644936450364513645236453364543645536456364573645836459364603646136462364633646436465364663646736468364693647036471364723647336474364753647636477364783647936480364813648236483364843648536486364873648836489364903649136492364933649436495364963649736498364993650036501365023650336504365053650636507365083650936510365113651236513365143651536516365173651836519365203652136522365233652436525365263652736528365293653036531365323653336534365353653636537365383653936540365413654236543365443654536546365473654836549365503655136552365533655436555365563655736558365593656036561365623656336564365653656636567365683656936570365713657236573365743657536576365773657836579365803658136582365833658436585365863658736588365893659036591365923659336594365953659636597365983659936600366013660236603366043660536606366073660836609366103661136612366133661436615366163661736618366193662036621366223662336624366253662636627366283662936630366313663236633366343663536636366373663836639366403664136642366433664436645366463664736648366493665036651366523665336654366553665636657366583665936660366613666236663366643666536666366673666836669366703667136672366733667436675366763667736678366793668036681366823668336684366853668636687366883668936690366913669236693366943669536696366973669836699367003670136702367033670436705367063670736708367093671036711367123671336714367153671636717367183671936720367213672236723367243672536726367273672836729367303673136732367333673436735367363673736738367393674036741367423674336744367453674636747367483674936750367513675236753367543675536756367573675836759367603676136762367633676436765367663676736768367693677036771367723677336774367753677636777367783677936780367813678236783367843678536786367873678836789367903679136792367933679436795367963679736798367993680036801368023680336804368053680636807368083680936810368113681236813368143681536816368173681836819368203682136822368233682436825368263682736828368293683036831368323683336834368353683636837368383683936840368413684236843368443684536846368473684836849368503685136852368533685436855368563685736858368593686036861368623686336864368653686636867368683686936870368713687236873368743687536876368773687836879368803688136882368833688436885368863688736888368893689036891368923689336894368953689636897368983689936900369013690236903369043690536906369073690836909369103691136912369133691436915369163691736918369193692036921369223692336924369253692636927369283692936930369313693236933369343693536936369373693836939369403694136942369433694436945369463694736948369493695036951369523695336954369553695636957369583695936960369613696236963369643696536966369673696836969369703697136972369733697436975369763697736978369793698036981369823698336984369853698636987369883698936990369913699236993369943699536996369973699836999370003700137002370033700437005370063700737008370093701037011370123701337014370153701637017370183701937020370213702237023370243702537026370273702837029370303703137032370333703437035370363703737038370393704037041370423704337044370453704637047370483704937050370513705237053370543705537056370573705837059370603706137062370633706437065370663706737068370693707037071370723707337074370753707637077370783707937080370813708237083370843708537086370873708837089370903709137092370933709437095370963709737098370993710037101371023710337104371053710637107371083710937110371113711237113371143711537116371173711837119371203712137122371233712437125371263712737128371293713037131371323713337134371353713637137371383713937140371413714237143371443714537146371473714837149371503715137152371533715437155371563715737158371593716037161371623716337164371653716637167371683716937170371713717237173371743717537176371773717837179371803718137182371833718437185371863718737188371893719037191371923719337194371953719637197371983719937200372013720237203372043720537206372073720837209372103721137212372133721437215372163721737218372193722037221372223722337224372253722637227372283722937230372313723237233372343723537236372373723837239372403724137242372433724437245372463724737248372493725037251372523725337254372553725637257372583725937260372613726237263372643726537266372673726837269372703727137272372733727437275372763727737278372793728037281372823728337284372853728637287372883728937290372913729237293372943729537296372973729837299373003730137302373033730437305373063730737308373093731037311373123731337314373153731637317373183731937320373213732237323373243732537326373273732837329373303733137332373333733437335373363733737338373393734037341373423734337344373453734637347373483734937350373513735237353373543735537356373573735837359373603736137362373633736437365373663736737368373693737037371373723737337374373753737637377373783737937380373813738237383373843738537386373873738837389373903739137392373933739437395373963739737398373993740037401374023740337404374053740637407374083740937410374113741237413374143741537416374173741837419374203742137422374233742437425374263742737428374293743037431374323743337434374353743637437374383743937440374413744237443374443744537446374473744837449374503745137452374533745437455374563745737458374593746037461374623746337464374653746637467374683746937470374713747237473374743747537476374773747837479374803748137482374833748437485374863748737488374893749037491374923749337494374953749637497374983749937500375013750237503375043750537506375073750837509375103751137512375133751437515375163751737518375193752037521375223752337524375253752637527375283752937530375313753237533375343753537536375373753837539375403754137542375433754437545375463754737548375493755037551375523755337554375553755637557375583755937560375613756237563375643756537566375673756837569375703757137572375733757437575375763757737578375793758037581375823758337584375853758637587375883758937590375913759237593375943759537596375973759837599376003760137602376033760437605376063760737608376093761037611376123761337614376153761637617376183761937620376213762237623376243762537626376273762837629376303763137632376333763437635376363763737638376393764037641376423764337644376453764637647376483764937650376513765237653376543765537656376573765837659376603766137662376633766437665376663766737668376693767037671376723767337674376753767637677376783767937680376813768237683376843768537686376873768837689376903769137692376933769437695376963769737698376993770037701377023770337704377053770637707377083770937710377113771237713377143771537716377173771837719377203772137722377233772437725377263772737728377293773037731377323773337734377353773637737377383773937740377413774237743377443774537746377473774837749377503775137752377533775437755377563775737758377593776037761377623776337764377653776637767377683776937770377713777237773377743777537776377773777837779377803778137782377833778437785377863778737788377893779037791377923779337794377953779637797377983779937800378013780237803378043780537806378073780837809378103781137812378133781437815378163781737818378193782037821378223782337824378253782637827378283782937830378313783237833378343783537836378373783837839378403784137842378433784437845378463784737848378493785037851378523785337854378553785637857378583785937860378613786237863378643786537866378673786837869378703787137872378733787437875378763787737878378793788037881378823788337884378853788637887378883788937890378913789237893378943789537896378973789837899379003790137902379033790437905379063790737908379093791037911379123791337914379153791637917379183791937920379213792237923379243792537926379273792837929379303793137932379333793437935379363793737938379393794037941379423794337944379453794637947379483794937950379513795237953379543795537956379573795837959379603796137962379633796437965379663796737968379693797037971379723797337974379753797637977379783797937980379813798237983379843798537986379873798837989379903799137992379933799437995379963799737998379993800038001380023800338004380053800638007380083800938010380113801238013380143801538016380173801838019380203802138022380233802438025380263802738028380293803038031380323803338034380353803638037380383803938040380413804238043380443804538046380473804838049380503805138052380533805438055380563805738058380593806038061380623806338064380653806638067380683806938070380713807238073380743807538076380773807838079380803808138082380833808438085380863808738088380893809038091380923809338094380953809638097380983809938100381013810238103381043810538106381073810838109381103811138112381133811438115381163811738118381193812038121381223812338124381253812638127381283812938130381313813238133381343813538136381373813838139381403814138142381433814438145381463814738148381493815038151381523815338154381553815638157381583815938160381613816238163381643816538166381673816838169381703817138172381733817438175381763817738178381793818038181381823818338184381853818638187381883818938190381913819238193381943819538196381973819838199382003820138202382033820438205382063820738208382093821038211382123821338214382153821638217382183821938220382213822238223382243822538226382273822838229382303823138232382333823438235382363823738238382393824038241382423824338244382453824638247382483824938250382513825238253382543825538256382573825838259382603826138262382633826438265382663826738268382693827038271382723827338274382753827638277382783827938280382813828238283382843828538286382873828838289382903829138292382933829438295382963829738298382993830038301383023830338304383053830638307383083830938310383113831238313383143831538316383173831838319383203832138322383233832438325383263832738328383293833038331383323833338334383353833638337383383833938340383413834238343383443834538346383473834838349383503835138352383533835438355383563835738358383593836038361383623836338364383653836638367383683836938370383713837238373383743837538376383773837838379383803838138382383833838438385383863838738388383893839038391383923839338394383953839638397383983839938400384013840238403384043840538406384073840838409384103841138412384133841438415384163841738418384193842038421384223842338424384253842638427384283842938430384313843238433384343843538436384373843838439384403844138442384433844438445384463844738448384493845038451384523845338454384553845638457384583845938460384613846238463384643846538466384673846838469384703847138472384733847438475384763847738478384793848038481384823848338484384853848638487384883848938490384913849238493384943849538496384973849838499385003850138502385033850438505385063850738508385093851038511385123851338514385153851638517385183851938520385213852238523385243852538526385273852838529385303853138532385333853438535385363853738538385393854038541385423854338544385453854638547385483854938550385513855238553385543855538556385573855838559385603856138562385633856438565385663856738568385693857038571385723857338574385753857638577385783857938580385813858238583385843858538586385873858838589385903859138592385933859438595385963859738598385993860038601386023860338604386053860638607386083860938610386113861238613386143861538616386173861838619386203862138622386233862438625386263862738628386293863038631386323863338634386353863638637386383863938640386413864238643386443864538646386473864838649386503865138652386533865438655386563865738658386593866038661386623866338664386653866638667386683866938670386713867238673386743867538676386773867838679386803868138682386833868438685386863868738688386893869038691386923869338694386953869638697386983869938700387013870238703387043870538706387073870838709387103871138712387133871438715387163871738718387193872038721387223872338724387253872638727387283872938730387313873238733387343873538736387373873838739387403874138742387433874438745387463874738748387493875038751387523875338754387553875638757387583875938760387613876238763387643876538766387673876838769387703877138772387733877438775387763877738778387793878038781387823878338784387853878638787387883878938790387913879238793387943879538796387973879838799388003880138802388033880438805388063880738808388093881038811388123881338814388153881638817388183881938820388213882238823388243882538826388273882838829388303883138832388333883438835388363883738838388393884038841388423884338844388453884638847388483884938850388513885238853388543885538856388573885838859388603886138862388633886438865388663886738868388693887038871388723887338874388753887638877388783887938880388813888238883388843888538886388873888838889388903889138892388933889438895388963889738898388993890038901389023890338904389053890638907389083890938910389113891238913389143891538916389173891838919389203892138922389233892438925389263892738928389293893038931389323893338934389353893638937389383893938940389413894238943389443894538946389473894838949389503895138952389533895438955389563895738958389593896038961389623896338964389653896638967389683896938970389713897238973389743897538976389773897838979389803898138982389833898438985389863898738988389893899038991389923899338994389953899638997389983899939000390013900239003390043900539006390073900839009390103901139012390133901439015390163901739018390193902039021390223902339024390253902639027390283902939030390313903239033390343903539036390373903839039390403904139042390433904439045390463904739048390493905039051390523905339054390553905639057390583905939060390613906239063390643906539066390673906839069390703907139072390733907439075390763907739078390793908039081390823908339084390853908639087390883908939090390913909239093390943909539096390973909839099391003910139102391033910439105391063910739108391093911039111391123911339114391153911639117391183911939120391213912239123391243912539126391273912839129391303913139132391333913439135391363913739138391393914039141391423914339144391453914639147391483914939150391513915239153391543915539156391573915839159391603916139162391633916439165391663916739168391693917039171391723917339174391753917639177391783917939180391813918239183391843918539186391873918839189391903919139192391933919439195391963919739198391993920039201392023920339204392053920639207392083920939210392113921239213392143921539216392173921839219392203922139222392233922439225392263922739228392293923039231392323923339234392353923639237392383923939240392413924239243392443924539246392473924839249392503925139252392533925439255392563925739258392593926039261392623926339264392653926639267392683926939270392713927239273392743927539276392773927839279392803928139282392833928439285392863928739288392893929039291392923929339294392953929639297392983929939300393013930239303393043930539306393073930839309393103931139312393133931439315393163931739318393193932039321393223932339324393253932639327393283932939330393313933239333393343933539336393373933839339393403934139342393433934439345393463934739348393493935039351393523935339354393553935639357393583935939360393613936239363393643936539366393673936839369393703937139372393733937439375393763937739378393793938039381393823938339384393853938639387393883938939390393913939239393393943939539396393973939839399394003940139402394033940439405394063940739408394093941039411394123941339414394153941639417394183941939420394213942239423394243942539426394273942839429394303943139432394333943439435394363943739438394393944039441394423944339444394453944639447394483944939450394513945239453394543945539456394573945839459394603946139462394633946439465394663946739468394693947039471394723947339474394753947639477394783947939480394813948239483394843948539486394873948839489394903949139492394933949439495394963949739498394993950039501395023950339504395053950639507395083950939510395113951239513395143951539516395173951839519395203952139522395233952439525395263952739528395293953039531395323953339534395353953639537395383953939540395413954239543395443954539546395473954839549395503955139552395533955439555395563955739558395593956039561395623956339564395653956639567395683956939570395713957239573395743957539576395773957839579395803958139582395833958439585395863958739588395893959039591395923959339594395953959639597395983959939600396013960239603396043960539606396073960839609396103961139612396133961439615396163961739618396193962039621396223962339624396253962639627396283962939630396313963239633396343963539636396373963839639396403964139642396433964439645396463964739648396493965039651396523965339654396553965639657396583965939660396613966239663396643966539666396673966839669396703967139672396733967439675396763967739678396793968039681396823968339684396853968639687396883968939690396913969239693396943969539696396973969839699397003970139702397033970439705397063970739708397093971039711397123971339714397153971639717397183971939720397213972239723397243972539726397273972839729397303973139732397333973439735397363973739738397393974039741397423974339744397453974639747397483974939750397513975239753397543975539756397573975839759397603976139762397633976439765397663976739768397693977039771397723977339774397753977639777397783977939780397813978239783397843978539786397873978839789397903979139792397933979439795397963979739798397993980039801398023980339804398053980639807398083980939810398113981239813398143981539816398173981839819398203982139822398233982439825398263982739828398293983039831398323983339834398353983639837398383983939840398413984239843398443984539846398473984839849398503985139852398533985439855398563985739858398593986039861398623986339864398653986639867398683986939870398713987239873398743987539876398773987839879398803988139882398833988439885398863988739888398893989039891398923989339894398953989639897398983989939900399013990239903399043990539906399073990839909399103991139912399133991439915399163991739918399193992039921399223992339924399253992639927399283992939930399313993239933399343993539936399373993839939399403994139942399433994439945399463994739948399493995039951399523995339954399553995639957399583995939960399613996239963399643996539966399673996839969399703997139972399733997439975399763997739978399793998039981399823998339984399853998639987399883998939990399913999239993399943999539996399973999839999400004000140002400034000440005400064000740008400094001040011400124001340014400154001640017400184001940020400214002240023400244002540026400274002840029400304003140032400334003440035400364003740038400394004040041400424004340044400454004640047400484004940050400514005240053400544005540056400574005840059400604006140062400634006440065400664006740068400694007040071400724007340074400754007640077400784007940080400814008240083400844008540086400874008840089400904009140092400934009440095400964009740098400994010040101401024010340104401054010640107401084010940110401114011240113401144011540116401174011840119401204012140122401234012440125401264012740128401294013040131401324013340134401354013640137401384013940140401414014240143401444014540146401474014840149401504015140152401534015440155401564015740158401594016040161401624016340164401654016640167401684016940170401714017240173401744017540176401774017840179401804018140182401834018440185401864018740188401894019040191401924019340194401954019640197401984019940200402014020240203402044020540206402074020840209402104021140212402134021440215402164021740218402194022040221402224022340224402254022640227402284022940230402314023240233402344023540236402374023840239402404024140242402434024440245402464024740248402494025040251402524025340254402554025640257402584025940260402614026240263402644026540266402674026840269402704027140272402734027440275402764027740278402794028040281402824028340284402854028640287402884028940290402914029240293402944029540296402974029840299403004030140302403034030440305403064030740308403094031040311403124031340314403154031640317403184031940320403214032240323403244032540326403274032840329403304033140332403334033440335403364033740338403394034040341403424034340344403454034640347403484034940350403514035240353403544035540356403574035840359403604036140362403634036440365403664036740368403694037040371403724037340374403754037640377403784037940380403814038240383403844038540386403874038840389403904039140392403934039440395403964039740398403994040040401404024040340404404054040640407404084040940410404114041240413404144041540416404174041840419404204042140422404234042440425404264042740428404294043040431404324043340434404354043640437404384043940440404414044240443404444044540446404474044840449404504045140452404534045440455404564045740458404594046040461404624046340464404654046640467404684046940470404714047240473404744047540476404774047840479404804048140482404834048440485404864048740488404894049040491404924049340494404954049640497404984049940500405014050240503405044050540506405074050840509405104051140512405134051440515405164051740518405194052040521405224052340524405254052640527405284052940530405314053240533405344053540536405374053840539405404054140542405434054440545405464054740548405494055040551405524055340554405554055640557405584055940560405614056240563405644056540566405674056840569405704057140572405734057440575405764057740578405794058040581405824058340584405854058640587405884058940590405914059240593405944059540596405974059840599406004060140602406034060440605406064060740608406094061040611406124061340614406154061640617406184061940620406214062240623406244062540626406274062840629406304063140632406334063440635406364063740638406394064040641406424064340644406454064640647406484064940650406514065240653406544065540656406574065840659406604066140662406634066440665406664066740668406694067040671406724067340674406754067640677406784067940680406814068240683406844068540686406874068840689406904069140692406934069440695406964069740698406994070040701407024070340704407054070640707407084070940710407114071240713407144071540716407174071840719407204072140722407234072440725407264072740728407294073040731407324073340734407354073640737407384073940740407414074240743407444074540746407474074840749407504075140752407534075440755407564075740758407594076040761407624076340764407654076640767407684076940770407714077240773407744077540776407774077840779407804078140782407834078440785407864078740788407894079040791407924079340794407954079640797407984079940800408014080240803408044080540806408074080840809408104081140812408134081440815408164081740818408194082040821408224082340824408254082640827408284082940830408314083240833408344083540836408374083840839408404084140842408434084440845408464084740848408494085040851408524085340854408554085640857408584085940860408614086240863408644086540866408674086840869408704087140872408734087440875408764087740878408794088040881408824088340884408854088640887408884088940890408914089240893408944089540896408974089840899409004090140902409034090440905409064090740908409094091040911409124091340914409154091640917409184091940920409214092240923409244092540926409274092840929409304093140932409334093440935409364093740938409394094040941409424094340944409454094640947409484094940950409514095240953409544095540956409574095840959409604096140962409634096440965409664096740968409694097040971409724097340974409754097640977409784097940980409814098240983409844098540986409874098840989409904099140992409934099440995409964099740998409994100041001410024100341004410054100641007410084100941010410114101241013410144101541016410174101841019410204102141022410234102441025410264102741028410294103041031410324103341034410354103641037410384103941040410414104241043410444104541046410474104841049410504105141052410534105441055410564105741058410594106041061410624106341064410654106641067410684106941070410714107241073410744107541076410774107841079410804108141082410834108441085410864108741088410894109041091410924109341094410954109641097410984109941100411014110241103411044110541106411074110841109411104111141112411134111441115411164111741118411194112041121411224112341124411254112641127411284112941130411314113241133411344113541136411374113841139411404114141142411434114441145411464114741148411494115041151411524115341154411554115641157411584115941160411614116241163411644116541166411674116841169411704117141172411734117441175411764117741178411794118041181411824118341184411854118641187411884118941190411914119241193411944119541196411974119841199412004120141202412034120441205412064120741208412094121041211412124121341214412154121641217412184121941220412214122241223412244122541226412274122841229412304123141232412334123441235412364123741238412394124041241412424124341244412454124641247412484124941250412514125241253412544125541256412574125841259412604126141262412634126441265412664126741268412694127041271412724127341274412754127641277412784127941280412814128241283412844128541286412874128841289412904129141292412934129441295412964129741298412994130041301413024130341304413054130641307413084130941310413114131241313413144131541316413174131841319413204132141322413234132441325413264132741328413294133041331413324133341334413354133641337413384133941340413414134241343413444134541346413474134841349413504135141352413534135441355413564135741358413594136041361413624136341364413654136641367413684136941370413714137241373413744137541376413774137841379413804138141382413834138441385413864138741388413894139041391413924139341394413954139641397413984139941400414014140241403414044140541406414074140841409414104141141412414134141441415414164141741418414194142041421414224142341424414254142641427414284142941430414314143241433414344143541436414374143841439414404144141442414434144441445414464144741448414494145041451414524145341454414554145641457414584145941460414614146241463414644146541466414674146841469414704147141472414734147441475414764147741478414794148041481414824148341484414854148641487414884148941490414914149241493414944149541496414974149841499415004150141502415034150441505415064150741508415094151041511415124151341514415154151641517415184151941520415214152241523415244152541526415274152841529415304153141532415334153441535415364153741538415394154041541415424154341544415454154641547415484154941550415514155241553415544155541556415574155841559415604156141562415634156441565415664156741568415694157041571415724157341574415754157641577415784157941580415814158241583415844158541586415874158841589415904159141592415934159441595415964159741598415994160041601416024160341604416054160641607416084160941610416114161241613416144161541616416174161841619416204162141622416234162441625416264162741628416294163041631416324163341634416354163641637416384163941640416414164241643416444164541646416474164841649416504165141652416534165441655416564165741658416594166041661416624166341664416654166641667416684166941670416714167241673416744167541676416774167841679416804168141682416834168441685416864168741688416894169041691416924169341694416954169641697416984169941700417014170241703417044170541706417074170841709417104171141712417134171441715417164171741718417194172041721417224172341724417254172641727417284172941730417314173241733417344173541736417374173841739417404174141742417434174441745417464174741748417494175041751417524175341754417554175641757417584175941760417614176241763417644176541766417674176841769417704177141772417734177441775417764177741778417794178041781417824178341784417854178641787417884178941790417914179241793417944179541796417974179841799418004180141802418034180441805418064180741808418094181041811418124181341814418154181641817418184181941820418214182241823418244182541826418274182841829418304183141832418334183441835418364183741838418394184041841418424184341844418454184641847418484184941850418514185241853418544185541856418574185841859418604186141862418634186441865418664186741868418694187041871418724187341874418754187641877418784187941880418814188241883418844188541886418874188841889418904189141892418934189441895418964189741898418994190041901419024190341904419054190641907419084190941910419114191241913419144191541916419174191841919419204192141922419234192441925419264192741928419294193041931419324193341934419354193641937419384193941940419414194241943419444194541946419474194841949419504195141952419534195441955419564195741958419594196041961419624196341964419654196641967419684196941970419714197241973419744197541976419774197841979419804198141982419834198441985419864198741988419894199041991419924199341994419954199641997419984199942000420014200242003420044200542006420074200842009420104201142012420134201442015420164201742018420194202042021420224202342024420254202642027420284202942030420314203242033420344203542036420374203842039420404204142042420434204442045420464204742048420494205042051420524205342054420554205642057420584205942060420614206242063420644206542066420674206842069420704207142072420734207442075420764207742078420794208042081420824208342084420854208642087420884208942090420914209242093420944209542096420974209842099421004210142102421034210442105421064210742108421094211042111421124211342114421154211642117421184211942120421214212242123421244212542126421274212842129421304213142132421334213442135421364213742138421394214042141421424214342144421454214642147421484214942150421514215242153421544215542156421574215842159421604216142162421634216442165421664216742168421694217042171421724217342174421754217642177421784217942180421814218242183421844218542186421874218842189421904219142192421934219442195421964219742198421994220042201422024220342204422054220642207422084220942210422114221242213422144221542216422174221842219422204222142222422234222442225422264222742228422294223042231422324223342234422354223642237422384223942240422414224242243422444224542246422474224842249422504225142252422534225442255422564225742258422594226042261422624226342264422654226642267422684226942270422714227242273422744227542276422774227842279422804228142282422834228442285422864228742288422894229042291422924229342294422954229642297422984229942300423014230242303423044230542306423074230842309423104231142312423134231442315423164231742318423194232042321423224232342324423254232642327423284232942330423314233242333423344233542336423374233842339423404234142342423434234442345423464234742348423494235042351423524235342354423554235642357423584235942360423614236242363423644236542366423674236842369423704237142372423734237442375423764237742378423794238042381423824238342384423854238642387423884238942390423914239242393423944239542396423974239842399424004240142402424034240442405424064240742408424094241042411424124241342414424154241642417424184241942420424214242242423424244242542426424274242842429424304243142432424334243442435424364243742438424394244042441424424244342444424454244642447424484244942450424514245242453424544245542456424574245842459424604246142462424634246442465424664246742468424694247042471424724247342474424754247642477424784247942480424814248242483424844248542486424874248842489424904249142492424934249442495424964249742498424994250042501425024250342504425054250642507425084250942510425114251242513425144251542516425174251842519425204252142522425234252442525425264252742528425294253042531425324253342534425354253642537425384253942540425414254242543425444254542546425474254842549425504255142552425534255442555425564255742558425594256042561425624256342564425654256642567425684256942570425714257242573425744257542576425774257842579425804258142582425834258442585425864258742588425894259042591425924259342594425954259642597425984259942600426014260242603426044260542606426074260842609426104261142612426134261442615426164261742618426194262042621426224262342624426254262642627426284262942630426314263242633426344263542636426374263842639426404264142642426434264442645426464264742648426494265042651426524265342654426554265642657426584265942660426614266242663426644266542666426674266842669426704267142672426734267442675426764267742678426794268042681426824268342684426854268642687426884268942690426914269242693426944269542696426974269842699427004270142702427034270442705427064270742708427094271042711427124271342714427154271642717427184271942720427214272242723427244272542726427274272842729427304273142732427334273442735427364273742738427394274042741427424274342744427454274642747427484274942750427514275242753427544275542756427574275842759427604276142762427634276442765427664276742768427694277042771427724277342774427754277642777427784277942780427814278242783427844278542786427874278842789427904279142792427934279442795427964279742798427994280042801428024280342804428054280642807428084280942810428114281242813428144281542816428174281842819428204282142822428234282442825428264282742828428294283042831428324283342834428354283642837428384283942840428414284242843428444284542846428474284842849428504285142852428534285442855428564285742858428594286042861428624286342864428654286642867428684286942870428714287242873428744287542876428774287842879428804288142882428834288442885428864288742888428894289042891428924289342894428954289642897428984289942900429014290242903429044290542906429074290842909429104291142912429134291442915429164291742918429194292042921429224292342924429254292642927429284292942930429314293242933429344293542936429374293842939429404294142942429434294442945429464294742948429494295042951429524295342954429554295642957429584295942960429614296242963429644296542966429674296842969429704297142972429734297442975429764297742978429794298042981429824298342984429854298642987429884298942990429914299242993429944299542996429974299842999430004300143002430034300443005430064300743008430094301043011430124301343014430154301643017430184301943020430214302243023430244302543026430274302843029430304303143032430334303443035430364303743038430394304043041430424304343044430454304643047430484304943050430514305243053430544305543056430574305843059430604306143062430634306443065430664306743068430694307043071430724307343074430754307643077430784307943080430814308243083430844308543086430874308843089430904309143092430934309443095430964309743098430994310043101431024310343104431054310643107431084310943110431114311243113431144311543116431174311843119431204312143122431234312443125431264312743128431294313043131431324313343134431354313643137431384313943140431414314243143431444314543146431474314843149431504315143152431534315443155431564315743158431594316043161431624316343164431654316643167431684316943170431714317243173431744317543176431774317843179431804318143182431834318443185431864318743188431894319043191431924319343194431954319643197431984319943200432014320243203432044320543206432074320843209432104321143212432134321443215432164321743218432194322043221432224322343224432254322643227432284322943230432314323243233432344323543236432374323843239432404324143242432434324443245432464324743248432494325043251432524325343254432554325643257432584325943260432614326243263432644326543266432674326843269432704327143272432734327443275432764327743278432794328043281432824328343284432854328643287432884328943290432914329243293432944329543296432974329843299433004330143302433034330443305433064330743308433094331043311433124331343314433154331643317433184331943320433214332243323433244332543326433274332843329433304333143332433334333443335433364333743338433394334043341433424334343344433454334643347433484334943350433514335243353433544335543356433574335843359433604336143362433634336443365433664336743368433694337043371433724337343374433754337643377433784337943380433814338243383433844338543386433874338843389433904339143392433934339443395433964339743398433994340043401434024340343404434054340643407434084340943410434114341243413434144341543416434174341843419434204342143422434234342443425434264342743428434294343043431434324343343434434354343643437434384343943440434414344243443434444344543446434474344843449434504345143452434534345443455434564345743458434594346043461434624346343464434654346643467434684346943470434714347243473434744347543476434774347843479434804348143482434834348443485434864348743488434894349043491434924349343494434954349643497434984349943500435014350243503435044350543506435074350843509435104351143512435134351443515435164351743518435194352043521435224352343524435254352643527435284352943530435314353243533435344353543536435374353843539435404354143542435434354443545435464354743548435494355043551435524355343554435554355643557435584355943560435614356243563435644356543566435674356843569435704357143572435734357443575435764357743578435794358043581435824358343584435854358643587435884358943590435914359243593435944359543596435974359843599436004360143602436034360443605436064360743608436094361043611436124361343614436154361643617436184361943620436214362243623436244362543626436274362843629436304363143632436334363443635436364363743638436394364043641436424364343644436454364643647436484364943650436514365243653436544365543656436574365843659436604366143662436634366443665436664366743668436694367043671436724367343674436754367643677436784367943680436814368243683436844368543686436874368843689436904369143692436934369443695436964369743698436994370043701437024370343704437054370643707437084370943710437114371243713437144371543716437174371843719437204372143722437234372443725437264372743728437294373043731437324373343734437354373643737437384373943740437414374243743437444374543746437474374843749437504375143752437534375443755437564375743758437594376043761437624376343764437654376643767437684376943770437714377243773437744377543776437774377843779437804378143782437834378443785437864378743788437894379043791437924379343794437954379643797437984379943800438014380243803438044380543806438074380843809438104381143812438134381443815438164381743818438194382043821438224382343824438254382643827438284382943830438314383243833438344383543836438374383843839438404384143842438434384443845438464384743848438494385043851438524385343854438554385643857438584385943860438614386243863438644386543866438674386843869438704387143872438734387443875438764387743878438794388043881438824388343884438854388643887438884388943890438914389243893438944389543896438974389843899439004390143902439034390443905439064390743908439094391043911439124391343914439154391643917439184391943920439214392243923439244392543926439274392843929439304393143932439334393443935439364393743938439394394043941439424394343944439454394643947439484394943950439514395243953439544395543956439574395843959439604396143962439634396443965439664396743968439694397043971439724397343974439754397643977439784397943980439814398243983439844398543986439874398843989439904399143992439934399443995439964399743998439994400044001440024400344004440054400644007440084400944010440114401244013440144401544016440174401844019440204402144022440234402444025440264402744028440294403044031440324403344034440354403644037440384403944040440414404244043440444404544046440474404844049440504405144052440534405444055440564405744058440594406044061440624406344064440654406644067440684406944070440714407244073440744407544076440774407844079440804408144082440834408444085440864408744088440894409044091440924409344094440954409644097440984409944100441014410244103441044410544106441074410844109441104411144112441134411444115441164411744118441194412044121441224412344124441254412644127441284412944130441314413244133441344413544136441374413844139441404414144142441434414444145441464414744148441494415044151441524415344154441554415644157441584415944160441614416244163441644416544166441674416844169441704417144172441734417444175441764417744178441794418044181441824418344184441854418644187441884418944190441914419244193441944419544196441974419844199442004420144202442034420444205442064420744208442094421044211442124421344214442154421644217442184421944220442214422244223442244422544226442274422844229442304423144232442334423444235442364423744238442394424044241442424424344244442454424644247442484424944250442514425244253442544425544256442574425844259442604426144262442634426444265442664426744268442694427044271442724427344274442754427644277442784427944280442814428244283442844428544286442874428844289442904429144292442934429444295442964429744298442994430044301443024430344304443054430644307443084430944310443114431244313443144431544316443174431844319443204432144322443234432444325443264432744328443294433044331443324433344334443354433644337443384433944340443414434244343443444434544346443474434844349443504435144352443534435444355443564435744358443594436044361443624436344364443654436644367443684436944370443714437244373443744437544376443774437844379443804438144382443834438444385443864438744388443894439044391443924439344394443954439644397443984439944400444014440244403444044440544406444074440844409444104441144412444134441444415444164441744418444194442044421444224442344424444254442644427444284442944430444314443244433444344443544436444374443844439444404444144442444434444444445444464444744448444494445044451444524445344454444554445644457444584445944460444614446244463444644446544466444674446844469444704447144472444734447444475444764447744478444794448044481444824448344484444854448644487444884448944490444914449244493444944449544496444974449844499445004450144502445034450444505445064450744508445094451044511445124451344514445154451644517445184451944520445214452244523445244452544526445274452844529445304453144532445334453444535445364453744538445394454044541445424454344544445454454644547445484454944550445514455244553445544455544556445574455844559445604456144562445634456444565445664456744568445694457044571445724457344574445754457644577445784457944580445814458244583445844458544586445874458844589445904459144592445934459444595445964459744598445994460044601446024460344604446054460644607446084460944610446114461244613446144461544616446174461844619446204462144622446234462444625446264462744628446294463044631446324463344634446354463644637446384463944640446414464244643446444464544646446474464844649446504465144652446534465444655446564465744658446594466044661446624466344664446654466644667446684466944670446714467244673446744467544676446774467844679446804468144682446834468444685446864468744688446894469044691446924469344694446954469644697446984469944700447014470244703447044470544706447074470844709447104471144712447134471444715447164471744718447194472044721447224472344724447254472644727447284472944730447314473244733447344473544736447374473844739447404474144742447434474444745447464474744748447494475044751447524475344754447554475644757447584475944760447614476244763447644476544766447674476844769447704477144772447734477444775447764477744778447794478044781447824478344784447854478644787447884478944790447914479244793447944479544796447974479844799448004480144802448034480444805448064480744808448094481044811448124481344814448154481644817448184481944820448214482244823448244482544826448274482844829448304483144832448334483444835448364483744838448394484044841448424484344844448454484644847448484484944850448514485244853448544485544856448574485844859448604486144862448634486444865448664486744868448694487044871448724487344874448754487644877448784487944880448814488244883448844488544886448874488844889448904489144892448934489444895448964489744898448994490044901449024490344904449054490644907449084490944910449114491244913449144491544916449174491844919449204492144922449234492444925449264492744928449294493044931449324493344934449354493644937449384493944940449414494244943449444494544946449474494844949449504495144952449534495444955449564495744958449594496044961449624496344964449654496644967449684496944970449714497244973449744497544976449774497844979449804498144982449834498444985449864498744988449894499044991449924499344994449954499644997449984499945000450014500245003450044500545006450074500845009450104501145012450134501445015450164501745018450194502045021450224502345024450254502645027450284502945030450314503245033450344503545036450374503845039450404504145042450434504445045450464504745048450494505045051450524505345054450554505645057450584505945060450614506245063450644506545066450674506845069450704507145072450734507445075450764507745078450794508045081450824508345084450854508645087450884508945090450914509245093450944509545096450974509845099451004510145102451034510445105451064510745108451094511045111451124511345114451154511645117451184511945120451214512245123451244512545126451274512845129451304513145132451334513445135451364513745138451394514045141451424514345144451454514645147451484514945150451514515245153451544515545156451574515845159451604516145162451634516445165451664516745168451694517045171451724517345174451754517645177451784517945180451814518245183451844518545186451874518845189451904519145192451934519445195451964519745198451994520045201452024520345204452054520645207452084520945210452114521245213452144521545216452174521845219452204522145222452234522445225452264522745228452294523045231452324523345234452354523645237452384523945240452414524245243452444524545246452474524845249452504525145252452534525445255452564525745258452594526045261452624526345264452654526645267452684526945270452714527245273452744527545276452774527845279452804528145282452834528445285452864528745288452894529045291452924529345294452954529645297452984529945300453014530245303453044530545306453074530845309453104531145312453134531445315453164531745318453194532045321453224532345324453254532645327453284532945330453314533245333453344533545336453374533845339453404534145342453434534445345453464534745348453494535045351453524535345354453554535645357453584535945360453614536245363453644536545366453674536845369453704537145372453734537445375453764537745378453794538045381453824538345384453854538645387453884538945390453914539245393453944539545396453974539845399454004540145402454034540445405454064540745408454094541045411454124541345414454154541645417454184541945420454214542245423454244542545426454274542845429454304543145432454334543445435454364543745438454394544045441454424544345444454454544645447454484544945450454514545245453454544545545456454574545845459454604546145462454634546445465454664546745468454694547045471454724547345474454754547645477454784547945480454814548245483454844548545486454874548845489454904549145492454934549445495454964549745498454994550045501455024550345504455054550645507455084550945510455114551245513455144551545516455174551845519455204552145522455234552445525455264552745528455294553045531455324553345534455354553645537455384553945540455414554245543455444554545546455474554845549455504555145552455534555445555455564555745558455594556045561455624556345564455654556645567455684556945570455714557245573455744557545576455774557845579455804558145582455834558445585455864558745588455894559045591455924559345594455954559645597455984559945600456014560245603456044560545606456074560845609456104561145612456134561445615456164561745618456194562045621456224562345624456254562645627456284562945630456314563245633456344563545636456374563845639456404564145642456434564445645456464564745648456494565045651456524565345654456554565645657456584565945660456614566245663456644566545666456674566845669456704567145672456734567445675456764567745678456794568045681456824568345684456854568645687456884568945690456914569245693456944569545696456974569845699457004570145702457034570445705457064570745708457094571045711457124571345714457154571645717457184571945720457214572245723457244572545726457274572845729457304573145732457334573445735457364573745738457394574045741457424574345744457454574645747457484574945750457514575245753457544575545756457574575845759457604576145762457634576445765457664576745768457694577045771457724577345774457754577645777457784577945780457814578245783457844578545786457874578845789457904579145792457934579445795457964579745798457994580045801458024580345804458054580645807458084580945810458114581245813458144581545816458174581845819458204582145822458234582445825458264582745828458294583045831458324583345834458354583645837458384583945840458414584245843458444584545846458474584845849458504585145852458534585445855458564585745858458594586045861458624586345864458654586645867458684586945870458714587245873458744587545876458774587845879458804588145882458834588445885458864588745888458894589045891458924589345894458954589645897458984589945900459014590245903459044590545906459074590845909459104591145912459134591445915459164591745918459194592045921459224592345924459254592645927459284592945930459314593245933459344593545936459374593845939459404594145942459434594445945459464594745948459494595045951459524595345954459554595645957459584595945960459614596245963459644596545966459674596845969459704597145972459734597445975459764597745978459794598045981459824598345984459854598645987459884598945990459914599245993459944599545996459974599845999460004600146002460034600446005460064600746008460094601046011460124601346014460154601646017460184601946020460214602246023460244602546026460274602846029460304603146032460334603446035460364603746038460394604046041460424604346044460454604646047460484604946050460514605246053460544605546056460574605846059460604606146062460634606446065460664606746068460694607046071460724607346074460754607646077460784607946080460814608246083460844608546086460874608846089460904609146092460934609446095460964609746098460994610046101461024610346104461054610646107461084610946110461114611246113461144611546116461174611846119461204612146122461234612446125461264612746128461294613046131461324613346134461354613646137461384613946140461414614246143461444614546146461474614846149461504615146152461534615446155461564615746158461594616046161461624616346164461654616646167461684616946170461714617246173461744617546176461774617846179461804618146182461834618446185461864618746188461894619046191461924619346194461954619646197461984619946200462014620246203462044620546206462074620846209462104621146212462134621446215462164621746218462194622046221462224622346224462254622646227462284622946230462314623246233462344623546236462374623846239462404624146242462434624446245462464624746248462494625046251462524625346254462554625646257462584625946260462614626246263462644626546266462674626846269462704627146272462734627446275462764627746278462794628046281462824628346284462854628646287462884628946290462914629246293462944629546296462974629846299463004630146302463034630446305463064630746308463094631046311463124631346314463154631646317463184631946320463214632246323463244632546326463274632846329463304633146332463334633446335463364633746338463394634046341463424634346344463454634646347463484634946350463514635246353463544635546356463574635846359463604636146362463634636446365463664636746368463694637046371463724637346374463754637646377463784637946380463814638246383463844638546386463874638846389463904639146392463934639446395463964639746398463994640046401464024640346404464054640646407464084640946410464114641246413464144641546416464174641846419464204642146422464234642446425464264642746428464294643046431464324643346434464354643646437464384643946440464414644246443464444644546446464474644846449464504645146452464534645446455464564645746458464594646046461464624646346464464654646646467464684646946470464714647246473464744647546476464774647846479464804648146482464834648446485464864648746488464894649046491464924649346494464954649646497464984649946500465014650246503465044650546506465074650846509465104651146512465134651446515465164651746518465194652046521465224652346524465254652646527465284652946530465314653246533465344653546536465374653846539465404654146542465434654446545465464654746548465494655046551465524655346554465554655646557465584655946560465614656246563465644656546566465674656846569465704657146572465734657446575465764657746578465794658046581465824658346584465854658646587465884658946590465914659246593465944659546596465974659846599466004660146602466034660446605466064660746608466094661046611466124661346614466154661646617466184661946620466214662246623466244662546626466274662846629466304663146632466334663446635466364663746638466394664046641466424664346644466454664646647466484664946650466514665246653466544665546656466574665846659466604666146662466634666446665466664666746668466694667046671466724667346674466754667646677466784667946680466814668246683466844668546686466874668846689466904669146692466934669446695466964669746698466994670046701467024670346704467054670646707467084670946710467114671246713467144671546716467174671846719467204672146722467234672446725467264672746728467294673046731467324673346734467354673646737467384673946740467414674246743467444674546746467474674846749467504675146752467534675446755467564675746758467594676046761467624676346764467654676646767467684676946770467714677246773467744677546776467774677846779467804678146782467834678446785467864678746788467894679046791467924679346794467954679646797467984679946800468014680246803468044680546806468074680846809468104681146812468134681446815468164681746818468194682046821468224682346824468254682646827468284682946830468314683246833468344683546836468374683846839468404684146842468434684446845468464684746848468494685046851468524685346854468554685646857468584685946860468614686246863468644686546866468674686846869468704687146872468734687446875468764687746878468794688046881468824688346884468854688646887468884688946890468914689246893468944689546896468974689846899469004690146902469034690446905469064690746908469094691046911469124691346914469154691646917469184691946920469214692246923469244692546926469274692846929469304693146932469334693446935469364693746938469394694046941469424694346944469454694646947469484694946950469514695246953469544695546956469574695846959469604696146962469634696446965469664696746968469694697046971469724697346974469754697646977469784697946980469814698246983469844698546986469874698846989469904699146992469934699446995469964699746998469994700047001470024700347004470054700647007470084700947010470114701247013470144701547016470174701847019470204702147022470234702447025470264702747028470294703047031470324703347034470354703647037470384703947040470414704247043470444704547046470474704847049470504705147052470534705447055470564705747058470594706047061470624706347064470654706647067470684706947070470714707247073470744707547076470774707847079470804708147082470834708447085470864708747088470894709047091470924709347094470954709647097470984709947100471014710247103471044710547106471074710847109471104711147112471134711447115471164711747118471194712047121471224712347124471254712647127471284712947130471314713247133471344713547136471374713847139471404714147142471434714447145471464714747148471494715047151471524715347154471554715647157471584715947160471614716247163471644716547166471674716847169471704717147172471734717447175471764717747178471794718047181471824718347184471854718647187471884718947190471914719247193471944719547196471974719847199472004720147202472034720447205472064720747208472094721047211472124721347214472154721647217472184721947220472214722247223472244722547226472274722847229472304723147232472334723447235472364723747238472394724047241472424724347244472454724647247472484724947250472514725247253472544725547256472574725847259472604726147262472634726447265472664726747268472694727047271472724727347274472754727647277472784727947280472814728247283472844728547286472874728847289472904729147292472934729447295472964729747298472994730047301473024730347304473054730647307473084730947310473114731247313473144731547316473174731847319473204732147322473234732447325473264732747328473294733047331473324733347334473354733647337473384733947340473414734247343473444734547346473474734847349473504735147352473534735447355473564735747358473594736047361473624736347364473654736647367473684736947370473714737247373473744737547376473774737847379473804738147382473834738447385473864738747388473894739047391473924739347394473954739647397473984739947400474014740247403474044740547406474074740847409474104741147412474134741447415474164741747418474194742047421474224742347424474254742647427474284742947430474314743247433474344743547436474374743847439474404744147442474434744447445474464744747448474494745047451474524745347454474554745647457474584745947460474614746247463474644746547466474674746847469474704747147472474734747447475474764747747478474794748047481474824748347484474854748647487474884748947490474914749247493474944749547496474974749847499475004750147502475034750447505475064750747508475094751047511475124751347514475154751647517475184751947520475214752247523475244752547526475274752847529475304753147532475334753447535475364753747538475394754047541475424754347544475454754647547475484754947550475514755247553475544755547556475574755847559475604756147562475634756447565475664756747568475694757047571475724757347574475754757647577475784757947580475814758247583475844758547586475874758847589475904759147592475934759447595475964759747598475994760047601476024760347604476054760647607476084760947610476114761247613476144761547616476174761847619476204762147622476234762447625476264762747628476294763047631476324763347634476354763647637476384763947640476414764247643476444764547646476474764847649476504765147652476534765447655476564765747658476594766047661476624766347664476654766647667476684766947670476714767247673476744767547676476774767847679476804768147682476834768447685476864768747688476894769047691476924769347694476954769647697476984769947700477014770247703477044770547706477074770847709477104771147712477134771447715477164771747718477194772047721477224772347724477254772647727477284772947730477314773247733477344773547736477374773847739477404774147742477434774447745477464774747748477494775047751477524775347754477554775647757477584775947760477614776247763477644776547766477674776847769477704777147772477734777447775477764777747778477794778047781477824778347784477854778647787477884778947790477914779247793477944779547796477974779847799478004780147802478034780447805478064780747808478094781047811478124781347814478154781647817478184781947820478214782247823478244782547826478274782847829478304783147832478334783447835478364783747838478394784047841478424784347844478454784647847478484784947850478514785247853478544785547856478574785847859478604786147862478634786447865478664786747868478694787047871478724787347874478754787647877478784787947880478814788247883478844788547886478874788847889478904789147892478934789447895478964789747898478994790047901479024790347904479054790647907479084790947910479114791247913479144791547916479174791847919479204792147922479234792447925479264792747928479294793047931479324793347934479354793647937479384793947940479414794247943479444794547946479474794847949479504795147952479534795447955479564795747958479594796047961479624796347964479654796647967479684796947970479714797247973479744797547976479774797847979479804798147982479834798447985479864798747988479894799047991479924799347994479954799647997479984799948000480014800248003480044800548006480074800848009480104801148012480134801448015480164801748018480194802048021480224802348024480254802648027480284802948030480314803248033480344803548036480374803848039480404804148042480434804448045480464804748048480494805048051480524805348054480554805648057480584805948060480614806248063480644806548066480674806848069480704807148072480734807448075480764807748078480794808048081480824808348084480854808648087480884808948090480914809248093480944809548096480974809848099481004810148102481034810448105481064810748108481094811048111481124811348114481154811648117481184811948120481214812248123481244812548126481274812848129481304813148132481334813448135481364813748138481394814048141481424814348144481454814648147481484814948150481514815248153481544815548156481574815848159481604816148162481634816448165481664816748168481694817048171481724817348174481754817648177481784817948180481814818248183481844818548186481874818848189481904819148192481934819448195481964819748198481994820048201482024820348204482054820648207482084820948210482114821248213482144821548216482174821848219482204822148222482234822448225482264822748228482294823048231482324823348234482354823648237482384823948240482414824248243482444824548246482474824848249482504825148252482534825448255482564825748258482594826048261482624826348264482654826648267482684826948270482714827248273482744827548276482774827848279482804828148282482834828448285482864828748288482894829048291482924829348294482954829648297482984829948300483014830248303483044830548306483074830848309483104831148312483134831448315483164831748318483194832048321483224832348324483254832648327483284832948330483314833248333483344833548336483374833848339483404834148342483434834448345483464834748348483494835048351483524835348354483554835648357483584835948360483614836248363483644836548366483674836848369483704837148372483734837448375483764837748378483794838048381483824838348384483854838648387483884838948390483914839248393483944839548396483974839848399484004840148402484034840448405484064840748408484094841048411484124841348414484154841648417484184841948420484214842248423484244842548426484274842848429484304843148432484334843448435484364843748438484394844048441484424844348444484454844648447484484844948450484514845248453484544845548456484574845848459484604846148462484634846448465484664846748468484694847048471484724847348474484754847648477484784847948480484814848248483484844848548486484874848848489484904849148492484934849448495484964849748498484994850048501485024850348504485054850648507485084850948510485114851248513485144851548516485174851848519485204852148522485234852448525485264852748528485294853048531485324853348534485354853648537485384853948540485414854248543485444854548546485474854848549485504855148552485534855448555485564855748558485594856048561485624856348564485654856648567485684856948570485714857248573485744857548576485774857848579485804858148582485834858448585485864858748588485894859048591485924859348594485954859648597485984859948600486014860248603486044860548606486074860848609486104861148612486134861448615486164861748618486194862048621486224862348624486254862648627486284862948630486314863248633486344863548636486374863848639486404864148642486434864448645486464864748648486494865048651486524865348654486554865648657486584865948660486614866248663486644866548666486674866848669486704867148672486734867448675486764867748678486794868048681486824868348684486854868648687486884868948690486914869248693486944869548696486974869848699487004870148702487034870448705487064870748708487094871048711487124871348714487154871648717487184871948720487214872248723487244872548726487274872848729487304873148732487334873448735487364873748738487394874048741487424874348744487454874648747487484874948750487514875248753487544875548756487574875848759487604876148762487634876448765487664876748768487694877048771487724877348774487754877648777487784877948780487814878248783487844878548786487874878848789487904879148792487934879448795487964879748798487994880048801488024880348804488054880648807488084880948810488114881248813488144881548816488174881848819488204882148822488234882448825488264882748828488294883048831488324883348834488354883648837488384883948840488414884248843488444884548846488474884848849488504885148852488534885448855488564885748858488594886048861488624886348864488654886648867488684886948870488714887248873488744887548876488774887848879488804888148882488834888448885488864888748888488894889048891488924889348894488954889648897488984889948900489014890248903489044890548906489074890848909489104891148912489134891448915489164891748918489194892048921489224892348924489254892648927489284892948930489314893248933489344893548936489374893848939489404894148942489434894448945489464894748948489494895048951489524895348954489554895648957489584895948960489614896248963489644896548966489674896848969489704897148972489734897448975489764897748978489794898048981489824898348984489854898648987489884898948990489914899248993489944899548996489974899848999490004900149002490034900449005490064900749008490094901049011490124901349014490154901649017490184901949020490214902249023490244902549026490274902849029490304903149032490334903449035490364903749038490394904049041490424904349044490454904649047490484904949050490514905249053490544905549056490574905849059490604906149062490634906449065490664906749068490694907049071490724907349074490754907649077490784907949080490814908249083490844908549086490874908849089490904909149092490934909449095490964909749098490994910049101491024910349104491054910649107491084910949110491114911249113491144911549116491174911849119491204912149122491234912449125491264912749128491294913049131491324913349134491354913649137491384913949140491414914249143491444914549146491474914849149491504915149152491534915449155491564915749158491594916049161491624916349164491654916649167491684916949170491714917249173491744917549176491774917849179491804918149182491834918449185491864918749188491894919049191491924919349194491954919649197491984919949200492014920249203492044920549206492074920849209492104921149212492134921449215492164921749218492194922049221492224922349224492254922649227492284922949230492314923249233492344923549236492374923849239492404924149242492434924449245492464924749248492494925049251492524925349254492554925649257492584925949260492614926249263492644926549266492674926849269492704927149272492734927449275492764927749278492794928049281492824928349284492854928649287492884928949290492914929249293492944929549296492974929849299493004930149302493034930449305493064930749308493094931049311493124931349314493154931649317493184931949320493214932249323493244932549326493274932849329493304933149332493334933449335493364933749338493394934049341493424934349344493454934649347493484934949350493514935249353493544935549356493574935849359493604936149362493634936449365493664936749368493694937049371493724937349374493754937649377493784937949380493814938249383493844938549386493874938849389493904939149392493934939449395493964939749398493994940049401494024940349404494054940649407494084940949410494114941249413494144941549416494174941849419494204942149422494234942449425494264942749428494294943049431494324943349434494354943649437494384943949440494414944249443494444944549446494474944849449494504945149452494534945449455494564945749458494594946049461494624946349464494654946649467494684946949470494714947249473494744947549476494774947849479494804948149482494834948449485494864948749488494894949049491494924949349494494954949649497494984949949500495014950249503495044950549506495074950849509495104951149512495134951449515495164951749518495194952049521495224952349524495254952649527495284952949530495314953249533495344953549536495374953849539495404954149542495434954449545495464954749548495494955049551495524955349554495554955649557495584955949560495614956249563495644956549566495674956849569495704957149572495734957449575495764957749578495794958049581495824958349584495854958649587495884958949590495914959249593495944959549596495974959849599496004960149602496034960449605496064960749608496094961049611496124961349614496154961649617496184961949620496214962249623496244962549626496274962849629496304963149632496334963449635496364963749638496394964049641496424964349644496454964649647496484964949650496514965249653496544965549656496574965849659496604966149662496634966449665496664966749668496694967049671496724967349674496754967649677496784967949680496814968249683496844968549686496874968849689496904969149692496934969449695496964969749698496994970049701497024970349704497054970649707497084970949710497114971249713497144971549716497174971849719497204972149722497234972449725497264972749728497294973049731497324973349734497354973649737497384973949740497414974249743497444974549746497474974849749497504975149752497534975449755497564975749758497594976049761497624976349764497654976649767497684976949770497714977249773497744977549776497774977849779497804978149782497834978449785497864978749788497894979049791497924979349794497954979649797497984979949800498014980249803498044980549806498074980849809498104981149812498134981449815498164981749818498194982049821498224982349824498254982649827498284982949830498314983249833498344983549836498374983849839498404984149842498434984449845498464984749848498494985049851498524985349854498554985649857498584985949860498614986249863498644986549866498674986849869498704987149872498734987449875498764987749878498794988049881498824988349884498854988649887498884988949890498914989249893498944989549896498974989849899499004990149902499034990449905499064990749908499094991049911499124991349914499154991649917499184991949920499214992249923499244992549926499274992849929499304993149932499334993449935499364993749938499394994049941499424994349944499454994649947499484994949950499514995249953499544995549956499574995849959499604996149962499634996449965499664996749968499694997049971499724997349974499754997649977499784997949980499814998249983499844998549986499874998849989499904999149992499934999449995499964999749998499995000050001500025000350004500055000650007500085000950010500115001250013500145001550016500175001850019500205002150022500235002450025500265002750028500295003050031500325003350034500355003650037500385003950040500415004250043500445004550046500475004850049500505005150052500535005450055500565005750058500595006050061500625006350064500655006650067500685006950070500715007250073500745007550076500775007850079500805008150082500835008450085500865008750088500895009050091500925009350094500955009650097500985009950100501015010250103501045010550106501075010850109501105011150112501135011450115501165011750118501195012050121501225012350124501255012650127501285012950130501315013250133501345013550136501375013850139501405014150142501435014450145501465014750148501495015050151501525015350154501555015650157501585015950160501615016250163501645016550166501675016850169501705017150172
  1. From 536821c33d55b5d714910c015008d2cebd7dfef5 Mon Sep 17 00:00:00 2001
  2. From: Robert Ogden <robertogden@chromium.org>
  3. Date: Wed, 25 May 2022 11:03:46 -0700
  4. Subject: [PATCH 8/9] run clang format
  5. ---
  6. .../configuration/edgetpu_coral_plugin.cc | 20 +-
  7. .../edgetpu_coral_plugin_test.cc | 3 +-
  8. .../src/tensorflow_lite_support/c/common.cc | 2 +-
  9. .../src/tensorflow_lite_support/c/common.h | 4 +-
  10. .../tensorflow_lite_support/c/common_utils.cc | 11 +-
  11. .../tensorflow_lite_support/c/common_utils.h | 3 +-
  12. .../c/task/audio/audio_classifier.cc | 12 +-
  13. .../c/task/audio/audio_classifier.h | 12 +-
  14. .../c/task/audio/core/audio_buffer.h | 4 +-
  15. .../c/task/processor/classification_result.cc | 2 +-
  16. .../c/task/text/bert_nl_classifier.cc | 6 +-
  17. .../c/task/text/bert_nl_classifier.h | 6 +-
  18. .../c/task/text/bert_question_answerer.cc | 3 +-
  19. .../c/task/text/bert_question_answerer.h | 3 +-
  20. .../c/task/text/nl_classifier.cc | 3 +-
  21. .../c/task/text/nl_classifier.h | 3 +-
  22. .../c/task/vision/image_classifier.cc | 9 +-
  23. .../c/task/vision/image_classifier.h | 9 +-
  24. .../c/task/vision/image_segmenter.cc | 6 +-
  25. .../c/task/vision/image_segmenter.h | 6 +-
  26. .../c/task/vision/object_detector.cc | 6 +-
  27. .../c/task/vision/object_detector.h | 6 +-
  28. .../test/task/audio/audio_classifier_test.cc | 32 +-
  29. .../test/task/vision/image_classifier_test.cc | 84 +-
  30. .../test/task/vision/image_segmenter_test.cc | 62 +-
  31. .../test/task/vision/object_detector_test.cc | 90 +-
  32. .../src/tensorflow_lite_support/cc/common.cc | 2 +-
  33. .../src/tensorflow_lite_support/cc/common.h | 5 +-
  34. .../cc/port/default/status_macros.h | 2 +-
  35. .../cc/port/default/statusor_internals.h | 38 +-
  36. .../cc/port/default/tflite_wrapper.cc | 9 +-
  37. .../cc/port/default/tflite_wrapper.h | 2 +-
  38. .../cc/port/integral_types.h | 2 +-
  39. .../cc/task/audio/audio_classifier.cc | 2 +-
  40. .../cc/task/audio/audio_embedder.cc | 3 +-
  41. .../cc/task/audio/audio_embedder.h | 9 +-
  42. .../cc/task/audio/core/audio_buffer.h | 10 +-
  43. .../cc/task/audio/utils/audio_utils.cc | 3 +-
  44. .../cc/task/audio/utils/audio_utils.h | 3 +-
  45. .../cc/task/audio/utils/wav_io.cc | 19 +-
  46. .../cc/task/audio/utils/wav_io.h | 6 +-
  47. .../cc/task/core/base_task_api.h | 2 +-
  48. .../cc/task/core/classification_head.h | 2 +-
  49. .../cc/task/core/error_reporter.cc | 8 +-
  50. .../cc/task/core/external_file_handler.cc | 7 +-
  51. .../cc/task/core/external_file_handler.h | 3 +-
  52. .../cc/task/core/label_map_item.cc | 5 +-
  53. .../cc/task/core/label_map_item.h | 7 +-
  54. .../cc/task/core/score_calibration.cc | 8 +-
  55. .../cc/task/core/score_calibration.h | 11 +-
  56. .../cc/task/core/task_api_factory.h | 8 +-
  57. .../cc/task/core/task_utils.h | 30 +-
  58. .../cc/task/core/tflite_engine.cc | 14 +-
  59. .../cc/task/core/tflite_engine.h | 13 +-
  60. .../cc/task/processor/audio_preprocessor.cc | 5 +-
  61. .../processor/classification_postprocessor.cc | 5 +-
  62. .../task/processor/embedding_postprocessor.h | 10 +-
  63. .../cc/task/processor/image_preprocessor.cc | 6 +-
  64. .../cc/task/processor/processor.h | 5 +-
  65. .../cc/task/processor/regex_preprocessor.cc | 3 +-
  66. .../cc/task/processor/regex_preprocessor.h | 3 +-
  67. .../cc/task/processor/search_postprocessor.cc | 40 +-
  68. .../cc/task/processor/search_postprocessor.h | 37 +-
  69. .../cc/task/text/bert_clu_annotator.cc | 4 +-
  70. .../cc/task/text/bert_nl_classifier.cc | 3 +-
  71. .../cc/task/text/bert_nl_classifier.h | 2 +-
  72. .../cc/task/text/bert_question_answerer.cc | 32 +-
  73. .../cc/task/text/bert_question_answerer.h | 7 +-
  74. .../cc/task/text/clu_lib/bert_utils.cc | 14 +-
  75. .../cc/task/text/clu_lib/bert_utils.h | 7 +-
  76. .../cc/task/text/clu_lib/intent_repr.cc | 18 +-
  77. .../cc/task/text/clu_lib/intent_repr.h | 5 +-
  78. .../cc/task/text/clu_lib/slot_repr.cc | 32 +-
  79. .../cc/task/text/clu_lib/slot_repr.h | 9 +-
  80. .../task/text/clu_lib/slot_tagging_output.cc | 24 +-
  81. .../task/text/clu_lib/slot_tagging_output.h | 6 +-
  82. .../cc/task/text/clu_lib/tflite_modules.cc | 41 +-
  83. .../cc/task/text/clu_lib/tflite_modules.h | 17 +-
  84. .../cc/task/text/clu_lib/tflite_test_utils.cc | 14 +-
  85. .../cc/task/text/clu_lib/tflite_test_utils.h | 6 +-
  86. .../task/text/nlclassifier/nl_classifier.cc | 18 +-
  87. .../cc/task/text/nlclassifier/nl_classifier.h | 19 +-
  88. .../text/proto/text_searcher_options.proto | 1 -
  89. .../cc/task/text/question_answerer.h | 6 +-
  90. .../cc/task/text/text_embedder.cc | 6 +-
  91. .../cc/task/text/text_embedder.h | 3 +-
  92. .../cc/task/text/text_searcher.h | 4 +-
  93. .../text/universal_sentence_encoder_qa.cc | 14 +-
  94. .../task/text/universal_sentence_encoder_qa.h | 7 +-
  95. .../cc/task/text/utils/bert_utils.cc | 2 +-
  96. .../task/vision/core/base_vision_task_api.h | 9 +-
  97. .../cc/task/vision/core/classification_head.h | 2 +-
  98. .../cc/task/vision/core/frame_buffer.h | 47 +-
  99. .../cc/task/vision/core/label_map_item.cc | 5 +-
  100. .../cc/task/vision/core/label_map_item.h | 7 +-
  101. .../cc/task/vision/image_classifier.cc | 14 +-
  102. .../cc/task/vision/image_classifier.h | 8 +-
  103. .../cc/task/vision/image_embedder.cc | 17 +-
  104. .../cc/task/vision/image_embedder.h | 9 +-
  105. .../cc/task/vision/image_searcher.cc | 7 +-
  106. .../cc/task/vision/image_searcher.h | 8 +-
  107. .../cc/task/vision/image_segmenter.cc | 17 +-
  108. .../cc/task/vision/image_segmenter.h | 8 +-
  109. .../cc/task/vision/object_detector.cc | 14 +-
  110. .../cc/task/vision/object_detector.h | 5 +-
  111. .../vision/proto/image_searcher_options.proto | 2 -
  112. .../vision/utils/frame_buffer_common_utils.cc | 59 +-
  113. .../vision/utils/frame_buffer_common_utils.h | 37 +-
  114. .../task/vision/utils/frame_buffer_utils.cc | 50 +-
  115. .../cc/task/vision/utils/frame_buffer_utils.h | 40 +-
  116. .../utils/frame_buffer_utils_interface.h | 11 +-
  117. .../cc/task/vision/utils/image_utils.cc | 12 +-
  118. .../cc/task/vision/utils/image_utils.h | 2 +-
  119. .../vision/utils/libyuv_frame_buffer_utils.cc | 81 +-
  120. .../vision/utils/libyuv_frame_buffer_utils.h | 9 +-
  121. .../cc/task/vision/utils/score_calibration.cc | 8 +-
  122. .../cc/task/vision/utils/score_calibration.h | 11 +-
  123. .../cc/test/common_test.cc | 2 +-
  124. .../task/processor/image_preprocessor_test.cc | 13 +-
  125. .../test/task/text/bert_nl_classifier_test.cc | 36 +-
  126. .../task/text/bert_question_answerer_test.cc | 7 +-
  127. .../test/task/text/clu_lib/bert_utils_test.cc | 32 +-
  128. .../task/text/clu_lib/intent_repr_test.cc | 2 +-
  129. .../text/nlclassifier/nl_classifier_test.cc | 83 +-
  130. .../cc/test/task/text/text_embedder_test.cc | 26 +-
  131. .../cc/test/task/text/text_searcher_test.cc | 18 +-
  132. .../universal_sentence_encoder_qa_test.cc | 16 +-
  133. .../test/task/vision/image_classifier_test.cc | 158 +-
  134. .../test/task/vision/image_embedder_test.cc | 95 +-
  135. .../test/task/vision/image_searcher_test.cc | 62 +-
  136. .../test/task/vision/image_segmenter_test.cc | 117 +-
  137. .../test/task/vision/object_detector_test.cc | 157 +-
  138. .../cc/test/test_utils.cc | 18 +-
  139. .../cc/test/test_utils.h | 6 +-
  140. .../cc/text/tokenizers/bert_tokenizer.cc | 3 +-
  141. .../cc/text/tokenizers/bert_tokenizer.h | 3 +-
  142. .../cc/text/tokenizers/bert_tokenizer_jni.cc | 25 +-
  143. .../cc/text/tokenizers/regex_tokenizer.cc | 4 +-
  144. .../cc/text/tokenizers/sentencepiece_jni.cc | 20 +-
  145. .../cc/text/tokenizers/tokenizer_jni_lib.cc | 3 +-
  146. .../cc/text/tokenizers/tokenizer_jni_lib.h | 3 +-
  147. .../cc/text/tokenizers/tokenizer_utils.cc | 6 +-
  148. .../cc/text/tokenizers/tokenizer_utils.h | 1 -
  149. .../cc/utils/common_utils.cc | 3 +-
  150. .../cc/utils/common_utils.h | 3 +-
  151. .../cc/utils/jni_utils.cc | 7 +-
  152. .../cc/utils/jni_utils.h | 9 +-
  153. .../codegen/android_java_generator.cc | 37 +-
  154. .../codegen/android_java_generator.h | 5 +-
  155. .../codegen/code_generator.cc | 3 +-
  156. .../codegen/code_generator.h | 3 +-
  157. .../codegen/code_generator_test.cc | 3 +-
  158. .../codegen/metadata_helper.h | 2 +-
  159. .../codegen/python/codegen_lib.cc | 9 +-
  160. .../tensorflow_lite_support/codegen/utils.cc | 36 +-
  161. .../custom_ops/kernel/ngrams.cc | 7 +-
  162. .../custom_ops/kernel/ngrams_op_resolver.cc | 2 +-
  163. .../custom_ops/kernel/ngrams_test.cc | 9 +-
  164. .../kernel/ragged/py_tflite_registerer.h | 2 +-
  165. .../kernel/ragged/ragged_range_tflite.cc | 9 +-
  166. .../kernel/ragged/ragged_range_tflite_test.cc | 3 +-
  167. .../ragged/ragged_tensor_to_tensor_tflite.cc | 47 +-
  168. .../ragged_tensor_to_tensor_tflite_test.cc | 6 +-
  169. .../kernel/sentencepiece/model_converter.cc | 10 +-
  170. .../kernel/sentencepiece/model_converter.h | 6 +-
  171. .../sentencepiece/optimized_decoder_test.cc | 6 +-
  172. .../kernel/sentencepiece/optimized_encoder.cc | 23 +-
  173. .../kernel/sentencepiece/optimized_encoder.h | 10 +-
  174. .../sentencepiece/optimized_encoder_test.cc | 8 +-
  175. .../sentencepiece/py_tflite_registerer.h | 2 +-
  176. .../sentencepiece_detokenizer_tflite.cc | 3 +-
  177. .../sentencepiece_tokenizer_op.cc | 6 +-
  178. .../sentencepiece_tokenizer_tflite.cc | 7 +-
  179. .../custom_ops/kernel/whitespace_tokenizer.cc | 13 +-
  180. .../whitespace_tokenizer_op_resolver.cc | 2 +-
  181. .../audio/desktop/audio_classifier_demo.cc | 16 +-
  182. .../audio/desktop/audio_classifier_lib.cc | 11 +-
  183. .../task/audio/desktop/audio_classifier_lib.h | 3 +-
  184. .../text/desktop/bert_nl_classifier_demo.cc | 14 +-
  185. .../desktop/bert_question_answerer_demo.cc | 18 +-
  186. .../task/text/desktop/nl_classifier_demo.cc | 14 +-
  187. .../task/text/desktop/text_embedder_demo.cc | 26 +-
  188. .../task/text/desktop/text_searcher_demo.cc | 30 +-
  189. .../universal_sentence_encoder_qa_demo.cc | 17 +-
  190. .../vision/desktop/image_classifier_demo.cc | 34 +-
  191. .../vision/desktop/image_embedder_demo.cc | 30 +-
  192. .../vision/desktop/image_searcher_demo.cc | 30 +-
  193. .../vision/desktop/image_segmenter_demo.cc | 24 +-
  194. .../vision/desktop/object_detector_demo.cc | 40 +-
  195. .../ios/sources/TFLCommon.h | 11 +-
  196. .../ios/sources/TFLCommonUtils.h | 32 +-
  197. .../ios/sources/TFLCommonUtils.m | 19 +-
  198. .../task/audio/core/sources/TFLFloatBuffer.h | 18 +-
  199. .../task/audio/core/sources/TFLFloatBuffer.m | 4 +-
  200. .../task/audio/core/sources/TFLRingBuffer.h | 32 +-
  201. .../task/audio/core/sources/TFLRingBuffer.m | 49 +-
  202. .../core/sources/TFLBaseOptions+Helpers.h | 2 +-
  203. .../ios/task/core/sources/TFLBaseOptions.h | 32 +-
  204. .../processor/sources/TFLCategory+Helpers.h | 2 +-
  205. .../processor/sources/TFLCategory+Helpers.m | 7 +-
  206. .../ios/task/processor/sources/TFLCategory.h | 22 +-
  207. .../ios/task/processor/sources/TFLCategory.m | 4 +-
  208. .../TFLClassificationOptions+Helpers.h | 6 +-
  209. .../TFLClassificationOptions+Helpers.m | 33 +-
  210. .../sources/TFLClassificationOptions.h | 9 +-
  211. .../sources/TFLClassificationResult+Helpers.h | 17 +-
  212. .../sources/TFLClassificationResult+Helpers.m | 22 +-
  213. .../sources/TFLClassificationResult.h | 79 +-
  214. .../sources/TFLClassificationResult.m | 12 +-
  215. .../sources/TFLDetectionResult+Helpers.h | 11 +-
  216. .../sources/TFLDetectionResult+Helpers.m | 15 +-
  217. .../processor/sources/TFLDetectionResult.h | 35 +-
  218. .../processor/sources/TFLDetectionResult.m | 4 +-
  219. .../sources/TFLSegmentationResult+Helpers.h | 4 +-
  220. .../sources/TFLSegmentationResult+Helpers.m | 44 +-
  221. .../processor/sources/TFLSegmentationResult.h | 65 +-
  222. .../processor/sources/TFLSegmentationResult.m | 45 +-
  223. .../Sources/TFLBertNLClassifier.h | 21 +-
  224. .../nlclassifier/Sources/TFLNLClassifier.h | 47 +-
  225. .../text/qa/Sources/TFLBertQuestionAnswerer.h | 4 +-
  226. .../task/vision/sources/TFLImageClassifier.h | 90 +-
  227. .../task/vision/sources/TFLImageClassifier.m | 58 +-
  228. .../task/vision/sources/TFLImageSegmenter.h | 62 +-
  229. .../task/vision/sources/TFLImageSegmenter.m | 49 +-
  230. .../task/vision/sources/TFLObjectDetector.h | 64 +-
  231. .../task/vision/sources/TFLObjectDetector.m | 54 +-
  232. .../vision/utils/sources/GMLImage+Utils.h | 8 +-
  233. .../vision/utils/sources/GMLImage+Utils.m | 225 +-
  234. .../test/task/audio/core/TFLRingBufferTests.m | 171 +-
  235. .../TFLImageClassifierTests.m | 28 +-
  236. .../image_segmenter/TFLImageSegmenterTests.m | 64 +-
  237. .../object_detector/TFLObjectDetectorTests.m | 36 +-
  238. .../tokenizers/Sources/TFLBertTokenizer.h | 6 +-
  239. .../Sources/TFLSentencepieceTokenizer.h | 2 +-
  240. .../text/tokenizers/Sources/TFLTokenizer.h | 4 +-
  241. .../tokenizers/Sources/TFLTokenizerUtil.h | 11 +-
  242. .../ios/utils/Sources/TFLStringUtil.mm | 11 +-
  243. .../lite/support/audio/TensorAudio.java | 524 ++---
  244. .../lite/support/common/FileUtil.java | 301 +--
  245. .../lite/support/common/Operator.java | 15 +-
  246. .../lite/support/common/Processor.java | 2 +-
  247. .../support/common/SequentialProcessor.java | 83 +-
  248. .../lite/support/common/TensorOperator.java | 6 +-
  249. .../lite/support/common/TensorProcessor.java | 57 +-
  250. .../common/internal/SupportPreconditions.java | 302 +--
  251. .../lite/support/common/ops/CastOp.java | 55 +-
  252. .../lite/support/common/ops/DequantizeOp.java | 9 +-
  253. .../lite/support/common/ops/NormalizeOp.java | 245 ++-
  254. .../lite/support/common/ops/QuantizeOp.java | 9 +-
  255. .../lite/support/image/BitmapContainer.java | 116 +-
  256. .../lite/support/image/BoundingBoxUtil.java | 369 ++--
  257. .../lite/support/image/ColorSpaceType.java | 623 +++---
  258. .../lite/support/image/ImageContainer.java | 36 +-
  259. .../lite/support/image/ImageConversions.java | 217 +-
  260. .../lite/support/image/ImageOperator.java | 41 +-
  261. .../lite/support/image/ImageProcessor.java | 285 +--
  262. .../lite/support/image/ImageProperties.java | 91 +-
  263. .../support/image/MediaImageContainer.java | 112 +-
  264. .../lite/support/image/MlImageAdapter.java | 160 +-
  265. .../support/image/TensorBufferContainer.java | 202 +-
  266. .../lite/support/image/TensorImage.java | 677 +++---
  267. .../lite/support/image/ops/ResizeOp.java | 105 +-
  268. .../image/ops/ResizeWithCropOrPadOp.java | 170 +-
  269. .../lite/support/image/ops/Rot90Op.java | 141 +-
  270. .../image/ops/TensorOperatorWrapper.java | 78 +-
  271. .../image/ops/TransformToGrayscaleOp.java | 127 +-
  272. .../lite/support/label/Category.java | 192 +-
  273. .../lite/support/label/LabelUtil.java | 77 +-
  274. .../lite/support/label/TensorLabel.java | 331 +--
  275. .../lite/support/label/ops/LabelAxisOp.java | 70 +-
  276. .../lite/support/model/GpuDelegateProxy.java | 71 +-
  277. .../tensorflow/lite/support/model/Model.java | 467 +++--
  278. .../support/tensorbuffer/TensorBuffer.java | 899 ++++----
  279. .../tensorbuffer/TensorBufferFloat.java | 181 +-
  280. .../tensorbuffer/TensorBufferUint8.java | 188 +-
  281. .../audio/classifier/AudioClassifier.java | 857 ++++----
  282. .../audio/classifier/Classifications.java | 28 +-
  283. .../lite/task/core/BaseOptions.java | 105 +-
  284. .../lite/task/core/BaseTaskApi.java | 122 +-
  285. .../lite/task/core/ComputeSettings.java | 48 +-
  286. .../lite/task/core/TaskJniUtils.java | 275 ++-
  287. .../core/annotations/UsedByReflection.java | 2 +-
  288. .../core/vision/ImageProcessingOptions.java | 125 +-
  289. .../lite/task/processor/NearestNeighbor.java | 53 +-
  290. .../lite/task/processor/SearcherOptions.java | 114 +-
  291. .../text/nlclassifier/BertNLClassifier.java | 391 ++--
  292. .../task/text/nlclassifier/NLClassifier.java | 568 ++---
  293. .../task/text/qa/BertQuestionAnswerer.java | 394 ++--
  294. .../lite/task/text/qa/QaAnswer.java | 60 +-
  295. .../lite/task/text/qa/QuestionAnswerer.java | 19 +-
  296. .../lite/task/text/searcher/TextSearcher.java | 375 ++--
  297. .../vision/classifier/Classifications.java | 25 +-
  298. .../vision/classifier/ImageClassifier.java | 882 ++++----
  299. .../task/vision/core/BaseVisionTaskApi.java | 349 ++--
  300. .../lite/task/vision/detector/Detection.java | 26 +-
  301. .../task/vision/detector/ObjectDetector.java | 873 ++++----
  302. .../task/vision/searcher/ImageSearcher.java | 605 +++---
  303. .../task/vision/segmenter/ColoredLabel.java | 112 +-
  304. .../task/vision/segmenter/ImageSegmenter.java | 752 ++++---
  305. .../task/vision/segmenter/OutputType.java | 202 +-
  306. .../task/vision/segmenter/Segmentation.java | 106 +-
  307. .../lite/support/audio/TensorAudioTest.java | 505 ++---
  308. .../lite/support/common/FileUtilTest.java | 129 +-
  309. .../support/common/TensorProcessorTest.java | 91 +-
  310. .../lite/support/common/ops/CastOpTest.java | 91 +-
  311. .../support/common/ops/DequantizeOpTest.java | 23 +-
  312. .../support/common/ops/NormalizeOpTest.java | 217 +-
  313. .../support/common/ops/QuantizeOpTest.java | 21 +-
  314. .../support/image/BoundingBoxUtilTest.java | 343 ++--
  315. .../image/ColorSpaceTypeInstrumentedTest.java | 37 +-
  316. .../support/image/ColorSpaceTypeTest.java | 703 +++----
  317. .../ImageConversionsInstrumentedTest.java | 338 +--
  318. .../support/image/ImageConversionsTest.java | 164 +-
  319. .../image/ImageProcessorInstrumentedTest.java | 221 +-
  320. .../support/image/ImageProcessorTest.java | 209 +-
  321. .../support/image/MlImageAdapterTest.java | 259 +--
  322. .../image/TensorImageInstrumentedTest.java | 208 +-
  323. .../lite/support/image/TensorImageTest.java | 1391 ++++++-------
  324. .../lite/support/image/TestImageCreator.java | 183 +-
  325. .../image/ops/ResizeOpInstrumentedTest.java | 103 +-
  326. ...ResizeWithCropOrPadOpInstrumentedTest.java | 239 ++-
  327. .../image/ops/Rot90OpInstrumentedTest.java | 122 +-
  328. ...ransformToGrayScaleOpInstrumentedTest.java | 104 +-
  329. .../lite/support/label/CategoryTest.java | 204 +-
  330. .../lite/support/label/LabelUtilTest.java | 47 +-
  331. .../lite/support/label/TensorLabelTest.java | 327 +--
  332. .../support/label/ops/LabelAxisOpTest.java | 160 +-
  333. .../GpuDelegateProxyInstrumentedTest.java | 18 +-
  334. .../support/model/GpuDelegateProxyTest.java | 11 +-
  335. .../lite/support/model/ModelTest.java | 244 +--
  336. .../tensorbuffer/TensorBufferFloatTest.java | 82 +-
  337. .../tensorbuffer/TensorBufferTest.java | 1707 +++++++--------
  338. .../tensorbuffer/TensorBufferUint8Test.java | 82 +-
  339. .../audio/classifier/audio_classifier_jni.cc | 42 +-
  340. .../src/native/task/core/task_jni_utils.cc | 5 +-
  341. .../bert/bert_nl_classifier_jni.cc | 23 +-
  342. .../text/nlclassifier/nl_classifier_jni.cc | 21 +-
  343. .../text/qa/bert_question_answerer_jni.cc | 24 +-
  344. .../task/text/searcher/text_searcher_jni.cc | 36 +-
  345. .../vision/classifier/image_classifier_jni.cc | 27 +-
  346. .../vision/core/base_vision_task_api_jni.cc | 40 +-
  347. .../vision/detector/object_detector_jni.cc | 27 +-
  348. .../java/src/native/task/vision/jni_utils.cc | 30 +-
  349. .../java/src/native/task/vision/jni_utils.h | 28 +-
  350. .../vision/searcher/image_searcher_jni.cc | 36 +-
  351. .../vision/segmenter/image_segmenter_jni.cc | 32 +-
  352. .../metadata/cc/metadata_extractor.cc | 20 +-
  353. .../metadata/cc/metadata_extractor.h | 4 +-
  354. .../metadata/cc/metadata_populator.cc | 2 +-
  355. .../metadata/cc/metadata_populator.h | 7 +-
  356. .../metadata/cc/metadata_version.cc | 33 +-
  357. .../cc/utils/zip_readonly_mem_file.cc | 13 +-
  358. .../metadata/cc/utils/zip_readonly_mem_file.h | 4 +-
  359. .../cc/utils/zip_writable_mem_file.cc | 17 +-
  360. .../metadata/cc/utils/zip_writable_mem_file.h | 4 +-
  361. .../flatbuffers_lib/flatbuffers_lib.cc | 2 +-
  362. .../support/metadata/BoundedInputStream.java | 138 +-
  363. .../support/metadata/ByteBufferChannel.java | 188 +-
  364. .../support/metadata/MetadataExtractor.java | 622 +++---
  365. .../lite/support/metadata/MetadataParser.java | 12 +-
  366. .../lite/support/metadata/ModelInfo.java | 448 ++--
  367. .../support/metadata/ModelMetadataInfo.java | 243 ++-
  368. .../lite/support/metadata/Preconditions.java | 306 +--
  369. .../metadata/SeekableByteChannelCompat.java | 140 +-
  370. .../lite/support/metadata/ZipFile.java | 686 +++----
  371. .../metadata/BoundedInputStreamTest.java | 429 ++--
  372. .../metadata/ByteBufferChannelTest.java | 480 +++--
  373. .../metadata/MetadataExtractorTest.java | 1828 ++++++++---------
  374. .../support/metadata/MetadataParserTest.java | 18 +-
  375. .../lite/support/metadata/ZipFileTest.java | 206 +-
  376. .../odml/ios/image/apis/GMLImage.h | 47 +-
  377. .../android/odml/image/BitmapExtractor.java | 43 +-
  378. .../odml/image/BitmapImageContainer.java | 70 +-
  379. .../odml/image/BitmapMlImageBuilder.java | 137 +-
  380. .../odml/image/ByteBufferExtractor.java | 421 ++--
  381. .../odml/image/ByteBufferImageContainer.java | 68 +-
  382. .../odml/image/ByteBufferMlImageBuilder.java | 135 +-
  383. .../android/odml/image/ImageContainer.java | 12 +-
  384. .../android/odml/image/ImageProperties.java | 92 +-
  385. .../odml/image/MediaImageContainer.java | 81 +-
  386. .../odml/image/MediaImageExtractor.java | 42 +-
  387. .../odml/image/MediaMlImageBuilder.java | 105 +-
  388. .../google/android/odml/image/MlImage.java | 423 ++--
  389. .../odml/image/BitmapExtractorTest.java | 46 +-
  390. .../odml/image/BitmapMlImageBuilderTest.java | 116 +-
  391. .../odml/image/ByteBufferExtractorTest.java | 264 ++-
  392. .../image/ByteBufferMlImageBuilderTest.java | 93 +-
  393. .../odml/image/MediaImageExtractorTest.java | 48 +-
  394. .../odml/image/MediaMlImageBuilderTest.java | 109 +-
  395. .../android/odml/image/TestImageCreator.java | 211 +-
  396. .../core/pybinds/_pywrap_audio_buffer.cc | 17 +-
  397. .../audio/pybinds/_pywrap_audio_classifier.cc | 1 -
  398. .../audio/pybinds/_pywrap_audio_embedder.cc | 22 +-
  399. .../task/vision/core/pybinds/image_utils.cc | 4 +-
  400. .../pybinds/_pywrap_image_classifier.cc | 16 +-
  401. .../vision/pybinds/_pywrap_image_segmenter.cc | 12 +-
  402. .../vision/pybinds/_pywrap_object_detector.cc | 13 +-
  403. .../scann_ondevice/cc/core/index_table_sum.h | 41 +-
  404. .../scann_ondevice/cc/core/indexer.cc | 24 +-
  405. .../scann_ondevice/cc/core/indexer.h | 2 +-
  406. .../scann_ondevice/cc/core/indexer_test.cc | 6 +-
  407. .../scann_ondevice/cc/core/partitioner.cc | 8 +-
  408. .../scann_ondevice/cc/core/partitioner.h | 5 +-
  409. .../scann_ondevice/cc/core/searcher.h | 29 +-
  410. .../scann_ondevice/cc/core/searcher_test.cc | 9 +-
  411. .../cc/core/top_n_amortized_constant.h | 12 +-
  412. .../scann_ondevice/cc/index.cc | 23 +-
  413. .../scann_ondevice/cc/index.h | 13 +-
  414. .../scann_ondevice/cc/index_builder.cc | 24 +-
  415. .../scann_ondevice/cc/index_builder.h | 14 +-
  416. .../cc/mem_random_access_file.cc | 7 +-
  417. .../cc/mem_random_access_file.h | 8 +-
  418. .../scann_ondevice/cc/mem_writable_file.h | 8 +-
  419. .../cc/python/index_builder_py_wrapper.cc | 6 +-
  420. .../cc/test/index_builder_test.cc | 143 +-
  421. .../scann_ondevice/cc/test/index_test.cc | 33 +-
  422. .../cc/test/mem_writable_file_test.cc | 2 +-
  423. .../leveldb_testing_utils_py_wrapper.cc | 14 +-
  424. .../src/third_party/fft2d/fft.h | 12 +-
  425. .../src/third_party/fft2d/fft2d.h | 12 +-
  426. 420 files changed, 19248 insertions(+), 18509 deletions(-)
  427. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/acceleration/configuration/edgetpu_coral_plugin.cc b/third_party/tflite_support/src/tensorflow_lite_support/acceleration/configuration/edgetpu_coral_plugin.cc
  428. index 9f27f3baae82f..6a16d12856258 100644
  429. --- a/third_party/tflite_support/src/tensorflow_lite_support/acceleration/configuration/edgetpu_coral_plugin.cc
  430. +++ b/third_party/tflite_support/src/tensorflow_lite_support/acceleration/configuration/edgetpu_coral_plugin.cc
  431. @@ -17,12 +17,12 @@ limitations under the License.
  432. #include <glog/logging.h>
  433. #include "absl/container/node_hash_map.h" // from @com_google_absl
  434. -#include "absl/memory/memory.h" // from @com_google_absl
  435. -#include "absl/strings/match.h" // from @com_google_absl
  436. -#include "absl/strings/numbers.h" // from @com_google_absl
  437. -#include "tflite/public/edgetpu_c.h"
  438. +#include "absl/memory/memory.h" // from @com_google_absl
  439. +#include "absl/strings/match.h" // from @com_google_absl
  440. +#include "absl/strings/numbers.h" // from @com_google_absl
  441. #include "tensorflow/lite/experimental/acceleration/configuration/configuration_generated.h"
  442. #include "tensorflow/lite/experimental/acceleration/configuration/delegate_registry.h"
  443. +#include "tflite/public/edgetpu_c.h"
  444. namespace tflite {
  445. namespace delegates {
  446. @@ -50,12 +50,16 @@ inline std::string ConvertBool(bool from_bool) {
  447. return from_bool ? "True" : "False";
  448. }
  449. -bool MatchDevice(const std::string& device, const std::string& type,
  450. +bool MatchDevice(const std::string& device,
  451. + const std::string& type,
  452. int* index) {
  453. const auto prefix(type + ":");
  454. - if (!absl::StartsWith(device, prefix)) return false;
  455. - if (!absl::SimpleAtoi(device.substr(prefix.size()), index)) return false;
  456. - if (*index < 0) return false;
  457. + if (!absl::StartsWith(device, prefix))
  458. + return false;
  459. + if (!absl::SimpleAtoi(device.substr(prefix.size()), index))
  460. + return false;
  461. + if (*index < 0)
  462. + return false;
  463. return true;
  464. }
  465. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/acceleration/configuration/edgetpu_coral_plugin_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/acceleration/configuration/edgetpu_coral_plugin_test.cc
  466. index a02635b9f3578..6ac4e5c734567 100644
  467. --- a/third_party/tflite_support/src/tensorflow_lite_support/acceleration/configuration/edgetpu_coral_plugin_test.cc
  468. +++ b/third_party/tflite_support/src/tensorflow_lite_support/acceleration/configuration/edgetpu_coral_plugin_test.cc
  469. @@ -43,7 +43,8 @@ using ::tflite::task::vision::ImageDataFree;
  470. using EdgeTpuCoralPluginTest = testing::TestWithParam<std::string>;
  471. -INSTANTIATE_TEST_SUITE_P(CoralPluginTests, EdgeTpuCoralPluginTest,
  472. +INSTANTIATE_TEST_SUITE_P(CoralPluginTests,
  473. + EdgeTpuCoralPluginTest,
  474. testing::Values(kRegularModelFilePath,
  475. kEdgeTpuModelFilePath));
  476. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/common.cc b/third_party/tflite_support/src/tensorflow_lite_support/c/common.cc
  477. index 2a182bbd6535a..f0974ed26b826 100644
  478. --- a/third_party/tflite_support/src/tensorflow_lite_support/c/common.cc
  479. +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/common.cc
  480. @@ -17,7 +17,7 @@ limitations under the License.
  481. #include <cstdlib>
  482. -void TfLiteSupportErrorDelete(TfLiteSupportError *error) {
  483. +void TfLiteSupportErrorDelete(TfLiteSupportError* error) {
  484. // `strdup` obtains memory using `malloc` and the memory needs to be
  485. // released using `free`.
  486. free(error->message);
  487. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/common.h b/third_party/tflite_support/src/tensorflow_lite_support/c/common.h
  488. index 1e21f1dcb31dc..3ced64226987f 100644
  489. --- a/third_party/tflite_support/src/tensorflow_lite_support/c/common.h
  490. +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/common.h
  491. @@ -190,10 +190,10 @@ typedef struct TfLiteSupportError {
  492. // Holds the error code.
  493. enum TfLiteSupportErrorCode code;
  494. // Detailed description of the error.
  495. - char *message;
  496. + char* message;
  497. } TfLiteSupportError;
  498. -void TfLiteSupportErrorDelete(TfLiteSupportError *error);
  499. +void TfLiteSupportErrorDelete(TfLiteSupportError* error);
  500. #ifdef __cplusplus
  501. } // extern "C"
  502. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/common_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/c/common_utils.cc
  503. index 39287377c4b36..39afb9c8cbdf3 100644
  504. --- a/third_party/tflite_support/src/tensorflow_lite_support/c/common_utils.cc
  505. +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/common_utils.cc
  506. @@ -18,15 +18,17 @@ limitations under the License.
  507. #include <string>
  508. #include "absl/status/status.h" // from @com_google_absl
  509. -#include "absl/strings/cord.h" // from @com_google_absl
  510. +#include "absl/strings/cord.h" // from @com_google_absl
  511. #include "tensorflow_lite_support/cc/common.h"
  512. namespace tflite {
  513. namespace support {
  514. void CreateTfLiteSupportError(enum TfLiteSupportErrorCode code,
  515. - const char* message, TfLiteSupportError** error) {
  516. - if (error == nullptr) return;
  517. + const char* message,
  518. + TfLiteSupportError** error) {
  519. + if (error == nullptr)
  520. + return;
  521. *error = new TfLiteSupportError;
  522. (*error)->code = code;
  523. @@ -35,7 +37,8 @@ void CreateTfLiteSupportError(enum TfLiteSupportErrorCode code,
  524. void CreateTfLiteSupportErrorWithStatus(const absl::Status& status,
  525. TfLiteSupportError** error) {
  526. - if (status.ok() || error == nullptr) return;
  527. + if (status.ok() || error == nullptr)
  528. + return;
  529. // Payload of absl::Status created by the tflite task library stores an
  530. // appropriate value of the enum TfLiteSupportStatus. The integer value
  531. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/common_utils.h b/third_party/tflite_support/src/tensorflow_lite_support/c/common_utils.h
  532. index 6959029575663..551f64a598970 100644
  533. --- a/third_party/tflite_support/src/tensorflow_lite_support/c/common_utils.h
  534. +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/common_utils.h
  535. @@ -27,7 +27,8 @@ namespace support {
  536. // Creates a TfLiteSupportError with a TfLiteSupportErrorCode and message.
  537. void CreateTfLiteSupportError(enum TfLiteSupportErrorCode code,
  538. - const char* message, TfLiteSupportError** error);
  539. + const char* message,
  540. + TfLiteSupportError** error);
  541. // Creates a TfLiteSupportError from absl::Status and passes it back as a
  542. // parameter which is a pointer to the error pointer.
  543. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/audio/audio_classifier.cc b/third_party/tflite_support/src/tensorflow_lite_support/c/task/audio/audio_classifier.cc
  544. index 89fba26b9b72f..3f1781a0a7db8 100644
  545. --- a/third_party/tflite_support/src/tensorflow_lite_support/c/task/audio/audio_classifier.cc
  546. +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/audio/audio_classifier.cc
  547. @@ -109,7 +109,8 @@ TfLiteAudioClassifierOptions TfLiteAudioClassifierOptionsCreate(void) {
  548. }
  549. TfLiteAudioClassifier* TfLiteAudioClassifierFromOptions(
  550. - const TfLiteAudioClassifierOptions* options, TfLiteSupportError** error) {
  551. + const TfLiteAudioClassifierOptions* options,
  552. + TfLiteSupportError** error) {
  553. StatusOr<AudioClassifierOptionsCpp> cpp_option_status =
  554. CreateAudioClassifierCppOptionsFromCOptions(options);
  555. @@ -181,7 +182,8 @@ TfLiteClassificationResult* GetClassificationResultCStruct(
  556. TfLiteClassificationResult* TfLiteAudioClassifierClassify(
  557. const TfLiteAudioClassifier* classifier,
  558. - const TfLiteAudioBuffer* audio_buffer, TfLiteSupportError** error) {
  559. + const TfLiteAudioBuffer* audio_buffer,
  560. + TfLiteSupportError** error) {
  561. if (classifier == nullptr) {
  562. tflite::support::CreateTfLiteSupportError(
  563. kInvalidArgumentError, "Expected non null audio classifier.", error);
  564. @@ -211,7 +213,8 @@ TfLiteClassificationResult* TfLiteAudioClassifierClassify(
  565. }
  566. int TfLiteAudioClassifierGetRequiredInputBufferSize(
  567. - TfLiteAudioClassifier* classifier, TfLiteSupportError** error) {
  568. + TfLiteAudioClassifier* classifier,
  569. + TfLiteSupportError** error) {
  570. if (classifier == nullptr) {
  571. tflite::support::CreateTfLiteSupportError(
  572. kInvalidArgumentError, "Expected non null audio classifier.", error);
  573. @@ -226,7 +229,8 @@ void TfLiteAudioClassifierDelete(TfLiteAudioClassifier* classifier) {
  574. }
  575. TfLiteAudioFormat* TfLiteAudioClassifierGetRequiredAudioFormat(
  576. - TfLiteAudioClassifier* classifier, TfLiteSupportError** error) {
  577. + TfLiteAudioClassifier* classifier,
  578. + TfLiteSupportError** error) {
  579. if (classifier == nullptr) {
  580. tflite::support::CreateTfLiteSupportError(
  581. kInvalidArgumentError, "Expected non null audio classifier.", error);
  582. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/audio/audio_classifier.h b/third_party/tflite_support/src/tensorflow_lite_support/c/task/audio/audio_classifier.h
  583. index e83295963378c..6af9b27944744 100644
  584. --- a/third_party/tflite_support/src/tensorflow_lite_support/c/task/audio/audio_classifier.h
  585. +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/audio/audio_classifier.h
  586. @@ -157,7 +157,8 @@ TfLiteAudioClassifierOptions TfLiteAudioClassifierOptionsCreate(void);
  587. // TfLiteSupportErrorDelete(error)
  588. //
  589. TfLiteAudioClassifier* TfLiteAudioClassifierFromOptions(
  590. - const TfLiteAudioClassifierOptions* options, TfLiteSupportError** error);
  591. + const TfLiteAudioClassifierOptions* options,
  592. + TfLiteSupportError** error);
  593. // Invokes the encapsulated TFLite model and classifies the frame_buffer.
  594. // Returns a pointer to the created classification result in case of success or
  595. @@ -185,15 +186,18 @@ TfLiteAudioClassifier* TfLiteAudioClassifierFromOptions(
  596. //
  597. TfLiteClassificationResult* TfLiteAudioClassifierClassify(
  598. const TfLiteAudioClassifier* classifier,
  599. - const TfLiteAudioBuffer* audio_buffer, TfLiteSupportError** error);
  600. + const TfLiteAudioBuffer* audio_buffer,
  601. + TfLiteSupportError** error);
  602. // Returns the input buffer size required by the audio classifier.
  603. int TfLiteAudioClassifierGetRequiredInputBufferSize(
  604. - TfLiteAudioClassifier* classifier, TfLiteSupportError** error);
  605. + TfLiteAudioClassifier* classifier,
  606. + TfLiteSupportError** error);
  607. // Returns the audio format required by the audio classifier.
  608. TfLiteAudioFormat* TfLiteAudioClassifierGetRequiredAudioFormat(
  609. - TfLiteAudioClassifier* classifier, TfLiteSupportError** error);
  610. + TfLiteAudioClassifier* classifier,
  611. + TfLiteSupportError** error);
  612. // Disposes off the audio classifier.
  613. void TfLiteAudioClassifierDelete(TfLiteAudioClassifier* classifier);
  614. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/audio/core/audio_buffer.h b/third_party/tflite_support/src/tensorflow_lite_support/c/task/audio/core/audio_buffer.h
  615. index 2ec7571036d29..471f02fdf2132 100644
  616. --- a/third_party/tflite_support/src/tensorflow_lite_support/c/task/audio/core/audio_buffer.h
  617. +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/audio/core/audio_buffer.h
  618. @@ -45,11 +45,11 @@ typedef struct TfLiteAudioBuffer {
  619. int size;
  620. } TfLiteAudioBuffer;
  621. -void TfLiteAudioBufferDelete(TfLiteAudioBuffer *buffer);
  622. +void TfLiteAudioBufferDelete(TfLiteAudioBuffer* buffer);
  623. void TfLiteAudioBufferDeleteData(const TfLiteAudioBuffer audio_buffer);
  624. -void TfLiteAudioFormatDelete(TfLiteAudioFormat *format);
  625. +void TfLiteAudioFormatDelete(TfLiteAudioFormat* format);
  626. #ifdef __cplusplus
  627. } // extern "C"
  628. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/processor/classification_result.cc b/third_party/tflite_support/src/tensorflow_lite_support/c/task/processor/classification_result.cc
  629. index 646e2c237c2f8..b7d7fab827694 100644
  630. --- a/third_party/tflite_support/src/tensorflow_lite_support/c/task/processor/classification_result.cc
  631. +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/processor/classification_result.cc
  632. @@ -27,7 +27,7 @@ void TfLiteClassificationResultDelete(
  633. for (int head = 0; head < classification_result->size; ++head) {
  634. TfLiteClassifications classifications =
  635. classification_result->classifications[head];
  636. - free(classifications.head_name);
  637. + free(classifications.head_name);
  638. for (int rank = 0; rank < classifications.size; ++rank) {
  639. TfLiteCategoryDelete(&(classifications.categories[rank]));
  640. }
  641. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_nl_classifier.cc b/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_nl_classifier.cc
  642. index 26888a832fc34..52907f4fe7d35 100644
  643. --- a/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_nl_classifier.cc
  644. +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_nl_classifier.cc
  645. @@ -40,7 +40,8 @@ struct TfLiteBertNLClassifier {
  646. };
  647. TfLiteBertNLClassifier* TfLiteBertNLClassifierCreateFromOptions(
  648. - const char* model_path, const TfLiteBertNLClassifierOptions* options) {
  649. + const char* model_path,
  650. + const TfLiteBertNLClassifierOptions* options) {
  651. BertNLClassifierOptionsCpp cc_options;
  652. cc_options.mutable_base_options()->mutable_model_file()->set_file_name(
  653. @@ -64,7 +65,8 @@ TfLiteBertNLClassifier* TfLiteBertNLClassifierCreate(const char* model_path) {
  654. }
  655. Categories* TfLiteBertNLClassifierClassify(
  656. - const TfLiteBertNLClassifier* classifier, const char* text) {
  657. + const TfLiteBertNLClassifier* classifier,
  658. + const char* text) {
  659. std::vector<CategoryCpp> results =
  660. classifier->impl->Classify(absl::string_view(text).data());
  661. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_nl_classifier.h b/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_nl_classifier.h
  662. index 430f5735c6bd2..94138a291233b 100644
  663. --- a/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_nl_classifier.h
  664. +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_nl_classifier.h
  665. @@ -48,7 +48,8 @@ typedef struct TfLiteBertNLClassifierOptions {
  666. // Creates TfLiteBertNLClassifier from model path and options, returns nullptr
  667. // if the file doesn't exist or is not a well formatted TFLite model path.
  668. TfLiteBertNLClassifier* TfLiteBertNLClassifierCreateFromOptions(
  669. - const char* model_path, const TfLiteBertNLClassifierOptions* options);
  670. + const char* model_path,
  671. + const TfLiteBertNLClassifierOptions* options);
  672. // Creates TfLiteBertNLClassifier from model path and default options, returns
  673. // nullptr if the file doesn't exist or is not a well formatted TFLite model
  674. @@ -57,7 +58,8 @@ TfLiteBertNLClassifier* TfLiteBertNLClassifierCreate(const char* model_path);
  675. // Invokes the encapsulated TFLite model and classifies the input text.
  676. Categories* TfLiteBertNLClassifierClassify(
  677. - const TfLiteBertNLClassifier* classifier, const char* text);
  678. + const TfLiteBertNLClassifier* classifier,
  679. + const char* text);
  680. void TfLiteBertNLClassifierDelete(TfLiteBertNLClassifier* classifier);
  681. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_question_answerer.cc b/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_question_answerer.cc
  682. index d0d1639357348..1887d5234d180 100644
  683. --- a/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_question_answerer.cc
  684. +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_question_answerer.cc
  685. @@ -48,7 +48,8 @@ TfLiteBertQuestionAnswerer* TfLiteBertQuestionAnswererCreate(
  686. }
  687. TfLiteQaAnswers* TfLiteBertQuestionAnswererAnswer(
  688. - const TfLiteBertQuestionAnswerer* question_answerer, const char* context,
  689. + const TfLiteBertQuestionAnswerer* question_answerer,
  690. + const char* context,
  691. const char* question) {
  692. std::vector<QaAnswerCpp> answers = question_answerer->impl->Answer(
  693. absl::string_view(context).data(), absl::string_view(question).data());
  694. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_question_answerer.h b/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_question_answerer.h
  695. index 7bc6e6ed385db..e9a1190356914 100644
  696. --- a/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_question_answerer.h
  697. +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_question_answerer.h
  698. @@ -58,7 +58,8 @@ TfLiteBertQuestionAnswerer* TfLiteBertQuestionAnswererCreate(
  699. // Invokes the encapsulated TFLite model and answers a question based on
  700. // context.
  701. TfLiteQaAnswers* TfLiteBertQuestionAnswererAnswer(
  702. - const TfLiteBertQuestionAnswerer* question_answerer, const char* context,
  703. + const TfLiteBertQuestionAnswerer* question_answerer,
  704. + const char* context,
  705. const char* question);
  706. void TfLiteBertQuestionAnswererDelete(
  707. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/nl_classifier.cc b/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/nl_classifier.cc
  708. index d6d86f67a620a..1e6805c1d1cd6 100644
  709. --- a/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/nl_classifier.cc
  710. +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/nl_classifier.cc
  711. @@ -37,7 +37,8 @@ struct TfLiteNLClassifier {
  712. };
  713. TfLiteNLClassifier* TfLiteNLClassifierCreateFromOptions(
  714. - const char* model_path, const TfLiteNLClassifierOptions* options) {
  715. + const char* model_path,
  716. + const TfLiteNLClassifierOptions* options) {
  717. auto classifier_status = NLClassifierCpp::CreateFromFileAndOptions(
  718. std::string(model_path),
  719. {
  720. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/nl_classifier.h b/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/nl_classifier.h
  721. index c47dd59b13eb4..389ca5d686df0 100644
  722. --- a/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/nl_classifier.h
  723. +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/nl_classifier.h
  724. @@ -48,7 +48,8 @@ typedef struct TfLiteNLClassifierOptions {
  725. // Creates TfLiteNLClassifier from model path and options, returns nullptr if
  726. // the file doesn't exist or is not a well formatted TFLite model path.
  727. TfLiteNLClassifier* TfLiteNLClassifierCreateFromOptions(
  728. - const char* model_path, const TfLiteNLClassifierOptions* options);
  729. + const char* model_path,
  730. + const TfLiteNLClassifierOptions* options);
  731. // Invokes the encapsulated TFLite model and classifies the input text.
  732. Categories* TfLiteNLClassifierClassify(const TfLiteNLClassifier* classifier,
  733. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/image_classifier.cc b/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/image_classifier.cc
  734. index 52e215116b51e..183468a6855aa 100644
  735. --- a/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/image_classifier.cc
  736. +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/image_classifier.cc
  737. @@ -110,7 +110,8 @@ TfLiteImageClassifierOptions TfLiteImageClassifierOptionsCreate(void) {
  738. }
  739. TfLiteImageClassifier* TfLiteImageClassifierFromOptions(
  740. - const TfLiteImageClassifierOptions* options, TfLiteSupportError** error) {
  741. + const TfLiteImageClassifierOptions* options,
  742. + TfLiteSupportError** error) {
  743. StatusOr<ImageClassifierOptionsCpp> cpp_option_status =
  744. CreateImageClassifierCppOptionsFromCOptions(options);
  745. @@ -178,7 +179,8 @@ TfLiteClassificationResult* GetClassificationResultCStruct(
  746. TfLiteClassificationResult* TfLiteImageClassifierClassifyWithRoi(
  747. const TfLiteImageClassifier* classifier,
  748. - const TfLiteFrameBuffer* frame_buffer, const TfLiteBoundingBox* roi,
  749. + const TfLiteFrameBuffer* frame_buffer,
  750. + const TfLiteBoundingBox* roi,
  751. TfLiteSupportError** error) {
  752. if (classifier == nullptr) {
  753. tflite::support::CreateTfLiteSupportError(
  754. @@ -221,7 +223,8 @@ TfLiteClassificationResult* TfLiteImageClassifierClassifyWithRoi(
  755. TfLiteClassificationResult* TfLiteImageClassifierClassify(
  756. const TfLiteImageClassifier* classifier,
  757. - const TfLiteFrameBuffer* frame_buffer, TfLiteSupportError** error) {
  758. + const TfLiteFrameBuffer* frame_buffer,
  759. + TfLiteSupportError** error) {
  760. return TfLiteImageClassifierClassifyWithRoi(classifier, frame_buffer, nullptr,
  761. error);
  762. }
  763. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/image_classifier.h b/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/image_classifier.h
  764. index dca83e00f9455..837c9894a2302 100644
  765. --- a/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/image_classifier.h
  766. +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/image_classifier.h
  767. @@ -158,7 +158,8 @@ TfLiteImageClassifierOptions TfLiteImageClassifierOptionsCreate(void);
  768. // TfLiteSupportErrorDelete(error)
  769. //
  770. TfLiteImageClassifier* TfLiteImageClassifierFromOptions(
  771. - const TfLiteImageClassifierOptions* options, TfLiteSupportError** error);
  772. + const TfLiteImageClassifierOptions* options,
  773. + TfLiteSupportError** error);
  774. // Invokes the encapsulated TFLite model and classifies the frame_buffer.
  775. // Returns a pointer to the created classification result in case of success or
  776. @@ -186,7 +187,8 @@ TfLiteImageClassifier* TfLiteImageClassifierFromOptions(
  777. //
  778. TfLiteClassificationResult* TfLiteImageClassifierClassify(
  779. const TfLiteImageClassifier* classifier,
  780. - const TfLiteFrameBuffer* frame_buffer, TfLiteSupportError** error);
  781. + const TfLiteFrameBuffer* frame_buffer,
  782. + TfLiteSupportError** error);
  783. // Invokes the encapsulated TFLite model and classifies the region of the
  784. // frame_buffer specified by the bounding box. Same as TfLiteImageClassifier*
  785. @@ -198,7 +200,8 @@ TfLiteClassificationResult* TfLiteImageClassifierClassify(
  786. // operations.
  787. TfLiteClassificationResult* TfLiteImageClassifierClassifyWithRoi(
  788. const TfLiteImageClassifier* classifier,
  789. - const TfLiteFrameBuffer* frame_buffer, const TfLiteBoundingBox* roi,
  790. + const TfLiteFrameBuffer* frame_buffer,
  791. + const TfLiteBoundingBox* roi,
  792. TfLiteSupportError** error);
  793. // Disposes off the image classifier.
  794. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/image_segmenter.cc b/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/image_segmenter.cc
  795. index e7395ddbde80e..d2cf362e82ed7 100644
  796. --- a/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/image_segmenter.cc
  797. +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/image_segmenter.cc
  798. @@ -92,7 +92,8 @@ TfLiteImageSegmenterOptions TfLiteImageSegmenterOptionsCreate(void) {
  799. }
  800. TfLiteImageSegmenter* TfLiteImageSegmenterFromOptions(
  801. - const TfLiteImageSegmenterOptions* options, TfLiteSupportError** error) {
  802. + const TfLiteImageSegmenterOptions* options,
  803. + TfLiteSupportError** error) {
  804. StatusOr<ImageSegmenterOptionsCpp> cpp_option_status =
  805. CreateImageSegmenterCppOptionsFromCOptions(options);
  806. @@ -182,7 +183,8 @@ TfLiteSegmentationResult* GetSegmentationResultCStruct(
  807. TfLiteSegmentationResult* TfLiteImageSegmenterSegment(
  808. const TfLiteImageSegmenter* segmenter,
  809. - const TfLiteFrameBuffer* frame_buffer, TfLiteSupportError** error) {
  810. + const TfLiteFrameBuffer* frame_buffer,
  811. + TfLiteSupportError** error) {
  812. if (segmenter == nullptr) {
  813. tflite::support::CreateTfLiteSupportError(
  814. kInvalidArgumentError, "Expected non null image segmenter.", error);
  815. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/image_segmenter.h b/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/image_segmenter.h
  816. index c2964fad2c144..e0dc62e224b99 100644
  817. --- a/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/image_segmenter.h
  818. +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/image_segmenter.h
  819. @@ -172,7 +172,8 @@ TfLiteImageSegmenterOptions TfLiteImageSegmenterOptionsCreate(void);
  820. // TfLiteSupportErrorDelete(error)
  821. //
  822. TfLiteImageSegmenter* TfLiteImageSegmenterFromOptions(
  823. - const TfLiteImageSegmenterOptions* options, TfLiteSupportError** error);
  824. + const TfLiteImageSegmenterOptions* options,
  825. + TfLiteSupportError** error);
  826. // Invokes the encapsulated TFLite model and performs image segmentation on
  827. // the frame_buffer.
  828. @@ -201,7 +202,8 @@ TfLiteImageSegmenter* TfLiteImageSegmenterFromOptions(
  829. //
  830. TfLiteSegmentationResult* TfLiteImageSegmenterSegment(
  831. const TfLiteImageSegmenter* segmenter,
  832. - const TfLiteFrameBuffer* frame_buffer, TfLiteSupportError** error);
  833. + const TfLiteFrameBuffer* frame_buffer,
  834. + TfLiteSupportError** error);
  835. // Disposes of the image segmenter.
  836. void TfLiteImageSegmenterDelete(TfLiteImageSegmenter* segmenter);
  837. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/object_detector.cc b/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/object_detector.cc
  838. index 1389a2de0ee75..92535e863b9a3 100644
  839. --- a/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/object_detector.cc
  840. +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/object_detector.cc
  841. @@ -109,7 +109,8 @@ TfLiteObjectDetectorOptions TfLiteObjectDetectorOptionsCreate(void) {
  842. }
  843. TfLiteObjectDetector* TfLiteObjectDetectorFromOptions(
  844. - const TfLiteObjectDetectorOptions* options, TfLiteSupportError** error) {
  845. + const TfLiteObjectDetectorOptions* options,
  846. + TfLiteSupportError** error) {
  847. StatusOr<ObjectDetectorOptionsCpp> cpp_option_status =
  848. CreateObjectDetectorCppOptionsFromCOptions(options);
  849. @@ -174,7 +175,8 @@ TfLiteDetectionResult* GetDetectionResultCStruct(
  850. }
  851. TfLiteDetectionResult* TfLiteObjectDetectorDetect(
  852. - const TfLiteObjectDetector* detector, const TfLiteFrameBuffer* frame_buffer,
  853. + const TfLiteObjectDetector* detector,
  854. + const TfLiteFrameBuffer* frame_buffer,
  855. TfLiteSupportError** error) {
  856. if (detector == nullptr) {
  857. tflite::support::CreateTfLiteSupportError(
  858. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/object_detector.h b/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/object_detector.h
  859. index e2e08ec161559..b4d4564fefeb0 100644
  860. --- a/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/object_detector.h
  861. +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/object_detector.h
  862. @@ -157,7 +157,8 @@ TfLiteObjectDetectorOptions TfLiteObjectDetectorOptionsCreate(void);
  863. // TfLiteSupportErrorDelete(error)
  864. //
  865. TfLiteObjectDetector* TfLiteObjectDetectorFromOptions(
  866. - const TfLiteObjectDetectorOptions* options, TfLiteSupportError** error);
  867. + const TfLiteObjectDetectorOptions* options,
  868. + TfLiteSupportError** error);
  869. // Invokes the encapsulated TFLite model and performs object detection on the
  870. // frame_buffer. Returns a pointer to the created object detection result result
  871. @@ -185,7 +186,8 @@ TfLiteObjectDetector* TfLiteObjectDetectorFromOptions(
  872. // TfLiteSupportErrorDelete(error)
  873. //
  874. TfLiteDetectionResult* TfLiteObjectDetectorDetect(
  875. - const TfLiteObjectDetector* detector, const TfLiteFrameBuffer* frame_buffer,
  876. + const TfLiteObjectDetector* detector,
  877. + const TfLiteFrameBuffer* frame_buffer,
  878. TfLiteSupportError** error);
  879. // Disposes off the object detector.
  880. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/audio/audio_classifier_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/audio/audio_classifier_test.cc
  881. index 17b2a4ccede29..126784cf6c755 100644
  882. --- a/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/audio/audio_classifier_test.cc
  883. +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/audio/audio_classifier_test.cc
  884. @@ -45,9 +45,10 @@ constexpr char kYamNetAudioClassifierWithMetadata[] =
  885. "yamnet_audio_classifier_with_metadata.tflite";
  886. StatusOr<TfLiteAudioBuffer> LoadAudioBufferFromFileNamed(
  887. - const std::string wav_file, int buffer_size) {
  888. - std::string contents = ReadFile(
  889. - JoinPath("./" /*test src dir*/, kTestDataDirectory, wav_file));
  890. + const std::string wav_file,
  891. + int buffer_size) {
  892. + std::string contents =
  893. + ReadFile(JoinPath("./" /*test src dir*/, kTestDataDirectory, wav_file));
  894. uint32_t decoded_sample_count;
  895. uint16_t decoded_channel_count;
  896. @@ -90,7 +91,8 @@ void Verify(TfLiteClassificationResult* classification_result,
  897. }
  898. void Verify(TfLiteClassifications& classifications,
  899. - int expected_categories_size, int expected_head_index,
  900. + int expected_categories_size,
  901. + int expected_head_index,
  902. char const* expected_head_name) {
  903. EXPECT_EQ(classifications.size, expected_categories_size);
  904. EXPECT_EQ(classifications.head_index, expected_head_index);
  905. @@ -101,8 +103,10 @@ void Verify(TfLiteClassifications& classifications,
  906. EXPECT_NE(classifications.categories, nullptr);
  907. }
  908. -void Verify(TfLiteCategory& category, int expected_index,
  909. - char const* expected_label, float expected_score) {
  910. +void Verify(TfLiteCategory& category,
  911. + int expected_index,
  912. + char const* expected_label,
  913. + float expected_score) {
  914. const float kPrecision = 1e-6;
  915. EXPECT_EQ(category.index, expected_index);
  916. EXPECT_NE(category.label, nullptr);
  917. @@ -115,7 +119,8 @@ void Verify(TfLiteCategory& category, int expected_index,
  918. EXPECT_NEAR(category.score, expected_score, kPrecision);
  919. }
  920. -void Verify(TfLiteSupportError* error, TfLiteSupportErrorCode error_code,
  921. +void Verify(TfLiteSupportError* error,
  922. + TfLiteSupportErrorCode error_code,
  923. char const* message) {
  924. ASSERT_NE(error, nullptr);
  925. EXPECT_EQ(error->code, kInvalidArgumentError);
  926. @@ -133,7 +138,8 @@ TEST_F(AudioClassifierFromOptionsTest, FailsWithMissingModelPathAndError) {
  927. TfLiteAudioClassifierFromOptions(&options, &error);
  928. EXPECT_EQ(audio_classifier, nullptr);
  929. - if (audio_classifier) TfLiteAudioClassifierDelete(audio_classifier);
  930. + if (audio_classifier)
  931. + TfLiteAudioClassifierDelete(audio_classifier);
  932. Verify(error, kInvalidArgumentError,
  933. "INVALID_ARGUMENT: Missing mandatory `model_file` field in "
  934. @@ -143,9 +149,8 @@ TEST_F(AudioClassifierFromOptionsTest, FailsWithMissingModelPathAndError) {
  935. }
  936. TEST_F(AudioClassifierFromOptionsTest, SucceedsWithModelPath) {
  937. - std::string model_path =
  938. - JoinPath("./" /*test src dir*/, kTestDataDirectory,
  939. - kYamNetAudioClassifierWithMetadata);
  940. + std::string model_path = JoinPath("./" /*test src dir*/, kTestDataDirectory,
  941. + kYamNetAudioClassifierWithMetadata);
  942. TfLiteAudioClassifierOptions options = TfLiteAudioClassifierOptionsCreate();
  943. options.base_options.model_file.file_path = model_path.data();
  944. TfLiteAudioClassifier* audio_classifier =
  945. @@ -158,9 +163,8 @@ TEST_F(AudioClassifierFromOptionsTest, SucceedsWithModelPath) {
  946. class AudioClassifierClassifyTest : public tflite_shims::testing::Test {
  947. protected:
  948. void SetUp() override {
  949. - std::string model_path =
  950. - JoinPath("./" /*test src dir*/, kTestDataDirectory,
  951. - kYamNetAudioClassifierWithMetadata);
  952. + std::string model_path = JoinPath("./" /*test src dir*/, kTestDataDirectory,
  953. + kYamNetAudioClassifierWithMetadata);
  954. TfLiteAudioClassifierOptions options = TfLiteAudioClassifierOptionsCreate();
  955. options.base_options.model_file.file_path = model_path.data();
  956. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/vision/image_classifier_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/vision/image_classifier_test.cc
  957. index 0a59344f4394c..cce2fa63fad17 100644
  958. --- a/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/vision/image_classifier_test.cc
  959. +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/vision/image_classifier_test.cc
  960. @@ -44,8 +44,8 @@ constexpr char kMobileNetQuantizedWithMetadata[] =
  961. "mobilenet_v1_0.25_224_quant.tflite";
  962. StatusOr<ImageData> LoadImage(const char* image_name) {
  963. - return DecodeImageFromFile(JoinPath("./" /*test src dir*/,
  964. - kTestDataDirectory, image_name));
  965. + return DecodeImageFromFile(
  966. + JoinPath("./" /*test src dir*/, kTestDataDirectory, image_name));
  967. }
  968. class ImageClassifierFromOptionsTest : public tflite_shims::testing::Test {};
  969. @@ -56,7 +56,8 @@ TEST_F(ImageClassifierFromOptionsTest, FailsWithNullOptionsAndError) {
  970. TfLiteImageClassifierFromOptions(nullptr, &error);
  971. EXPECT_EQ(image_classifier, nullptr);
  972. - if (image_classifier) TfLiteImageClassifierDelete(image_classifier);
  973. + if (image_classifier)
  974. + TfLiteImageClassifierDelete(image_classifier);
  975. ASSERT_NE(error, nullptr);
  976. EXPECT_EQ(error->code, kInvalidArgumentError);
  977. @@ -71,7 +72,8 @@ TEST_F(ImageClassifierFromOptionsTest, FailsWithMissingModelPath) {
  978. TfLiteImageClassifier* image_classifier =
  979. TfLiteImageClassifierFromOptions(&options, nullptr);
  980. EXPECT_EQ(image_classifier, nullptr);
  981. - if (image_classifier) TfLiteImageClassifierDelete(image_classifier);
  982. + if (image_classifier)
  983. + TfLiteImageClassifierDelete(image_classifier);
  984. }
  985. TEST_F(ImageClassifierFromOptionsTest, FailsWithMissingModelPathAndError) {
  986. @@ -82,7 +84,8 @@ TEST_F(ImageClassifierFromOptionsTest, FailsWithMissingModelPathAndError) {
  987. TfLiteImageClassifierFromOptions(&options, &error);
  988. EXPECT_EQ(image_classifier, nullptr);
  989. - if (image_classifier) TfLiteImageClassifierDelete(image_classifier);
  990. + if (image_classifier)
  991. + TfLiteImageClassifierDelete(image_classifier);
  992. ASSERT_NE(error, nullptr);
  993. EXPECT_EQ(error->code, kInvalidArgumentError);
  994. @@ -93,9 +96,8 @@ TEST_F(ImageClassifierFromOptionsTest, FailsWithMissingModelPathAndError) {
  995. }
  996. TEST_F(ImageClassifierFromOptionsTest, SucceedsWithModelPath) {
  997. - std::string model_path =
  998. - JoinPath("./" /*test src dir*/, kTestDataDirectory,
  999. - kMobileNetQuantizedWithMetadata);
  1000. + std::string model_path = JoinPath("./" /*test src dir*/, kTestDataDirectory,
  1001. + kMobileNetQuantizedWithMetadata);
  1002. TfLiteImageClassifierOptions options = TfLiteImageClassifierOptionsCreate();
  1003. options.base_options.model_file.file_path = model_path.data();
  1004. TfLiteImageClassifier* image_classifier =
  1005. @@ -106,9 +108,8 @@ TEST_F(ImageClassifierFromOptionsTest, SucceedsWithModelPath) {
  1006. }
  1007. TEST_F(ImageClassifierFromOptionsTest, SucceedsWithNumberOfThreadsAndError) {
  1008. - std::string model_path =
  1009. - JoinPath("./" /*test src dir*/, kTestDataDirectory,
  1010. - kMobileNetQuantizedWithMetadata);
  1011. + std::string model_path = JoinPath("./" /*test src dir*/, kTestDataDirectory,
  1012. + kMobileNetQuantizedWithMetadata);
  1013. TfLiteImageClassifierOptions options = TfLiteImageClassifierOptionsCreate();
  1014. options.base_options.model_file.file_path = model_path.data();
  1015. options.base_options.compute_settings.cpu_settings.num_threads = 3;
  1016. @@ -120,15 +121,16 @@ TEST_F(ImageClassifierFromOptionsTest, SucceedsWithNumberOfThreadsAndError) {
  1017. EXPECT_NE(image_classifier, nullptr);
  1018. EXPECT_EQ(error, nullptr);
  1019. - if (image_classifier) TfLiteImageClassifierDelete(image_classifier);
  1020. - if (error) TfLiteSupportErrorDelete(error);
  1021. + if (image_classifier)
  1022. + TfLiteImageClassifierDelete(image_classifier);
  1023. + if (error)
  1024. + TfLiteSupportErrorDelete(error);
  1025. }
  1026. TEST_F(ImageClassifierFromOptionsTest,
  1027. FailsWithClassNameDenyListAndClassNameAllowListAndError) {
  1028. - std::string model_path =
  1029. - JoinPath("./" /*test src dir*/, kTestDataDirectory,
  1030. - kMobileNetQuantizedWithMetadata);
  1031. + std::string model_path = JoinPath("./" /*test src dir*/, kTestDataDirectory,
  1032. + kMobileNetQuantizedWithMetadata);
  1033. TfLiteImageClassifierOptions options = TfLiteImageClassifierOptionsCreate();
  1034. options.base_options.model_file.file_path = model_path.data();
  1035. @@ -146,7 +148,8 @@ TEST_F(ImageClassifierFromOptionsTest,
  1036. TfLiteImageClassifierFromOptions(&options, &error);
  1037. EXPECT_EQ(image_classifier, nullptr);
  1038. - if (image_classifier) TfLiteImageClassifierDelete(image_classifier);
  1039. + if (image_classifier)
  1040. + TfLiteImageClassifierDelete(image_classifier);
  1041. ASSERT_NE(error, nullptr);
  1042. EXPECT_EQ(error->code, kInvalidArgumentError);
  1043. @@ -158,7 +161,8 @@ TEST_F(ImageClassifierFromOptionsTest,
  1044. TEST(ImageClassifierNullClassifierClassifyTest,
  1045. FailsWithNullImageClassifierAndError) {
  1046. - SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data, LoadImage("burger-224.png"));
  1047. + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data,
  1048. + LoadImage("burger-224.png"));
  1049. TfLiteSupportError* error = nullptr;
  1050. TfLiteClassificationResult* classification_result =
  1051. @@ -181,9 +185,8 @@ TEST(ImageClassifierNullClassifierClassifyTest,
  1052. class ImageClassifierClassifyTest : public tflite_shims::testing::Test {
  1053. protected:
  1054. void SetUp() override {
  1055. - std::string model_path =
  1056. - JoinPath("./" /*test src dir*/, kTestDataDirectory,
  1057. - kMobileNetQuantizedWithMetadata);
  1058. + std::string model_path = JoinPath("./" /*test src dir*/, kTestDataDirectory,
  1059. + kMobileNetQuantizedWithMetadata);
  1060. TfLiteImageClassifierOptions options = TfLiteImageClassifierOptionsCreate();
  1061. options.base_options.model_file.file_path = model_path.data();
  1062. @@ -196,7 +199,8 @@ class ImageClassifierClassifyTest : public tflite_shims::testing::Test {
  1063. };
  1064. TEST_F(ImageClassifierClassifyTest, SucceedsWithImageData) {
  1065. - SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data, LoadImage("burger-224.png"));
  1066. + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data,
  1067. + LoadImage("burger-224.png"));
  1068. TfLiteFrameBuffer frame_buffer = {
  1069. .format = kRGB,
  1070. @@ -223,7 +227,8 @@ TEST_F(ImageClassifierClassifyTest, SucceedsWithImageData) {
  1071. }
  1072. TEST_F(ImageClassifierClassifyTest, FailsWithNullFrameBufferAndError) {
  1073. - SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data, LoadImage("burger-224.png"));
  1074. + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data,
  1075. + LoadImage("burger-224.png"));
  1076. TfLiteSupportError* error = nullptr;
  1077. TfLiteClassificationResult* classification_result =
  1078. @@ -244,7 +249,8 @@ TEST_F(ImageClassifierClassifyTest, FailsWithNullFrameBufferAndError) {
  1079. }
  1080. TEST_F(ImageClassifierClassifyTest, FailsWithNullImageDataAndError) {
  1081. - SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data, LoadImage("burger-224.png"));
  1082. + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data,
  1083. + LoadImage("burger-224.png"));
  1084. TfLiteFrameBuffer frame_buffer = {.format = kRGB, .orientation = kTopLeft};
  1085. @@ -267,7 +273,8 @@ TEST_F(ImageClassifierClassifyTest, FailsWithNullImageDataAndError) {
  1086. }
  1087. TEST_F(ImageClassifierClassifyTest, SucceedsWithRoiWithinImageBounds) {
  1088. - SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data, LoadImage("burger-224.png"));
  1089. + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data,
  1090. + LoadImage("burger-224.png"));
  1091. TfLiteFrameBuffer frame_buffer = {
  1092. .format = kRGB,
  1093. @@ -298,7 +305,8 @@ TEST_F(ImageClassifierClassifyTest, SucceedsWithRoiWithinImageBounds) {
  1094. }
  1095. TEST_F(ImageClassifierClassifyTest, FailsWithRoiOutsideImageBoundsAndError) {
  1096. - SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data, LoadImage("burger-224.png"));
  1097. + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data,
  1098. + LoadImage("burger-224.png"));
  1099. TfLiteFrameBuffer frame_buffer = {
  1100. .format = kRGB,
  1101. @@ -330,9 +338,8 @@ TEST_F(ImageClassifierClassifyTest, FailsWithRoiOutsideImageBoundsAndError) {
  1102. TEST(ImageClassifierWithUserDefinedOptionsClassifyTest,
  1103. SucceedsWithClassNameDenyList) {
  1104. char* denylisted_label_name = (char*)"cheeseburger";
  1105. - std::string model_path =
  1106. - JoinPath("./" /*test src dir*/, kTestDataDirectory,
  1107. - kMobileNetQuantizedWithMetadata);
  1108. + std::string model_path = JoinPath("./" /*test src dir*/, kTestDataDirectory,
  1109. + kMobileNetQuantizedWithMetadata);
  1110. TfLiteImageClassifierOptions options = TfLiteImageClassifierOptionsCreate();
  1111. options.base_options.model_file.file_path = model_path.data();
  1112. @@ -345,7 +352,8 @@ TEST(ImageClassifierWithUserDefinedOptionsClassifyTest,
  1113. TfLiteImageClassifierFromOptions(&options, nullptr);
  1114. ASSERT_NE(image_classifier, nullptr);
  1115. - SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data, LoadImage("burger-224.png"));
  1116. + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data,
  1117. + LoadImage("burger-224.png"));
  1118. TfLiteFrameBuffer frame_buffer = {
  1119. .format = kRGB,
  1120. @@ -357,7 +365,8 @@ TEST(ImageClassifierWithUserDefinedOptionsClassifyTest,
  1121. TfLiteImageClassifierClassify(image_classifier, &frame_buffer, nullptr);
  1122. ImageDataFree(&image_data);
  1123. - if (image_classifier) TfLiteImageClassifierDelete(image_classifier);
  1124. + if (image_classifier)
  1125. + TfLiteImageClassifierDelete(image_classifier);
  1126. ASSERT_NE(classification_result, nullptr);
  1127. EXPECT_GE(classification_result->size, 1);
  1128. @@ -374,10 +383,9 @@ TEST(ImageClassifierWithUserDefinedOptionsClassifyTest,
  1129. TEST(ImageClassifierWithUserDefinedOptionsClassifyTest,
  1130. SucceedsWithClassNameAllowList) {
  1131. char* allowlisted_label_name = (char*)"cheeseburger";
  1132. - std::string model_path =
  1133. - JoinPath("./" /*test src dir*/, kTestDataDirectory,
  1134. - kMobileNetQuantizedWithMetadata)
  1135. - .data();
  1136. + std::string model_path = JoinPath("./" /*test src dir*/, kTestDataDirectory,
  1137. + kMobileNetQuantizedWithMetadata)
  1138. + .data();
  1139. TfLiteImageClassifierOptions options = TfLiteImageClassifierOptionsCreate();
  1140. options.base_options.model_file.file_path = model_path.data();
  1141. @@ -390,7 +398,8 @@ TEST(ImageClassifierWithUserDefinedOptionsClassifyTest,
  1142. TfLiteImageClassifierFromOptions(&options, nullptr);
  1143. ASSERT_NE(image_classifier, nullptr);
  1144. - SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data, LoadImage("burger-224.png"));
  1145. + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data,
  1146. + LoadImage("burger-224.png"));
  1147. TfLiteFrameBuffer frame_buffer = {
  1148. .format = kRGB,
  1149. @@ -402,7 +411,8 @@ TEST(ImageClassifierWithUserDefinedOptionsClassifyTest,
  1150. TfLiteImageClassifierClassify(image_classifier, &frame_buffer, nullptr);
  1151. ImageDataFree(&image_data);
  1152. - if (image_classifier) TfLiteImageClassifierDelete(image_classifier);
  1153. + if (image_classifier)
  1154. + TfLiteImageClassifierDelete(image_classifier);
  1155. ASSERT_NE(classification_result, nullptr);
  1156. EXPECT_GE(classification_result->size, 1);
  1157. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/vision/image_segmenter_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/vision/image_segmenter_test.cc
  1158. index d4c8106b2729d..c03c15d6fe6b7 100644
  1159. --- a/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/vision/image_segmenter_test.cc
  1160. +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/vision/image_segmenter_test.cc
  1161. @@ -46,8 +46,8 @@ constexpr char kTestDataDirectory[] =
  1162. constexpr char kDeepLabV3[] = "deeplabv3.tflite";
  1163. StatusOr<ImageData> LoadImage(const char* image_name) {
  1164. - return DecodeImageFromFile(JoinPath("./" /*test src dir*/,
  1165. - kTestDataDirectory, image_name));
  1166. + return DecodeImageFromFile(
  1167. + JoinPath("./" /*test src dir*/, kTestDataDirectory, image_name));
  1168. }
  1169. // The maximum fraction of pixels in the candidate mask that can have a
  1170. @@ -59,8 +59,11 @@ constexpr float kGoldenMaskTolerance = 1e-2;
  1171. // 20 means class index 2, etc.
  1172. constexpr int kGoldenMaskMagnificationFactor = 10;
  1173. -void InitializeColoredLabel(TfLiteColoredLabel& colored_label, uint32_t r,
  1174. - uint32_t g, uint32_t b, const char* label) {
  1175. +void InitializeColoredLabel(TfLiteColoredLabel& colored_label,
  1176. + uint32_t r,
  1177. + uint32_t g,
  1178. + uint32_t b,
  1179. + const char* label) {
  1180. colored_label.r = r;
  1181. colored_label.g = g;
  1182. colored_label.b = b;
  1183. @@ -129,7 +132,8 @@ TEST_F(ImageSegmenterFromOptionsTest, FailsWithNullOptionsAndError) {
  1184. EXPECT_EQ(image_segmenter, nullptr);
  1185. - if (image_segmenter) TfLiteImageSegmenterDelete(image_segmenter);
  1186. + if (image_segmenter)
  1187. + TfLiteImageSegmenterDelete(image_segmenter);
  1188. ASSERT_NE(error, nullptr);
  1189. EXPECT_EQ(error->code, kInvalidArgumentError);
  1190. @@ -147,7 +151,8 @@ TEST_F(ImageSegmenterFromOptionsTest, FailsWithMissingModelPath) {
  1191. EXPECT_EQ(image_segmenter, nullptr);
  1192. - if (image_segmenter) TfLiteImageSegmenterDelete(image_segmenter);
  1193. + if (image_segmenter)
  1194. + TfLiteImageSegmenterDelete(image_segmenter);
  1195. }
  1196. TEST_F(ImageSegmenterFromOptionsTest, FailsWithMissingModelPathAndError) {
  1197. @@ -160,7 +165,8 @@ TEST_F(ImageSegmenterFromOptionsTest, FailsWithMissingModelPathAndError) {
  1198. EXPECT_EQ(image_segmenter, nullptr);
  1199. - if (image_segmenter) TfLiteImageSegmenterDelete(image_segmenter);
  1200. + if (image_segmenter)
  1201. + TfLiteImageSegmenterDelete(image_segmenter);
  1202. ASSERT_NE(error, nullptr);
  1203. EXPECT_EQ(error->code, kInvalidArgumentError);
  1204. @@ -171,8 +177,8 @@ TEST_F(ImageSegmenterFromOptionsTest, FailsWithMissingModelPathAndError) {
  1205. }
  1206. TEST_F(ImageSegmenterFromOptionsTest, SucceedsWithModelPath) {
  1207. - std::string model_path = JoinPath("./" /*test src dir*/,
  1208. - kTestDataDirectory, kDeepLabV3);
  1209. + std::string model_path =
  1210. + JoinPath("./" /*test src dir*/, kTestDataDirectory, kDeepLabV3);
  1211. TfLiteImageSegmenterOptions options = TfLiteImageSegmenterOptionsCreate();
  1212. options.base_options.model_file.file_path = model_path.data();
  1213. @@ -186,8 +192,8 @@ TEST_F(ImageSegmenterFromOptionsTest, SucceedsWithModelPath) {
  1214. }
  1215. TEST_F(ImageSegmenterFromOptionsTest, SucceedsWithNumberOfThreadsAndError) {
  1216. - std::string model_path = JoinPath("./" /*test src dir*/,
  1217. - kTestDataDirectory, kDeepLabV3);
  1218. + std::string model_path =
  1219. + JoinPath("./" /*test src dir*/, kTestDataDirectory, kDeepLabV3);
  1220. TfLiteImageSegmenterOptions options = TfLiteImageSegmenterOptionsCreate();
  1221. options.base_options.model_file.file_path = model_path.data();
  1222. @@ -200,13 +206,15 @@ TEST_F(ImageSegmenterFromOptionsTest, SucceedsWithNumberOfThreadsAndError) {
  1223. EXPECT_NE(image_segmenter, nullptr);
  1224. EXPECT_EQ(error, nullptr);
  1225. - if (image_segmenter) TfLiteImageSegmenterDelete(image_segmenter);
  1226. - if (error) TfLiteSupportErrorDelete(error);
  1227. + if (image_segmenter)
  1228. + TfLiteImageSegmenterDelete(image_segmenter);
  1229. + if (error)
  1230. + TfLiteSupportErrorDelete(error);
  1231. }
  1232. TEST_F(ImageSegmenterFromOptionsTest, FailsWithUnspecifiedOutputTypeAndError) {
  1233. - std::string model_path = JoinPath("./" /*test src dir*/,
  1234. - kTestDataDirectory, kDeepLabV3);
  1235. + std::string model_path =
  1236. + JoinPath("./" /*test src dir*/, kTestDataDirectory, kDeepLabV3);
  1237. TfLiteImageSegmenterOptions options = TfLiteImageSegmenterOptionsCreate();
  1238. options.base_options.model_file.file_path = model_path.data();
  1239. @@ -219,15 +227,17 @@ TEST_F(ImageSegmenterFromOptionsTest, FailsWithUnspecifiedOutputTypeAndError) {
  1240. EXPECT_EQ(image_segmenter, nullptr);
  1241. EXPECT_NE(error, nullptr);
  1242. - if (image_segmenter) TfLiteImageSegmenterDelete(image_segmenter);
  1243. - if (error) TfLiteSupportErrorDelete(error);
  1244. + if (image_segmenter)
  1245. + TfLiteImageSegmenterDelete(image_segmenter);
  1246. + if (error)
  1247. + TfLiteSupportErrorDelete(error);
  1248. }
  1249. class ImageSegmenterSegmentTest : public tflite_shims::testing::Test {
  1250. protected:
  1251. void SetUp() override {
  1252. - std::string model_path = JoinPath("./" /*test src dir*/,
  1253. - kTestDataDirectory, kDeepLabV3);
  1254. + std::string model_path =
  1255. + JoinPath("./" /*test src dir*/, kTestDataDirectory, kDeepLabV3);
  1256. TfLiteImageSegmenterOptions options = TfLiteImageSegmenterOptionsCreate();
  1257. options.base_options.model_file.file_path = model_path.data();
  1258. @@ -241,7 +251,7 @@ class ImageSegmenterSegmentTest : public tflite_shims::testing::Test {
  1259. TEST_F(ImageSegmenterSegmentTest, SucceedsWithCategoryMask) {
  1260. SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data,
  1261. - LoadImage("segmentation_input_rotation0.jpg"));
  1262. + LoadImage("segmentation_input_rotation0.jpg"));
  1263. TfLiteFrameBuffer frame_buffer = {
  1264. .format = kRGB,
  1265. @@ -264,7 +274,7 @@ TEST_F(ImageSegmenterSegmentTest, SucceedsWithCategoryMask) {
  1266. // Load golden mask output.
  1267. SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData golden_mask,
  1268. - LoadImage("segmentation_golden_rotation0.png"));
  1269. + LoadImage("segmentation_golden_rotation0.png"));
  1270. int inconsistent_pixels = 0;
  1271. int num_pixels = golden_mask.height * golden_mask.width;
  1272. @@ -285,8 +295,9 @@ TEST_F(ImageSegmenterSegmentTest, SucceedsWithCategoryMask) {
  1273. }
  1274. TEST_F(ImageSegmenterSegmentTest, SucceedsWithCategoryMaskAndOrientation) {
  1275. - SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data,
  1276. - LoadImage("segmentation_input_rotation90_flop.jpg"));
  1277. + SUPPORT_ASSERT_OK_AND_ASSIGN(
  1278. + ImageData image_data,
  1279. + LoadImage("segmentation_input_rotation90_flop.jpg"));
  1280. TfLiteFrameBuffer frame_buffer = {
  1281. .format = kRGB,
  1282. @@ -308,8 +319,9 @@ TEST_F(ImageSegmenterSegmentTest, SucceedsWithCategoryMaskAndOrientation) {
  1283. segmentation_result->segmentations[0]);
  1284. // Load golden mask output.
  1285. - SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData golden_mask,
  1286. - LoadImage("segmentation_golden_rotation90_flop.png"));
  1287. + SUPPORT_ASSERT_OK_AND_ASSIGN(
  1288. + ImageData golden_mask,
  1289. + LoadImage("segmentation_golden_rotation90_flop.png"));
  1290. int inconsistent_pixels = 0;
  1291. int num_pixels = golden_mask.height * golden_mask.width;
  1292. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/vision/object_detector_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/vision/object_detector_test.cc
  1293. index 0171e584fdd3d..78d78f5ddb6d1 100644
  1294. --- a/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/vision/object_detector_test.cc
  1295. +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/vision/object_detector_test.cc
  1296. @@ -46,8 +46,8 @@ constexpr char kMobileSsdWithMetadata[] =
  1297. "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite";
  1298. StatusOr<ImageData> LoadImage(const char* image_name) {
  1299. - return DecodeImageFromFile(JoinPath("./" /*test src dir*/,
  1300. - kTestDataDirectory, image_name));
  1301. + return DecodeImageFromFile(
  1302. + JoinPath("./" /*test src dir*/, kTestDataDirectory, image_name));
  1303. }
  1304. void VerifyDetection(const TfLiteDetection& detection,
  1305. @@ -96,7 +96,8 @@ TEST_F(ObjectDetectorFromOptionsTest, FailsWithNullOptionsAndError) {
  1306. TfLiteObjectDetectorFromOptions(nullptr, &error);
  1307. EXPECT_EQ(object_detector, nullptr);
  1308. - if (object_detector) TfLiteObjectDetectorDelete(object_detector);
  1309. + if (object_detector)
  1310. + TfLiteObjectDetectorDelete(object_detector);
  1311. ASSERT_NE(error, nullptr);
  1312. EXPECT_EQ(error->code, kInvalidArgumentError);
  1313. @@ -111,7 +112,8 @@ TEST_F(ObjectDetectorFromOptionsTest, FailsWithMissingModelPath) {
  1314. TfLiteObjectDetector* object_detector =
  1315. TfLiteObjectDetectorFromOptions(&options, nullptr);
  1316. EXPECT_EQ(object_detector, nullptr);
  1317. - if (object_detector) TfLiteObjectDetectorDelete(object_detector);
  1318. + if (object_detector)
  1319. + TfLiteObjectDetectorDelete(object_detector);
  1320. }
  1321. TEST_F(ObjectDetectorFromOptionsTest, FailsWithMissingModelPathAndError) {
  1322. @@ -122,7 +124,8 @@ TEST_F(ObjectDetectorFromOptionsTest, FailsWithMissingModelPathAndError) {
  1323. TfLiteObjectDetectorFromOptions(&options, &error);
  1324. EXPECT_EQ(object_detector, nullptr);
  1325. - if (object_detector) TfLiteObjectDetectorDelete(object_detector);
  1326. + if (object_detector)
  1327. + TfLiteObjectDetectorDelete(object_detector);
  1328. ASSERT_NE(error, nullptr);
  1329. EXPECT_EQ(error->code, kInvalidArgumentError);
  1330. @@ -133,8 +136,8 @@ TEST_F(ObjectDetectorFromOptionsTest, FailsWithMissingModelPathAndError) {
  1331. }
  1332. TEST_F(ObjectDetectorFromOptionsTest, SucceedsWithModelPath) {
  1333. - std::string model_path = JoinPath("./" /*test src dir*/,
  1334. - kTestDataDirectory, kMobileSsdWithMetadata);
  1335. + std::string model_path = JoinPath("./" /*test src dir*/, kTestDataDirectory,
  1336. + kMobileSsdWithMetadata);
  1337. TfLiteObjectDetectorOptions options = TfLiteObjectDetectorOptionsCreate();
  1338. options.base_options.model_file.file_path = model_path.data();
  1339. TfLiteObjectDetector* object_detector =
  1340. @@ -145,8 +148,8 @@ TEST_F(ObjectDetectorFromOptionsTest, SucceedsWithModelPath) {
  1341. }
  1342. TEST_F(ObjectDetectorFromOptionsTest, SucceedsWithNumberOfThreadsAndError) {
  1343. - std::string model_path = JoinPath("./" /*test src dir*/,
  1344. - kTestDataDirectory, kMobileSsdWithMetadata);
  1345. + std::string model_path = JoinPath("./" /*test src dir*/, kTestDataDirectory,
  1346. + kMobileSsdWithMetadata);
  1347. TfLiteObjectDetectorOptions options = TfLiteObjectDetectorOptionsCreate();
  1348. options.base_options.model_file.file_path = model_path.data();
  1349. options.base_options.compute_settings.cpu_settings.num_threads = 3;
  1350. @@ -158,14 +161,16 @@ TEST_F(ObjectDetectorFromOptionsTest, SucceedsWithNumberOfThreadsAndError) {
  1351. EXPECT_NE(object_detector, nullptr);
  1352. EXPECT_EQ(error, nullptr);
  1353. - if (object_detector) TfLiteObjectDetectorDelete(object_detector);
  1354. - if (error) TfLiteSupportErrorDelete(error);
  1355. + if (object_detector)
  1356. + TfLiteObjectDetectorDelete(object_detector);
  1357. + if (error)
  1358. + TfLiteSupportErrorDelete(error);
  1359. }
  1360. TEST_F(ObjectDetectorFromOptionsTest,
  1361. FailsWithClassNameDenyListAndClassNameAllowListAndError) {
  1362. - std::string model_path = JoinPath("./" /*test src dir*/,
  1363. - kTestDataDirectory, kMobileSsdWithMetadata);
  1364. + std::string model_path = JoinPath("./" /*test src dir*/, kTestDataDirectory,
  1365. + kMobileSsdWithMetadata);
  1366. TfLiteObjectDetectorOptions options = TfLiteObjectDetectorOptionsCreate();
  1367. options.base_options.model_file.file_path = model_path.data();
  1368. @@ -183,7 +188,8 @@ TEST_F(ObjectDetectorFromOptionsTest,
  1369. TfLiteObjectDetectorFromOptions(&options, &error);
  1370. EXPECT_EQ(object_detector, nullptr);
  1371. - if (object_detector) TfLiteObjectDetectorDelete(object_detector);
  1372. + if (object_detector)
  1373. + TfLiteObjectDetectorDelete(object_detector);
  1374. ASSERT_NE(error, nullptr);
  1375. EXPECT_EQ(error->code, kInvalidArgumentError);
  1376. @@ -195,7 +201,8 @@ TEST_F(ObjectDetectorFromOptionsTest,
  1377. TEST(ObjectDetectorNullDetectorDetectTest,
  1378. FailsWithNullObjectDetectorAndError) {
  1379. - SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data, LoadImage("cats_and_dogs.jpg"));
  1380. + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data,
  1381. + LoadImage("cats_and_dogs.jpg"));
  1382. TfLiteSupportError* error = nullptr;
  1383. TfLiteDetectionResult* detection_result =
  1384. @@ -204,7 +211,8 @@ TEST(ObjectDetectorNullDetectorDetectTest,
  1385. ImageDataFree(&image_data);
  1386. EXPECT_EQ(detection_result, nullptr);
  1387. - if (detection_result) TfLiteDetectionResultDelete(detection_result);
  1388. + if (detection_result)
  1389. + TfLiteDetectionResultDelete(detection_result);
  1390. ASSERT_NE(error, nullptr);
  1391. EXPECT_EQ(error->code, kInvalidArgumentError);
  1392. @@ -217,9 +225,8 @@ TEST(ObjectDetectorNullDetectorDetectTest,
  1393. class ObjectDetectorDetectTest : public tflite_shims::testing::Test {
  1394. protected:
  1395. void SetUp() override {
  1396. - std::string model_path =
  1397. - JoinPath("./" /*test src dir*/, kTestDataDirectory,
  1398. - kMobileSsdWithMetadata);
  1399. + std::string model_path = JoinPath("./" /*test src dir*/, kTestDataDirectory,
  1400. + kMobileSsdWithMetadata);
  1401. TfLiteObjectDetectorOptions options = TfLiteObjectDetectorOptionsCreate();
  1402. options.base_options.model_file.file_path = model_path.data();
  1403. @@ -232,7 +239,8 @@ class ObjectDetectorDetectTest : public tflite_shims::testing::Test {
  1404. };
  1405. TEST_F(ObjectDetectorDetectTest, SucceedsWithImageData) {
  1406. - SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data, LoadImage("cats_and_dogs.jpg"));
  1407. + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data,
  1408. + LoadImage("cats_and_dogs.jpg"));
  1409. TfLiteFrameBuffer frame_buffer = {
  1410. .format = kRGB,
  1411. @@ -251,7 +259,8 @@ TEST_F(ObjectDetectorDetectTest, SucceedsWithImageData) {
  1412. }
  1413. TEST_F(ObjectDetectorDetectTest, FailsWithNullFrameBufferAndError) {
  1414. - SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data, LoadImage("cats_and_dogs.jpg"));
  1415. + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data,
  1416. + LoadImage("cats_and_dogs.jpg"));
  1417. TfLiteSupportError* error = nullptr;
  1418. TfLiteDetectionResult* detection_result =
  1419. @@ -260,7 +269,8 @@ TEST_F(ObjectDetectorDetectTest, FailsWithNullFrameBufferAndError) {
  1420. ImageDataFree(&image_data);
  1421. EXPECT_EQ(detection_result, nullptr);
  1422. - if (detection_result) TfLiteDetectionResultDelete(detection_result);
  1423. + if (detection_result)
  1424. + TfLiteDetectionResultDelete(detection_result);
  1425. ASSERT_NE(error, nullptr);
  1426. EXPECT_EQ(error->code, kInvalidArgumentError);
  1427. @@ -271,7 +281,8 @@ TEST_F(ObjectDetectorDetectTest, FailsWithNullFrameBufferAndError) {
  1428. }
  1429. TEST_F(ObjectDetectorDetectTest, FailsWithNullImageDataAndError) {
  1430. - SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data, LoadImage("cats_and_dogs.jpg"));
  1431. + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data,
  1432. + LoadImage("cats_and_dogs.jpg"));
  1433. TfLiteSupportError* error = nullptr;
  1434. TfLiteDetectionResult* detection_result =
  1435. TfLiteObjectDetectorDetect(object_detector, nullptr, &error);
  1436. @@ -279,7 +290,8 @@ TEST_F(ObjectDetectorDetectTest, FailsWithNullImageDataAndError) {
  1437. ImageDataFree(&image_data);
  1438. EXPECT_EQ(detection_result, nullptr);
  1439. - if (detection_result) TfLiteDetectionResultDelete(detection_result);
  1440. + if (detection_result)
  1441. + TfLiteDetectionResultDelete(detection_result);
  1442. ASSERT_NE(error, nullptr);
  1443. EXPECT_EQ(error->code, kInvalidArgumentError);
  1444. @@ -292,8 +304,8 @@ TEST_F(ObjectDetectorDetectTest, FailsWithNullImageDataAndError) {
  1445. TEST(ObjectDetectorWithUserDefinedOptionsDetectorTest,
  1446. SucceedsWithClassNameDenyList) {
  1447. char* denylisted_label_name = (char*)"cat";
  1448. - std::string model_path = JoinPath("./" /*test src dir*/,
  1449. - kTestDataDirectory, kMobileSsdWithMetadata);
  1450. + std::string model_path = JoinPath("./" /*test src dir*/, kTestDataDirectory,
  1451. + kMobileSsdWithMetadata);
  1452. TfLiteObjectDetectorOptions options = TfLiteObjectDetectorOptionsCreate();
  1453. options.base_options.model_file.file_path = model_path.data();
  1454. @@ -306,7 +318,8 @@ TEST(ObjectDetectorWithUserDefinedOptionsDetectorTest,
  1455. TfLiteObjectDetectorFromOptions(&options, nullptr);
  1456. ASSERT_NE(object_detector, nullptr);
  1457. - SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data, LoadImage("cats_and_dogs.jpg"));
  1458. + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data,
  1459. + LoadImage("cats_and_dogs.jpg"));
  1460. TfLiteFrameBuffer frame_buffer = {
  1461. .format = kRGB,
  1462. @@ -318,7 +331,8 @@ TEST(ObjectDetectorWithUserDefinedOptionsDetectorTest,
  1463. TfLiteObjectDetectorDetect(object_detector, &frame_buffer, nullptr);
  1464. ImageDataFree(&image_data);
  1465. - if (object_detector) TfLiteObjectDetectorDelete(object_detector);
  1466. + if (object_detector)
  1467. + TfLiteObjectDetectorDelete(object_detector);
  1468. ASSERT_NE(detection_result, nullptr);
  1469. EXPECT_GE(detection_result->size, 1);
  1470. @@ -334,8 +348,8 @@ TEST(ObjectDetectorWithUserDefinedOptionsDetectorTest,
  1471. TEST(ObjectDetectorWithUserDefinedOptionsDetectorTest,
  1472. SucceedsWithClassNameAllowList) {
  1473. char* allowlisted_label_name = (char*)"cat";
  1474. - std::string model_path = JoinPath("./" /*test src dir*/,
  1475. - kTestDataDirectory, kMobileSsdWithMetadata)
  1476. + std::string model_path = JoinPath("./" /*test src dir*/, kTestDataDirectory,
  1477. + kMobileSsdWithMetadata)
  1478. .data();
  1479. TfLiteObjectDetectorOptions options = TfLiteObjectDetectorOptionsCreate();
  1480. @@ -349,7 +363,8 @@ TEST(ObjectDetectorWithUserDefinedOptionsDetectorTest,
  1481. TfLiteObjectDetectorFromOptions(&options, nullptr);
  1482. ASSERT_NE(object_detector, nullptr);
  1483. - SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data, LoadImage("cats_and_dogs.jpg"));
  1484. + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data,
  1485. + LoadImage("cats_and_dogs.jpg"));
  1486. TfLiteFrameBuffer frame_buffer = {
  1487. .format = kRGB,
  1488. @@ -361,7 +376,8 @@ TEST(ObjectDetectorWithUserDefinedOptionsDetectorTest,
  1489. TfLiteObjectDetectorDetect(object_detector, &frame_buffer, nullptr);
  1490. ImageDataFree(&image_data);
  1491. - if (object_detector) TfLiteObjectDetectorDelete(object_detector);
  1492. + if (object_detector)
  1493. + TfLiteObjectDetectorDelete(object_detector);
  1494. ASSERT_NE(detection_result, nullptr);
  1495. EXPECT_GE(detection_result->size, 1);
  1496. @@ -376,8 +392,8 @@ TEST(ObjectDetectorWithUserDefinedOptionsDetectorTest,
  1497. TEST(ObjectDetectorWithUserDefinedOptionsDetectorTest,
  1498. SucceedsWithScoreThreshold) {
  1499. - std::string model_path = JoinPath("./" /*test src dir*/,
  1500. - kTestDataDirectory, kMobileSsdWithMetadata)
  1501. + std::string model_path = JoinPath("./" /*test src dir*/, kTestDataDirectory,
  1502. + kMobileSsdWithMetadata)
  1503. .data();
  1504. TfLiteObjectDetectorOptions options = TfLiteObjectDetectorOptionsCreate();
  1505. @@ -389,7 +405,8 @@ TEST(ObjectDetectorWithUserDefinedOptionsDetectorTest,
  1506. TfLiteObjectDetectorFromOptions(&options, nullptr);
  1507. ASSERT_NE(object_detector, nullptr);
  1508. - SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data, LoadImage("cats_and_dogs.jpg"));
  1509. + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data,
  1510. + LoadImage("cats_and_dogs.jpg"));
  1511. TfLiteFrameBuffer frame_buffer = {
  1512. .format = kRGB,
  1513. @@ -401,7 +418,8 @@ TEST(ObjectDetectorWithUserDefinedOptionsDetectorTest,
  1514. TfLiteObjectDetectorDetect(object_detector, &frame_buffer, nullptr);
  1515. ImageDataFree(&image_data);
  1516. - if (object_detector) TfLiteObjectDetectorDelete(object_detector);
  1517. + if (object_detector)
  1518. + TfLiteObjectDetectorDelete(object_detector);
  1519. ASSERT_NE(detection_result, nullptr);
  1520. EXPECT_EQ(detection_result->size, 1);
  1521. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/common.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/common.cc
  1522. index abfef722d6659..09e9a83e07bef 100644
  1523. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/common.cc
  1524. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/common.cc
  1525. @@ -15,7 +15,7 @@ limitations under the License.
  1526. #include "tensorflow_lite_support/cc/common.h"
  1527. -#include "absl/strings/cord.h" // from @com_google_absl
  1528. +#include "absl/strings/cord.h" // from @com_google_absl
  1529. #include "absl/strings/str_cat.h" // from @com_google_absl
  1530. namespace tflite {
  1531. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/common.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/common.h
  1532. index b06e9f58459af..71dd920b86bed 100644
  1533. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/common.h
  1534. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/common.h
  1535. @@ -16,7 +16,7 @@ limitations under the License.
  1536. #ifndef TENSORFLOW_LITE_SUPPORT_CC_COMMON_H_
  1537. #define TENSORFLOW_LITE_SUPPORT_CC_COMMON_H_
  1538. -#include "absl/status/status.h" // from @com_google_absl
  1539. +#include "absl/status/status.h" // from @com_google_absl
  1540. #include "absl/strings/string_view.h" // from @com_google_absl
  1541. namespace tflite {
  1542. @@ -164,7 +164,8 @@ enum class TfLiteSupportStatus {
  1543. // more than returning an object identical to an OK status. See `absl::Status`
  1544. // for more details.
  1545. absl::Status CreateStatusWithPayload(
  1546. - absl::StatusCode canonical_code, absl::string_view message,
  1547. + absl::StatusCode canonical_code,
  1548. + absl::string_view message,
  1549. tflite::support::TfLiteSupportStatus tfls_code =
  1550. tflite::support::TfLiteSupportStatus::kError);
  1551. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/status_macros.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/status_macros.h
  1552. index 14999ca37b7ac..cb145dbd232c8 100644
  1553. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/status_macros.h
  1554. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/status_macros.h
  1555. @@ -18,7 +18,7 @@ limitations under the License.
  1556. #define TENSORFLOW_LITE_SUPPORT_CC_PORT_DEFAULT_STATUS_MACROS_H_
  1557. #include "absl/base/optimization.h" // from @com_google_absl
  1558. -#include "absl/status/status.h" // from @com_google_absl
  1559. +#include "absl/status/status.h" // from @com_google_absl
  1560. // Evaluates an expression that produces a `absl::Status`. If the status is not
  1561. // ok, returns it from the current function.
  1562. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/statusor_internals.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/statusor_internals.h
  1563. index dc04c293c6ffd..81ec3c1ab5f86 100644
  1564. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/statusor_internals.h
  1565. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/statusor_internals.h
  1566. @@ -21,8 +21,8 @@ limitations under the License.
  1567. #include <utility>
  1568. #include "absl/meta/type_traits.h" // from @com_google_absl
  1569. -#include "absl/status/status.h" // from @com_google_absl
  1570. -#include "absl/utility/utility.h" // from @com_google_absl
  1571. +#include "absl/status/status.h" // from @com_google_absl
  1572. +#include "absl/utility/utility.h" // from @com_google_absl
  1573. namespace tflite {
  1574. namespace support {
  1575. @@ -63,7 +63,8 @@ struct IsDirectInitializationAmbiguous
  1576. U>::value,
  1577. std::false_type,
  1578. IsDirectInitializationAmbiguous<
  1579. - T, absl::remove_cv_t<absl::remove_reference_t<U>>>> {};
  1580. + T,
  1581. + absl::remove_cv_t<absl::remove_reference_t<U>>>> {};
  1582. template <typename T, typename V>
  1583. struct IsDirectInitializationAmbiguous<T, tflite::support::StatusOr<V>>
  1584. @@ -101,7 +102,8 @@ struct IsForwardingAssignmentAmbiguous
  1585. U>::value,
  1586. std::false_type,
  1587. IsForwardingAssignmentAmbiguous<
  1588. - T, absl::remove_cv_t<absl::remove_reference_t<U>>>> {};
  1589. + T,
  1590. + absl::remove_cv_t<absl::remove_reference_t<U>>>> {};
  1591. template <typename T, typename U>
  1592. struct IsForwardingAssignmentAmbiguous<T, tflite::support::StatusOr<U>>
  1593. @@ -136,7 +138,8 @@ template <typename T, typename... Args>
  1594. void PlacementNew(void* p, Args&&... args) {
  1595. #if defined(__GNUC__) && !defined(__clang__)
  1596. // Teach gcc that 'p' cannot be null, fixing code size issues.
  1597. - if (p == nullptr) __builtin_unreachable();
  1598. + if (p == nullptr)
  1599. + __builtin_unreachable();
  1600. #endif
  1601. new (p) T(std::forward<Args>(args)...);
  1602. }
  1603. @@ -207,7 +210,8 @@ class StatusOrData {
  1604. }
  1605. StatusOrData& operator=(const StatusOrData& other) {
  1606. - if (this == &other) return *this;
  1607. + if (this == &other)
  1608. + return *this;
  1609. if (other.ok())
  1610. Assign(other.data_);
  1611. else
  1612. @@ -216,7 +220,8 @@ class StatusOrData {
  1613. }
  1614. StatusOrData& operator=(StatusOrData&& other) {
  1615. - if (this == &other) return *this;
  1616. + if (this == &other)
  1617. + return *this;
  1618. if (other.ok())
  1619. Assign(std::move(other.data_));
  1620. else
  1621. @@ -295,15 +300,18 @@ class StatusOrData {
  1622. };
  1623. void Clear() {
  1624. - if (ok()) data_.~T();
  1625. + if (ok())
  1626. + data_.~T();
  1627. }
  1628. void EnsureOk() const {
  1629. - if (ABSL_PREDICT_FALSE(!ok())) Helper::Crash(status_);
  1630. + if (ABSL_PREDICT_FALSE(!ok()))
  1631. + Helper::Crash(status_);
  1632. }
  1633. void EnsureNotOk() {
  1634. - if (ABSL_PREDICT_FALSE(ok())) Helper::HandleInvalidStatusCtorArg(&status_);
  1635. + if (ABSL_PREDICT_FALSE(ok()))
  1636. + Helper::HandleInvalidStatusCtorArg(&status_);
  1637. }
  1638. // Construct the value (ie. data_) through placement new with the passed
  1639. @@ -362,8 +370,9 @@ struct MoveCtorBase<T, false> {
  1640. MoveCtorBase& operator=(MoveCtorBase&&) = default;
  1641. };
  1642. -template <typename T, bool = std::is_copy_constructible<T>::value&&
  1643. - std::is_copy_assignable<T>::value>
  1644. +template <typename T,
  1645. + bool = std::is_copy_constructible<T>::value&&
  1646. + std::is_copy_assignable<T>::value>
  1647. struct CopyAssignBase {
  1648. CopyAssignBase() = default;
  1649. CopyAssignBase(const CopyAssignBase&) = default;
  1650. @@ -381,8 +390,9 @@ struct CopyAssignBase<T, false> {
  1651. CopyAssignBase& operator=(CopyAssignBase&&) = default;
  1652. };
  1653. -template <typename T, bool = std::is_move_constructible<T>::value&&
  1654. - std::is_move_assignable<T>::value>
  1655. +template <typename T,
  1656. + bool = std::is_move_constructible<T>::value&&
  1657. + std::is_move_assignable<T>::value>
  1658. struct MoveAssignBase {
  1659. MoveAssignBase() = default;
  1660. MoveAssignBase(const MoveAssignBase&) = default;
  1661. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/tflite_wrapper.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/tflite_wrapper.cc
  1662. index 11f9d584cfdd0..4d23efe43bc99 100644
  1663. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/tflite_wrapper.cc
  1664. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/tflite_wrapper.cc
  1665. @@ -15,7 +15,7 @@ limitations under the License.
  1666. #include "tensorflow_lite_support/cc/port/default/tflite_wrapper.h"
  1667. -#include "absl/status/status.h" // from @com_google_absl
  1668. +#include "absl/status/status.h" // from @com_google_absl
  1669. #include "absl/strings/str_format.h" // from @com_google_absl
  1670. #include "tensorflow/lite/c/common.h"
  1671. #include "tensorflow/lite/delegates/interpreter_utils.h"
  1672. @@ -310,7 +310,9 @@ absl::Status TfLiteInterpreterWrapper::InvokeWithoutFallback() {
  1673. return absl::OkStatus();
  1674. }
  1675. -void TfLiteInterpreterWrapper::Cancel() { cancel_flag_.Set(true); }
  1676. +void TfLiteInterpreterWrapper::Cancel() {
  1677. + cancel_flag_.Set(true);
  1678. +}
  1679. void TfLiteInterpreterWrapper::SetTfLiteCancellation() {
  1680. // Create a cancellation check function and set to the TFLite interpreter.
  1681. @@ -323,7 +325,8 @@ void TfLiteInterpreterWrapper::SetTfLiteCancellation() {
  1682. }
  1683. absl::Status TfLiteInterpreterWrapper::LoadDelegatePlugin(
  1684. - const std::string& name, const tflite::TFLiteSettings& tflite_settings) {
  1685. + const std::string& name,
  1686. + const tflite::TFLiteSettings& tflite_settings) {
  1687. delegate_plugin_ = DelegatePluginRegistry::CreateByName(
  1688. absl::StrFormat("%sPlugin", name), tflite_settings);
  1689. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/tflite_wrapper.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/tflite_wrapper.h
  1690. index 9a6fdebd99903..a9deed9f93521 100644
  1691. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/tflite_wrapper.h
  1692. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/tflite_wrapper.h
  1693. @@ -19,7 +19,7 @@ limitations under the License.
  1694. #include <string>
  1695. #include <utility>
  1696. -#include "absl/status/status.h" // from @com_google_absl
  1697. +#include "absl/status/status.h" // from @com_google_absl
  1698. #include "flatbuffers/flatbuffers.h" // from @flatbuffers
  1699. #include "tensorflow/lite/c/common.h"
  1700. #include "tensorflow/lite/experimental/acceleration/configuration/configuration.pb.h"
  1701. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/integral_types.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/integral_types.h
  1702. index 0d808ab24d6cc..dc6183bee693c 100644
  1703. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/integral_types.h
  1704. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/integral_types.h
  1705. @@ -37,7 +37,7 @@ typedef unsigned long uword_t;
  1706. #define GG_LL_FORMAT "ll" // As in "%lld". Note that "q" is poor form also.
  1707. #define GG_LL_FORMAT_W L"ll"
  1708. -const uint8 kuint8max{0xFF};
  1709. +const uint8 kuint8max{0xFF};
  1710. const uint16 kuint16max{0xFFFF};
  1711. const uint32 kuint32max{0xFFFFFFFF};
  1712. const uint64 kuint64max{GG_ULONGLONG(0xFFFFFFFFFFFFFFFF)};
  1713. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/audio_classifier.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/audio_classifier.cc
  1714. index 4b1439dcc0719..4be3e53c11972 100644
  1715. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/audio_classifier.cc
  1716. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/audio_classifier.cc
  1717. @@ -17,7 +17,7 @@ limitations under the License.
  1718. #include <initializer_list>
  1719. -#include "absl/status/status.h" // from @com_google_absl
  1720. +#include "absl/status/status.h" // from @com_google_absl
  1721. #include "absl/strings/str_format.h" // from @com_google_absl
  1722. #include "tensorflow/lite/c/c_api_types.h"
  1723. #include "tensorflow_lite_support/cc/common.h"
  1724. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/audio_embedder.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/audio_embedder.cc
  1725. index 56acada352121..a01effd031e29 100644
  1726. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/audio_embedder.cc
  1727. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/audio_embedder.cc
  1728. @@ -29,7 +29,8 @@ namespace audio {
  1729. /* static */
  1730. tflite::support::StatusOr<double> AudioEmbedder::CosineSimilarity(
  1731. - const processor::FeatureVector& u, const processor::FeatureVector& v) {
  1732. + const processor::FeatureVector& u,
  1733. + const processor::FeatureVector& v) {
  1734. return processor::EmbeddingPostprocessor::CosineSimilarity(u, v);
  1735. }
  1736. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/audio_embedder.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/audio_embedder.h
  1737. index f6df6d4d58552..4a139ee8bf82d 100644
  1738. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/audio_embedder.h
  1739. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/audio_embedder.h
  1740. @@ -27,9 +27,9 @@ limitations under the License.
  1741. namespace tflite {
  1742. namespace task {
  1743. namespace audio {
  1744. -class AudioEmbedder
  1745. - : public tflite::task::core::BaseTaskApi<
  1746. - tflite::task::processor::EmbeddingResult, const AudioBuffer&> {
  1747. +class AudioEmbedder : public tflite::task::core::BaseTaskApi<
  1748. + tflite::task::processor::EmbeddingResult,
  1749. + const AudioBuffer&> {
  1750. public:
  1751. // Use base class constructor.
  1752. using BaseTaskApi::BaseTaskApi;
  1753. @@ -41,7 +41,8 @@ class AudioEmbedder
  1754. //
  1755. // [1]: https://en.wikipedia.org/wiki/Cosine_similarity
  1756. static tflite::support::StatusOr<double> CosineSimilarity(
  1757. - const processor::FeatureVector& u, const processor::FeatureVector& v);
  1758. + const processor::FeatureVector& u,
  1759. + const processor::FeatureVector& v);
  1760. // Creates an AudioEmbedder from the provided options. A non-default
  1761. // OpResolver can be specified in order to support custom Ops or specify a
  1762. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/core/audio_buffer.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/core/audio_buffer.h
  1763. index 39110ed8d0b15..d922e48af25bc 100644
  1764. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/core/audio_buffer.h
  1765. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/core/audio_buffer.h
  1766. @@ -17,8 +17,8 @@ limitations under the License.
  1767. #include <memory>
  1768. -#include "absl/memory/memory.h" // from @com_google_absl
  1769. -#include "absl/status/status.h" // from @com_google_absl
  1770. +#include "absl/memory/memory.h" // from @com_google_absl
  1771. +#include "absl/status/status.h" // from @com_google_absl
  1772. #include "absl/strings/str_format.h" // from @com_google_absl
  1773. #include "tensorflow_lite_support/cc/common.h"
  1774. #include "tensorflow_lite_support/cc/port/statusor.h"
  1775. @@ -41,7 +41,8 @@ class AudioBuffer {
  1776. // Factory method for creating an AudioBuffer object. The internal buffer does
  1777. // not take the ownership of the input backing buffer.
  1778. static tflite::support::StatusOr<std::unique_ptr<AudioBuffer>> Create(
  1779. - const float* audio_buffer, int buffer_size,
  1780. + const float* audio_buffer,
  1781. + int buffer_size,
  1782. const AudioFormat& audio_format) {
  1783. return absl::make_unique<AudioBuffer>(audio_buffer, buffer_size,
  1784. audio_format);
  1785. @@ -50,7 +51,8 @@ class AudioBuffer {
  1786. // AudioBuffer for internal use only. Uses the factory method to construct
  1787. // AudioBuffer instance. The internal buffer does not take the ownership of
  1788. // the input backing buffer.
  1789. - AudioBuffer(const float* audio_buffer, int buffer_size,
  1790. + AudioBuffer(const float* audio_buffer,
  1791. + int buffer_size,
  1792. const AudioFormat& audio_format)
  1793. : audio_buffer_(audio_buffer),
  1794. buffer_size_(buffer_size),
  1795. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/utils/audio_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/utils/audio_utils.cc
  1796. index 1a27c6b44c1bf..c013759b13ebb 100644
  1797. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/utils/audio_utils.cc
  1798. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/utils/audio_utils.cc
  1799. @@ -20,7 +20,8 @@ namespace task {
  1800. namespace audio {
  1801. tflite::support::StatusOr<AudioBuffer> LoadAudioBufferFromFile(
  1802. - const std::string& wav_file_path, int buffer_size,
  1803. + const std::string& wav_file_path,
  1804. + int buffer_size,
  1805. std::vector<float>* wav_data) {
  1806. std::string contents = ReadFile(wav_file_path);
  1807. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/utils/audio_utils.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/utils/audio_utils.h
  1808. index 68880c0cb4072..123d5a1f6fbf7 100644
  1809. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/utils/audio_utils.h
  1810. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/utils/audio_utils.h
  1811. @@ -34,7 +34,8 @@ namespace audio {
  1812. // object, the user of this function has to make sure that wav_data outlives the
  1813. // returned AudioBuffer object.
  1814. tflite::support::StatusOr<AudioBuffer> LoadAudioBufferFromFile(
  1815. - const std::string& wav_file_path, int buffer_size,
  1816. + const std::string& wav_file_path,
  1817. + int buffer_size,
  1818. std::vector<float>* wav_data);
  1819. } // namespace audio
  1820. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/utils/wav_io.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/utils/wav_io.cc
  1821. index 3c0ad996a9919..9ae3fbec70543 100644
  1822. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/utils/wav_io.cc
  1823. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/utils/wav_io.cc
  1824. @@ -27,9 +27,9 @@ limitations under the License.
  1825. #include <fstream>
  1826. #include <limits>
  1827. -#include "absl/base/casts.h" // from @com_google_absl
  1828. -#include "absl/status/status.h" // from @com_google_absl
  1829. -#include "absl/strings/str_cat.h" // from @com_google_absl
  1830. +#include "absl/base/casts.h" // from @com_google_absl
  1831. +#include "absl/status/status.h" // from @com_google_absl
  1832. +#include "absl/strings/str_cat.h" // from @com_google_absl
  1833. #include "absl/strings/str_format.h" // from @com_google_absl
  1834. #include "tensorflow_lite_support/cc/port/status_macros.h"
  1835. @@ -62,7 +62,9 @@ std::string ReadFile(const std::string filepath) {
  1836. // Handles moving the data index forward, validating the arguments, and avoiding
  1837. // overflow or underflow.
  1838. -absl::Status IncrementOffset(int old_offset, size_t increment, size_t max_size,
  1839. +absl::Status IncrementOffset(int old_offset,
  1840. + size_t increment,
  1841. + size_t max_size,
  1842. int* new_offset) {
  1843. if (old_offset < 0) {
  1844. return absl::InvalidArgumentError(
  1845. @@ -87,7 +89,8 @@ absl::Status IncrementOffset(int old_offset, size_t increment, size_t max_size,
  1846. }
  1847. absl::Status ExpectText(const std::string& data,
  1848. - const std::string& expected_text, int* offset) {
  1849. + const std::string& expected_text,
  1850. + int* offset) {
  1851. int new_offset;
  1852. RETURN_IF_ERROR(
  1853. IncrementOffset(*offset, expected_text.size(), data.size(), &new_offset));
  1854. @@ -101,8 +104,10 @@ absl::Status ExpectText(const std::string& data,
  1855. return absl::OkStatus();
  1856. }
  1857. -absl::Status ReadString(const std::string& data, int expected_length,
  1858. - std::string* value, int* offset) {
  1859. +absl::Status ReadString(const std::string& data,
  1860. + int expected_length,
  1861. + std::string* value,
  1862. + int* offset) {
  1863. int new_offset;
  1864. RETURN_IF_ERROR(
  1865. IncrementOffset(*offset, expected_length, data.size(), &new_offset));
  1866. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/utils/wav_io.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/utils/wav_io.h
  1867. index 51271fc065c83..9aca5d06f7985 100644
  1868. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/utils/wav_io.h
  1869. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/utils/wav_io.h
  1870. @@ -20,9 +20,9 @@ limitations under the License.
  1871. #define TENSORFLOW_LITE_SUPPORT_CC_TASK_AUDIO_UTILS_WAV_IO_H_
  1872. +#include <cstdint>
  1873. #include <string>
  1874. #include <vector>
  1875. -#include <cstdint>
  1876. #include "absl/status/status.h" // from @com_google_absl
  1877. #include "tensorflow_lite_support/cc/port/status_macros.h"
  1878. @@ -64,7 +64,9 @@ absl::Status DecodeLin16WaveAsFloatVector(const std::string& wav_string,
  1879. // Handles moving the data index forward, validating the arguments, and avoiding
  1880. // overflow or underflow.
  1881. -absl::Status IncrementOffset(int old_offset, size_t increment, size_t max_size,
  1882. +absl::Status IncrementOffset(int old_offset,
  1883. + size_t increment,
  1884. + size_t max_size,
  1885. int* new_offset);
  1886. // This function is only exposed in the header for testing purposes, as a
  1887. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/base_task_api.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/base_task_api.h
  1888. index d743383734b42..effd42f0f0336 100644
  1889. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/base_task_api.h
  1890. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/base_task_api.h
  1891. @@ -18,7 +18,7 @@ limitations under the License.
  1892. #include <utility>
  1893. -#include "absl/status/status.h" // from @com_google_absl
  1894. +#include "absl/status/status.h" // from @com_google_absl
  1895. #include "absl/strings/string_view.h" // from @com_google_absl
  1896. #include "tensorflow/lite/c/common.h"
  1897. #include "tensorflow_lite_support/cc/common.h"
  1898. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/classification_head.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/classification_head.h
  1899. index c868060f9894a..c91552f7ec82e 100644
  1900. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/classification_head.h
  1901. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/classification_head.h
  1902. @@ -18,7 +18,7 @@ limitations under the License.
  1903. #include <string>
  1904. #include <vector>
  1905. -#include "absl/memory/memory.h" // from @com_google_absl
  1906. +#include "absl/memory/memory.h" // from @com_google_absl
  1907. #include "absl/strings/string_view.h" // from @com_google_absl
  1908. #include "tensorflow_lite_support/cc/port/statusor.h"
  1909. #include "tensorflow_lite_support/cc/task/core/label_map_item.h"
  1910. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/error_reporter.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/error_reporter.cc
  1911. index 80dea95cce24b..a626ce6030b96 100644
  1912. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/error_reporter.cc
  1913. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/error_reporter.cc
  1914. @@ -35,9 +35,13 @@ int ErrorReporter::Report(const char* format, va_list args) {
  1915. return num_characters;
  1916. }
  1917. -std::string ErrorReporter::message() { return last_message_; }
  1918. +std::string ErrorReporter::message() {
  1919. + return last_message_;
  1920. +}
  1921. -std::string ErrorReporter::previous_message() { return second_last_message_; }
  1922. +std::string ErrorReporter::previous_message() {
  1923. + return second_last_message_;
  1924. +}
  1925. } // namespace core
  1926. } // namespace task
  1927. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/external_file_handler.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/external_file_handler.cc
  1928. index 9c4cc2009baea..e15830d5ab061 100644
  1929. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/external_file_handler.cc
  1930. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/external_file_handler.cc
  1931. @@ -18,11 +18,11 @@ limitations under the License.
  1932. #include <memory>
  1933. #include <string>
  1934. -#include "absl/memory/memory.h" // from @com_google_absl
  1935. +#include "absl/memory/memory.h" // from @com_google_absl
  1936. #include "absl/strings/str_format.h" // from @com_google_absl
  1937. #include "tensorflow_lite_support/cc/common.h"
  1938. -#include "tensorflow_lite_support/cc/port/statusor.h"
  1939. #include "tensorflow_lite_support/cc/port/status_macros.h"
  1940. +#include "tensorflow_lite_support/cc/port/statusor.h"
  1941. namespace tflite {
  1942. namespace task {
  1943. @@ -57,11 +57,10 @@ absl::Status ExternalFileHandler::MapExternalFile() {
  1944. StatusCode::kInvalidArgument,
  1945. "ExternalFile must specify 'file_content' in Chromium.",
  1946. TfLiteSupportStatus::kInvalidArgumentError);
  1947. -
  1948. }
  1949. absl::string_view ExternalFileHandler::GetFileContent() {
  1950. - return external_file_.file_content();
  1951. + return external_file_.file_content();
  1952. }
  1953. ExternalFileHandler::~ExternalFileHandler() = default;
  1954. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/external_file_handler.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/external_file_handler.h
  1955. index a7daa175f77f5..9f35fdd6d09ce 100644
  1956. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/external_file_handler.h
  1957. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/external_file_handler.h
  1958. @@ -18,7 +18,7 @@ limitations under the License.
  1959. #include <memory>
  1960. -#include "absl/status/status.h" // from @com_google_absl
  1961. +#include "absl/status/status.h" // from @com_google_absl
  1962. #include "absl/strings/string_view.h" // from @com_google_absl
  1963. #include "tensorflow_lite_support/cc/port/integral_types.h"
  1964. #include "tensorflow_lite_support/cc/port/statusor.h"
  1965. @@ -64,7 +64,6 @@ class ExternalFileHandler {
  1966. // Reference to the input ExternalFile.
  1967. const ExternalFile& external_file_;
  1968. -
  1969. };
  1970. } // namespace core
  1971. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/label_map_item.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/label_map_item.cc
  1972. index 694c55ab34e78..72e4b670cb172 100644
  1973. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/label_map_item.cc
  1974. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/label_map_item.cc
  1975. @@ -15,7 +15,7 @@ limitations under the License.
  1976. #include "tensorflow_lite_support/cc/task/core/label_map_item.h"
  1977. #include "absl/strings/str_format.h" // from @com_google_absl
  1978. -#include "absl/strings/str_split.h" // from @com_google_absl
  1979. +#include "absl/strings/str_split.h" // from @com_google_absl
  1980. #include "tensorflow_lite_support/cc/common.h"
  1981. namespace tflite {
  1982. @@ -28,7 +28,8 @@ using ::tflite::support::StatusOr;
  1983. using ::tflite::support::TfLiteSupportStatus;
  1984. StatusOr<std::vector<LabelMapItem>> BuildLabelMapFromFiles(
  1985. - absl::string_view labels_file, absl::string_view display_names_file) {
  1986. + absl::string_view labels_file,
  1987. + absl::string_view display_names_file) {
  1988. if (labels_file.empty()) {
  1989. return CreateStatusWithPayload(StatusCode::kInvalidArgument,
  1990. "Expected non-empty labels file.",
  1991. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/label_map_item.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/label_map_item.h
  1992. index 4d8422a2a572d..d8e1f70d8fab1 100644
  1993. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/label_map_item.h
  1994. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/label_map_item.h
  1995. @@ -20,8 +20,8 @@ limitations under the License.
  1996. #include "absl/container/flat_hash_map.h" // from @com_google_absl
  1997. #include "absl/container/flat_hash_set.h" // from @com_google_absl
  1998. -#include "absl/status/status.h" // from @com_google_absl
  1999. -#include "absl/strings/string_view.h" // from @com_google_absl
  2000. +#include "absl/status/status.h" // from @com_google_absl
  2001. +#include "absl/strings/string_view.h" // from @com_google_absl
  2002. #include "tensorflow_lite_support/cc/port/statusor.h"
  2003. namespace tflite {
  2004. @@ -49,7 +49,8 @@ struct LabelMapItem {
  2005. // Returns an error e.g. if there's a mismatch between the number of labels and
  2006. // display names.
  2007. tflite::support::StatusOr<std::vector<LabelMapItem>> BuildLabelMapFromFiles(
  2008. - absl::string_view labels_file, absl::string_view display_names_file);
  2009. + absl::string_view labels_file,
  2010. + absl::string_view display_names_file);
  2011. // A class that represents a hierarchy of labels as specified in a label map.
  2012. //
  2013. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/score_calibration.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/score_calibration.cc
  2014. index 818839a77e43d..e7faebad487b9 100644
  2015. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/score_calibration.cc
  2016. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/score_calibration.cc
  2017. @@ -19,11 +19,11 @@ limitations under the License.
  2018. #include <utility>
  2019. #include <vector>
  2020. -#include "absl/status/status.h" // from @com_google_absl
  2021. -#include "absl/strings/str_format.h" // from @com_google_absl
  2022. -#include "absl/strings/str_split.h" // from @com_google_absl
  2023. +#include "absl/status/status.h" // from @com_google_absl
  2024. +#include "absl/strings/str_format.h" // from @com_google_absl
  2025. +#include "absl/strings/str_split.h" // from @com_google_absl
  2026. #include "absl/strings/string_view.h" // from @com_google_absl
  2027. -#include "absl/types/optional.h" // from @com_google_absl
  2028. +#include "absl/types/optional.h" // from @com_google_absl
  2029. #include "tensorflow_lite_support/cc/common.h"
  2030. #include "tensorflow_lite_support/cc/port/status_macros.h"
  2031. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/score_calibration.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/score_calibration.h
  2032. index c1b945f76ab48..6e2b308bef101 100644
  2033. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/score_calibration.h
  2034. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/score_calibration.h
  2035. @@ -23,9 +23,9 @@ limitations under the License.
  2036. #include <vector>
  2037. #include "absl/container/flat_hash_map.h" // from @com_google_absl
  2038. -#include "absl/status/status.h" // from @com_google_absl
  2039. -#include "absl/strings/string_view.h" // from @com_google_absl
  2040. -#include "absl/types/optional.h" // from @com_google_absl
  2041. +#include "absl/status/status.h" // from @com_google_absl
  2042. +#include "absl/strings/string_view.h" // from @com_google_absl
  2043. +#include "absl/types/optional.h" // from @com_google_absl
  2044. #include "tensorflow_lite_support/cc/port/statusor.h"
  2045. #include "tensorflow_lite_support/cc/task/core/label_map_item.h"
  2046. #include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
  2047. @@ -37,7 +37,10 @@ namespace core {
  2048. // Sigmoid structure.
  2049. struct Sigmoid {
  2050. Sigmoid() : scale(1.0) {}
  2051. - Sigmoid(std::string label, float slope, float offset, float scale = 1.0,
  2052. + Sigmoid(std::string label,
  2053. + float slope,
  2054. + float offset,
  2055. + float scale = 1.0,
  2056. absl::optional<float> min_uncalibrated_score = absl::nullopt)
  2057. : label(label),
  2058. slope(slope),
  2059. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/task_api_factory.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/task_api_factory.h
  2060. index 3d3bc801a6e5d..bbe549a802b39 100644
  2061. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/task_api_factory.h
  2062. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/task_api_factory.h
  2063. @@ -18,7 +18,7 @@ limitations under the License.
  2064. #include <memory>
  2065. -#include "absl/base/macros.h" // from @com_google_absl
  2066. +#include "absl/base/macros.h" // from @com_google_absl
  2067. #include "absl/status/status.h" // from @com_google_absl
  2068. #include "tensorflow/lite/core/api/op_resolver.h"
  2069. #include "tensorflow/lite/kernels/op_macros.h"
  2070. @@ -48,7 +48,8 @@ class TaskAPIFactory {
  2071. "Use CreateFromBaseOptions and configure model input from "
  2072. "tensorflow_lite_support/cc/task/core/proto/base_options.proto")
  2073. static tflite::support::StatusOr<std::unique_ptr<T>> CreateFromBuffer(
  2074. - const char* buffer_data, size_t buffer_size,
  2075. + const char* buffer_data,
  2076. + size_t buffer_size,
  2077. std::unique_ptr<tflite::OpResolver> resolver =
  2078. absl::make_unique<tflite_shims::ops::builtin::BuiltinOpResolver>(),
  2079. int num_threads = 1,
  2080. @@ -156,7 +157,8 @@ class TaskAPIFactory {
  2081. private:
  2082. template <typename T, EnableIfBaseUntypedTaskApiSubclass<T> = nullptr>
  2083. static tflite::support::StatusOr<std::unique_ptr<T>> CreateFromTfLiteEngine(
  2084. - std::unique_ptr<TfLiteEngine> engine, int num_threads,
  2085. + std::unique_ptr<TfLiteEngine> engine,
  2086. + int num_threads,
  2087. const tflite::proto::ComputeSettings& compute_settings =
  2088. tflite::proto::ComputeSettings()) {
  2089. tflite::proto::ComputeSettings settings_copy =
  2090. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/task_utils.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/task_utils.h
  2091. index 9c26d154634e1..2c21a95a1b075 100644
  2092. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/task_utils.h
  2093. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/task_utils.h
  2094. @@ -21,12 +21,12 @@ limitations under the License.
  2095. #include <numeric>
  2096. #include <vector>
  2097. -#include "absl/memory/memory.h" // from @com_google_absl
  2098. -#include "absl/status/status.h" // from @com_google_absl
  2099. -#include "absl/strings/str_cat.h" // from @com_google_absl
  2100. -#include "absl/strings/str_format.h" // from @com_google_absl
  2101. +#include "absl/memory/memory.h" // from @com_google_absl
  2102. +#include "absl/status/status.h" // from @com_google_absl
  2103. +#include "absl/strings/str_cat.h" // from @com_google_absl
  2104. +#include "absl/strings/str_format.h" // from @com_google_absl
  2105. #include "absl/strings/string_view.h" // from @com_google_absl
  2106. -#include "flatbuffers/flatbuffers.h" // from @flatbuffers
  2107. +#include "flatbuffers/flatbuffers.h" // from @flatbuffers
  2108. #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
  2109. #include "tensorflow/lite/kernels/op_macros.h"
  2110. #include "tensorflow/lite/string_util.h"
  2111. @@ -66,9 +66,11 @@ tflite::support::StatusOr<T*> AssertAndReturnTypedTensor(
  2112. // type or has not the same number of elements.
  2113. // Note: std::negation is not used because it is from C++17, where the code will
  2114. // be compiled using C++14 in OSS.
  2115. -template <typename T, typename = std::enable_if_t<
  2116. - std::is_same<T, std::string>::value == false>>
  2117. -inline absl::Status PopulateTensor(const T* data, int num_elements,
  2118. +template <
  2119. + typename T,
  2120. + typename = std::enable_if_t<std::is_same<T, std::string>::value == false>>
  2121. +inline absl::Status PopulateTensor(const T* data,
  2122. + int num_elements,
  2123. TfLiteTensor* tensor) {
  2124. T* v;
  2125. ASSIGN_OR_RETURN(v, AssertAndReturnTypedTensor<T>(tensor));
  2126. @@ -93,7 +95,8 @@ inline absl::Status PopulateTensor(const std::vector<T>& data,
  2127. template <>
  2128. inline absl::Status PopulateTensor<std::string>(
  2129. - const std::vector<std::string>& data, TfLiteTensor* tensor) {
  2130. + const std::vector<std::string>& data,
  2131. + TfLiteTensor* tensor) {
  2132. if (tensor->type != kTfLiteString) {
  2133. return tflite::support::CreateStatusWithPayload(
  2134. absl::StatusCode::kInternal,
  2135. @@ -144,7 +147,8 @@ inline absl::Status PopulateVector(const TfLiteTensor* tensor,
  2136. template <>
  2137. inline absl::Status PopulateVector<std::string>(
  2138. - const TfLiteTensor* tensor, std::vector<std::string>* data) {
  2139. + const TfLiteTensor* tensor,
  2140. + std::vector<std::string>* data) {
  2141. std::string* v __attribute__((unused));
  2142. ASSIGN_OR_RETURN(v, AssertAndReturnTypedTensor<std::string>(tensor));
  2143. int num = GetStringCount(tensor);
  2144. @@ -160,7 +164,8 @@ inline absl::Status PopulateVector<std::string>(
  2145. // Note: std::negation is not used because it is from C++17, where the code will
  2146. // be compiled using C++14 in OSS.
  2147. template <
  2148. - class TRepeatedField, class T = float,
  2149. + class TRepeatedField,
  2150. + class T = float,
  2151. typename = std::enable_if_t<std::is_same<T, std::string>::value == false>>
  2152. inline absl::Status PopulateVectorToRepeated(const TfLiteTensor* tensor,
  2153. TRepeatedField* data) {
  2154. @@ -236,7 +241,8 @@ int FindTensorIndexByName(
  2155. if (tensor_metadata != nullptr && tensor_metadata->size() == tensors.size()) {
  2156. int index =
  2157. FindTensorIndexByMetadataName(tensor_metadata, metadata_tensor_name);
  2158. - if (index > -1) return index;
  2159. + if (index > -1)
  2160. + return index;
  2161. }
  2162. return FindTensorIndexByModelName(tensors, model_tensor_name);
  2163. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/tflite_engine.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/tflite_engine.cc
  2164. index 5999090cab973..41e06389af80b 100644
  2165. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/tflite_engine.cc
  2166. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/tflite_engine.cc
  2167. @@ -17,7 +17,7 @@ limitations under the License.
  2168. #include <memory>
  2169. -#include "absl/strings/match.h" // from @com_google_absl
  2170. +#include "absl/strings/match.h" // from @com_google_absl
  2171. #include "absl/strings/str_cat.h" // from @com_google_absl
  2172. #include "tensorflow/lite/builtin_ops.h"
  2173. #include "tensorflow/lite/core/shims/cc/kernels/register.h"
  2174. @@ -38,7 +38,8 @@ using ::tflite::support::CreateStatusWithPayload;
  2175. using ::tflite::support::InterpreterCreationResources;
  2176. using ::tflite::support::TfLiteSupportStatus;
  2177. -bool TfLiteEngine::Verifier::Verify(const char* data, int length,
  2178. +bool TfLiteEngine::Verifier::Verify(const char* data,
  2179. + int length,
  2180. tflite::ErrorReporter* reporter) {
  2181. return tflite_shims::Verify(data, length, reporter);
  2182. }
  2183. @@ -69,7 +70,8 @@ std::vector<const TfLiteTensor*> TfLiteEngine::GetOutputs() {
  2184. }
  2185. void TfLiteEngine::VerifyAndBuildModelFromBuffer(
  2186. - const char* buffer_data, size_t buffer_size,
  2187. + const char* buffer_data,
  2188. + size_t buffer_size,
  2189. TfLiteVerifier* extra_verifier) {
  2190. model_ = tflite_shims::FlatBufferModel::VerifyAndBuildFromBuffer(
  2191. buffer_data, buffer_size, extra_verifier, &error_reporter_);
  2192. @@ -116,7 +118,8 @@ absl::Status TfLiteEngine::InitializeFromModelFileHandler(
  2193. }
  2194. absl::Status TfLiteEngine::BuildModelFromFlatBuffer(
  2195. - const char* buffer_data, size_t buffer_size,
  2196. + const char* buffer_data,
  2197. + size_t buffer_size,
  2198. const tflite::proto::ComputeSettings& compute_settings) {
  2199. if (model_) {
  2200. return CreateStatusWithPayload(StatusCode::kInternal,
  2201. @@ -205,7 +208,8 @@ absl::Status TfLiteEngine::InitInterpreter(int num_threads) {
  2202. // absl::Status TfLiteEngine::InitInterpreter(
  2203. // const tflite::proto::ComputeSettings& compute_settings)
  2204. absl::Status TfLiteEngine::InitInterpreter(
  2205. - const tflite::proto::ComputeSettings& compute_settings, int num_threads) {
  2206. + const tflite::proto::ComputeSettings& compute_settings,
  2207. + int num_threads) {
  2208. ComputeSettings settings_copy = ComputeSettings(compute_settings);
  2209. settings_copy.mutable_tflite_settings()
  2210. ->mutable_cpu_settings()
  2211. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/tflite_engine.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/tflite_engine.h
  2212. index 53dabdc4841d7..0cbaa738e6db6 100644
  2213. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/tflite_engine.h
  2214. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/tflite_engine.h
  2215. @@ -18,8 +18,8 @@ limitations under the License.
  2216. #include <memory>
  2217. -#include "absl/memory/memory.h" // from @com_google_absl
  2218. -#include "absl/status/status.h" // from @com_google_absl
  2219. +#include "absl/memory/memory.h" // from @com_google_absl
  2220. +#include "absl/status/status.h" // from @com_google_absl
  2221. #include "absl/strings/string_view.h" // from @com_google_absl
  2222. #include "tensorflow/lite/core/api/op_resolver.h"
  2223. #include "tensorflow/lite/core/shims/c/common.h"
  2224. @@ -96,7 +96,8 @@ class TfLiteEngine {
  2225. // object. This performs extra verification on the input data using
  2226. // tflite::Verify.
  2227. absl::Status BuildModelFromFlatBuffer(
  2228. - const char* buffer_data, size_t buffer_size,
  2229. + const char* buffer_data,
  2230. + size_t buffer_size,
  2231. const tflite::proto::ComputeSettings& compute_settings =
  2232. tflite::proto::ComputeSettings());
  2233. @@ -138,7 +139,8 @@ class TfLiteEngine {
  2234. // absl::Status TfLiteEngine::InitInterpreter(
  2235. // const tflite::proto::ComputeSettings& compute_settings)
  2236. absl::Status InitInterpreter(
  2237. - const tflite::proto::ComputeSettings& compute_settings, int num_threads);
  2238. + const tflite::proto::ComputeSettings& compute_settings,
  2239. + int num_threads);
  2240. // Cancels the on-going `Invoke()` call if any and if possible. This method
  2241. // can be called from a different thread than the one where `Invoke()` is
  2242. @@ -155,7 +157,8 @@ class TfLiteEngine {
  2243. // the FlatBuffer data provided as input.
  2244. class Verifier : public tflite::TfLiteVerifier {
  2245. public:
  2246. - bool Verify(const char* data, int length,
  2247. + bool Verify(const char* data,
  2248. + int length,
  2249. tflite::ErrorReporter* reporter) override;
  2250. };
  2251. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/audio_preprocessor.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/audio_preprocessor.cc
  2252. index e3ea2b134e3f4..254d0689e5ecc 100644
  2253. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/audio_preprocessor.cc
  2254. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/audio_preprocessor.cc
  2255. @@ -14,7 +14,7 @@ limitations under the License.
  2256. ==============================================================================*/
  2257. #include "tensorflow_lite_support/cc/task/processor/audio_preprocessor.h"
  2258. -#include "absl/status/status.h" // from @com_google_absl
  2259. +#include "absl/status/status.h" // from @com_google_absl
  2260. #include "absl/strings/str_format.h" // from @com_google_absl
  2261. #include "tensorflow_lite_support/cc/common.h"
  2262. #include "tensorflow_lite_support/cc/port/statusor.h"
  2263. @@ -29,7 +29,8 @@ namespace {
  2264. // Looks up AudioProperty from metadata. If no error occurs, the returned value
  2265. // is guaranteed to be valid (not null).
  2266. tflite::support::StatusOr<const AudioProperties*> GetAudioPropertiesSafe(
  2267. - const TensorMetadata* tensor_metadata, int input_index) {
  2268. + const TensorMetadata* tensor_metadata,
  2269. + int input_index) {
  2270. if (tensor_metadata->content() == nullptr ||
  2271. tensor_metadata->content()->content_properties() == nullptr) {
  2272. return CreateStatusWithPayload(
  2273. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/classification_postprocessor.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/classification_postprocessor.cc
  2274. index 9c11083c4f839..63962003f5e77 100644
  2275. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/classification_postprocessor.cc
  2276. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/classification_postprocessor.cc
  2277. @@ -17,7 +17,7 @@ limitations under the License.
  2278. #include <memory>
  2279. -#include "absl/status/status.h" // from @com_google_absl
  2280. +#include "absl/status/status.h" // from @com_google_absl
  2281. #include "absl/strings/str_format.h" // from @com_google_absl
  2282. #include "tensorflow/lite/c/c_api_types.h"
  2283. #include "tensorflow_lite_support/cc/port/status_macros.h"
  2284. @@ -42,7 +42,8 @@ using ::tflite::task::core::ScoreCalibration;
  2285. /* static */
  2286. tflite::support::StatusOr<std::unique_ptr<ClassificationPostprocessor>>
  2287. ClassificationPostprocessor::Create(
  2288. - core::TfLiteEngine* engine, const std::initializer_list<int> output_indices,
  2289. + core::TfLiteEngine* engine,
  2290. + const std::initializer_list<int> output_indices,
  2291. std::unique_ptr<ClassificationOptions> options) {
  2292. ASSIGN_OR_RETURN(auto processor,
  2293. Processor::Create<ClassificationPostprocessor>(
  2294. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/embedding_postprocessor.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/embedding_postprocessor.h
  2295. index 7863e3aa82fb7..f04048d84b4ce 100644
  2296. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/embedding_postprocessor.h
  2297. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/embedding_postprocessor.h
  2298. @@ -69,8 +69,8 @@ class EmbeddingPostprocessor : public Postprocessor {
  2299. // Performs actual cosine similarity computation.
  2300. template <typename T>
  2301. - static tflite::support::StatusOr<double> ComputeCosineSimilarity(
  2302. - const T* u, const T* v, int num_elements);
  2303. + static tflite::support::StatusOr<double>
  2304. + ComputeCosineSimilarity(const T* u, const T* v, int num_elements);
  2305. template <typename T>
  2306. void NormalizeFeatureVector(T* feature_vector) const;
  2307. @@ -146,7 +146,8 @@ void EmbeddingPostprocessor::QuantizeFeatureVector(T* feature_vector) const {
  2308. /* static */
  2309. template <typename T>
  2310. tflite::support::StatusOr<double>
  2311. -EmbeddingPostprocessor::ComputeCosineSimilarity(const T* u, const T* v,
  2312. +EmbeddingPostprocessor::ComputeCosineSimilarity(const T* u,
  2313. + const T* v,
  2314. int num_elements) {
  2315. if (num_elements <= 0) {
  2316. return CreateStatusWithPayload(
  2317. @@ -174,7 +175,8 @@ EmbeddingPostprocessor::ComputeCosineSimilarity(const T* u, const T* v,
  2318. /* static */
  2319. template <typename T>
  2320. tflite::support::StatusOr<double> EmbeddingPostprocessor::CosineSimilarity(
  2321. - const T& u, const T& v) {
  2322. + const T& u,
  2323. + const T& v) {
  2324. if (u.has_value_string() && v.has_value_string()) {
  2325. if (u.value_string().size() != v.value_string().size()) {
  2326. return CreateStatusWithPayload(
  2327. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/image_preprocessor.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/image_preprocessor.cc
  2328. index 7ad4ad4703789..310a1f5eba724 100644
  2329. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/image_preprocessor.cc
  2330. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/image_preprocessor.cc
  2331. @@ -36,7 +36,8 @@ using ::tflite::task::vision::FrameBuffer;
  2332. /* static */
  2333. tflite::support::StatusOr<std::unique_ptr<ImagePreprocessor>>
  2334. ImagePreprocessor::Create(
  2335. - core::TfLiteEngine* engine, const std::initializer_list<int> input_indices,
  2336. + core::TfLiteEngine* engine,
  2337. + const std::initializer_list<int> input_indices,
  2338. const vision::FrameBufferUtils::ProcessEngine& process_engine) {
  2339. ASSIGN_OR_RETURN(auto processor,
  2340. Processor::Create<ImagePreprocessor>(
  2341. @@ -49,7 +50,8 @@ ImagePreprocessor::Create(
  2342. // Returns false if image preprocessing could be skipped, true otherwise.
  2343. bool ImagePreprocessor::IsImagePreprocessingNeeded(
  2344. - const FrameBuffer& frame_buffer, const BoundingBox& roi) {
  2345. + const FrameBuffer& frame_buffer,
  2346. + const BoundingBox& roi) {
  2347. // Is crop required?
  2348. if (roi.origin_x() != 0 || roi.origin_y() != 0 ||
  2349. roi.width() != frame_buffer.dimension().width ||
  2350. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/processor.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/processor.h
  2351. index 4aad40b2afd97..b3c43605ac82e 100644
  2352. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/processor.h
  2353. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/processor.h
  2354. @@ -18,7 +18,7 @@ limitations under the License.
  2355. #include <initializer_list>
  2356. #include <vector>
  2357. -#include "absl/status/status.h" // from @com_google_absl
  2358. +#include "absl/status/status.h" // from @com_google_absl
  2359. #include "absl/strings/str_format.h" // from @com_google_absl
  2360. #include "tensorflow/lite/core/shims/c/common.h"
  2361. #include "tensorflow_lite_support/cc/common.h"
  2362. @@ -52,7 +52,8 @@ class Processor {
  2363. // num_expected_tensors, engine, tensor_indices);
  2364. template <typename T, EnableIfProcessorSubclass<T> = nullptr>
  2365. static tflite::support::StatusOr<std::unique_ptr<T>> Create(
  2366. - int num_expected_tensors, tflite::task::core::TfLiteEngine* engine,
  2367. + int num_expected_tensors,
  2368. + tflite::task::core::TfLiteEngine* engine,
  2369. const std::initializer_list<int> tensor_indices,
  2370. bool requires_metadata = true) {
  2371. auto processor = absl::make_unique<T>(engine, tensor_indices);
  2372. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/regex_preprocessor.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/regex_preprocessor.cc
  2373. index af923b4d6f2c1..58b77b6952de1 100644
  2374. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/regex_preprocessor.cc
  2375. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/regex_preprocessor.cc
  2376. @@ -55,7 +55,8 @@ StatusOr<absl::string_view> CheckAndLoadFirstAssociatedFile(
  2377. /* static */
  2378. StatusOr<std::unique_ptr<RegexPreprocessor>> RegexPreprocessor::Create(
  2379. - tflite::task::core::TfLiteEngine* engine, int input_tensor_index) {
  2380. + tflite::task::core::TfLiteEngine* engine,
  2381. + int input_tensor_index) {
  2382. ASSIGN_OR_RETURN(auto processor, Processor::Create<RegexPreprocessor>(
  2383. /* num_expected_tensors = */ 1, engine,
  2384. {input_tensor_index},
  2385. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/regex_preprocessor.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/regex_preprocessor.h
  2386. index 1f92bcc18e524..bdd4e5e207a12 100644
  2387. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/regex_preprocessor.h
  2388. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/regex_preprocessor.h
  2389. @@ -34,7 +34,8 @@ namespace processor {
  2390. class RegexPreprocessor : public TextPreprocessor {
  2391. public:
  2392. static tflite::support::StatusOr<std::unique_ptr<RegexPreprocessor>> Create(
  2393. - tflite::task::core::TfLiteEngine* engine, int input_tensor_index);
  2394. + tflite::task::core::TfLiteEngine* engine,
  2395. + int input_tensor_index);
  2396. absl::Status Preprocess(const std::string& text);
  2397. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/search_postprocessor.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/search_postprocessor.cc
  2398. index 730c9919cadee..a2fa1f8243199 100644
  2399. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/search_postprocessor.cc
  2400. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/search_postprocessor.cc
  2401. @@ -22,17 +22,12 @@ limitations under the License.
  2402. #include <memory>
  2403. #include <vector>
  2404. -#include "tensorflow_lite_support/scann_ondevice/cc/core/partitioner.h"
  2405. -#include "tensorflow_lite_support/scann_ondevice/cc/core/processor.h"
  2406. -#include "tensorflow_lite_support/scann_ondevice/cc/core/searcher.h"
  2407. -#include "tensorflow_lite_support/scann_ondevice/cc/core/serialized_searcher.pb.h"
  2408. -#include "tensorflow_lite_support/scann_ondevice/cc/core/top_n_amortized_constant.h"
  2409. -#include "absl/memory/memory.h" // from @com_google_absl
  2410. -#include "absl/status/status.h" // from @com_google_absl
  2411. -#include "absl/strings/str_format.h" // from @com_google_absl
  2412. +#include "Eigen/Core" // from @eigen
  2413. +#include "absl/memory/memory.h" // from @com_google_absl
  2414. +#include "absl/status/status.h" // from @com_google_absl
  2415. +#include "absl/strings/str_format.h" // from @com_google_absl
  2416. #include "absl/strings/string_view.h" // from @com_google_absl
  2417. -#include "absl/types/span.h" // from @com_google_absl
  2418. -#include "Eigen/Core" // from @eigen
  2419. +#include "absl/types/span.h" // from @com_google_absl
  2420. #include "tensorflow_lite_support/cc/common.h"
  2421. #include "tensorflow_lite_support/cc/port/status_macros.h"
  2422. #include "tensorflow_lite_support/cc/port/statusor.h"
  2423. @@ -45,6 +40,11 @@ limitations under the License.
  2424. #include "tensorflow_lite_support/cc/task/processor/proto/search_result.pb.h"
  2425. #include "tensorflow_lite_support/metadata/cc/metadata_extractor.h"
  2426. #include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
  2427. +#include "tensorflow_lite_support/scann_ondevice/cc/core/partitioner.h"
  2428. +#include "tensorflow_lite_support/scann_ondevice/cc/core/processor.h"
  2429. +#include "tensorflow_lite_support/scann_ondevice/cc/core/searcher.h"
  2430. +#include "tensorflow_lite_support/scann_ondevice/cc/core/serialized_searcher.pb.h"
  2431. +#include "tensorflow_lite_support/scann_ondevice/cc/core/top_n_amortized_constant.h"
  2432. #include "tensorflow_lite_support/scann_ondevice/cc/index.h"
  2433. #include "tensorflow_lite_support/scann_ondevice/proto/index_config.pb.h"
  2434. @@ -56,16 +56,16 @@ namespace {
  2435. constexpr int kNoNeighborId = -1;
  2436. +using ::tflite::TensorMetadata;
  2437. +using ::tflite::metadata::ModelMetadataExtractor;
  2438. +using ::tflite::scann_ondevice::Index;
  2439. +using ::tflite::scann_ondevice::IndexConfig;
  2440. using ::tflite::scann_ondevice::core::AsymmetricHashFindNeighbors;
  2441. using ::tflite::scann_ondevice::core::DistanceMeasure;
  2442. using ::tflite::scann_ondevice::core::FloatFindNeighbors;
  2443. using ::tflite::scann_ondevice::core::QueryInfo;
  2444. using ::tflite::scann_ondevice::core::ScannOnDeviceConfig;
  2445. using ::tflite::scann_ondevice::core::TopN;
  2446. -using ::tflite::TensorMetadata;
  2447. -using ::tflite::metadata::ModelMetadataExtractor;
  2448. -using ::tflite::scann_ondevice::Index;
  2449. -using ::tflite::scann_ondevice::IndexConfig;
  2450. using ::tflite::support::CreateStatusWithPayload;
  2451. using ::tflite::support::StatusOr;
  2452. using ::tflite::support::TfLiteSupportStatus;
  2453. @@ -212,7 +212,8 @@ absl::Status ConvertEmbeddingToEigenMatrix(const Embedding& embedding,
  2454. /* static */
  2455. StatusOr<std::unique_ptr<SearchPostprocessor>> SearchPostprocessor::Create(
  2456. - TfLiteEngine* engine, int output_index,
  2457. + TfLiteEngine* engine,
  2458. + int output_index,
  2459. std::unique_ptr<SearchOptions> search_options,
  2460. std::unique_ptr<EmbeddingOptions> embedding_options) {
  2461. ASSIGN_OR_RETURN(auto embedding_postprocessor,
  2462. @@ -316,7 +317,8 @@ absl::Status SearchPostprocessor::Init(
  2463. index_config_.scann_config().partitioner().search_fraction())),
  2464. partitioner_->NumPartitions());
  2465. } else {
  2466. - partitioner_ = absl::make_unique<tflite::scann_ondevice::core::NoOpPartitioner>();
  2467. + partitioner_ =
  2468. + absl::make_unique<tflite::scann_ondevice::core::NoOpPartitioner>();
  2469. num_leaves_to_search_ = partitioner_->NumPartitions();
  2470. }
  2471. @@ -330,7 +332,8 @@ absl::Status SearchPostprocessor::Init(
  2472. }
  2473. absl::Status SearchPostprocessor::QuantizedSearch(
  2474. - Eigen::Ref<Eigen::MatrixXf> query, std::vector<int> leaves_to_search,
  2475. + Eigen::Ref<Eigen::MatrixXf> query,
  2476. + std::vector<int> leaves_to_search,
  2477. absl::Span<TopN> top_n) {
  2478. int dim = index_config_.embedding_dim();
  2479. // Prepare QueryInfo used for all leaves.
  2480. @@ -360,7 +363,8 @@ absl::Status SearchPostprocessor::QuantizedSearch(
  2481. }
  2482. absl::Status SearchPostprocessor::LinearSearch(
  2483. - Eigen::Ref<Eigen::MatrixXf> query, std::vector<int> leaves_to_search,
  2484. + Eigen::Ref<Eigen::MatrixXf> query,
  2485. + std::vector<int> leaves_to_search,
  2486. absl::Span<TopN> top_n) {
  2487. int dim = index_config_.embedding_dim();
  2488. for (int leaf_id : leaves_to_search) {
  2489. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/search_postprocessor.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/search_postprocessor.h
  2490. index 47c78b64ba2ca..d79bc853148a9 100644
  2491. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/search_postprocessor.h
  2492. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/search_postprocessor.h
  2493. @@ -21,14 +21,9 @@ limitations under the License.
  2494. #include <memory>
  2495. #include <vector>
  2496. -#include "tensorflow_lite_support/scann_ondevice/cc/core/partitioner.h"
  2497. -#include "tensorflow_lite_support/scann_ondevice/cc/core/processor.h"
  2498. -#include "tensorflow_lite_support/scann_ondevice/cc/core/searcher.h"
  2499. -#include "tensorflow_lite_support/scann_ondevice/cc/core/serialized_searcher.pb.h"
  2500. -#include "tensorflow_lite_support/scann_ondevice/cc/core/top_n_amortized_constant.h"
  2501. +#include "Eigen/Core" // from @eigen
  2502. #include "absl/strings/string_view.h" // from @com_google_absl
  2503. -#include "absl/types/span.h" // from @com_google_absl
  2504. -#include "Eigen/Core" // from @eigen
  2505. +#include "absl/types/span.h" // from @com_google_absl
  2506. #include "tensorflow_lite_support/cc/port/statusor.h"
  2507. #include "tensorflow_lite_support/cc/task/core/external_file_handler.h"
  2508. #include "tensorflow_lite_support/cc/task/core/tflite_engine.h"
  2509. @@ -37,6 +32,11 @@ limitations under the License.
  2510. #include "tensorflow_lite_support/cc/task/processor/proto/embedding_options.pb.h"
  2511. #include "tensorflow_lite_support/cc/task/processor/proto/search_options.pb.h"
  2512. #include "tensorflow_lite_support/cc/task/processor/proto/search_result.pb.h"
  2513. +#include "tensorflow_lite_support/scann_ondevice/cc/core/partitioner.h"
  2514. +#include "tensorflow_lite_support/scann_ondevice/cc/core/processor.h"
  2515. +#include "tensorflow_lite_support/scann_ondevice/cc/core/searcher.h"
  2516. +#include "tensorflow_lite_support/scann_ondevice/cc/core/serialized_searcher.pb.h"
  2517. +#include "tensorflow_lite_support/scann_ondevice/cc/core/top_n_amortized_constant.h"
  2518. #include "tensorflow_lite_support/scann_ondevice/cc/index.h"
  2519. #include "tensorflow_lite_support/scann_ondevice/proto/index_config.pb.h"
  2520. @@ -55,7 +55,8 @@ namespace processor {
  2521. class SearchPostprocessor : public Postprocessor {
  2522. public:
  2523. static tflite::support::StatusOr<std::unique_ptr<SearchPostprocessor>> Create(
  2524. - tflite::task::core::TfLiteEngine* engine, int output_index,
  2525. + tflite::task::core::TfLiteEngine* engine,
  2526. + int output_index,
  2527. std::unique_ptr<SearchOptions> search_options,
  2528. std::unique_ptr<EmbeddingOptions> embedding_options =
  2529. std::make_unique<EmbeddingOptions>());
  2530. @@ -76,12 +77,14 @@ class SearchPostprocessor : public Postprocessor {
  2531. std::unique_ptr<EmbeddingPostprocessor> embedding_postprocessor,
  2532. std::unique_ptr<SearchOptions> options);
  2533. - absl::Status QuantizedSearch(Eigen::Ref<Eigen::MatrixXf> query,
  2534. - std::vector<int> leaves_to_search,
  2535. - absl::Span<tflite::scann_ondevice::core::TopN> top_n);
  2536. - absl::Status LinearSearch(Eigen::Ref<Eigen::MatrixXf> query,
  2537. - std::vector<int> leaves_to_search,
  2538. - absl::Span<tflite::scann_ondevice::core::TopN> top_n);
  2539. + absl::Status QuantizedSearch(
  2540. + Eigen::Ref<Eigen::MatrixXf> query,
  2541. + std::vector<int> leaves_to_search,
  2542. + absl::Span<tflite::scann_ondevice::core::TopN> top_n);
  2543. + absl::Status LinearSearch(
  2544. + Eigen::Ref<Eigen::MatrixXf> query,
  2545. + std::vector<int> leaves_to_search,
  2546. + absl::Span<tflite::scann_ondevice::core::TopN> top_n);
  2547. std::unique_ptr<SearchOptions> options_;
  2548. @@ -96,8 +99,10 @@ class SearchPostprocessor : public Postprocessor {
  2549. // ScaNN management.
  2550. int num_leaves_to_search_;
  2551. tflite::scann_ondevice::core::DistanceMeasure distance_measure_;
  2552. - std::unique_ptr<tflite::scann_ondevice::core::PartitionerInterface> partitioner_;
  2553. - std::shared_ptr<tflite::scann_ondevice::core::AsymmetricHashQuerier> quantizer_;
  2554. + std::unique_ptr<tflite::scann_ondevice::core::PartitionerInterface>
  2555. + partitioner_;
  2556. + std::shared_ptr<tflite::scann_ondevice::core::AsymmetricHashQuerier>
  2557. + quantizer_;
  2558. };
  2559. } // namespace processor
  2560. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_clu_annotator.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_clu_annotator.cc
  2561. index f60a556dbbe1b..802facec374f3 100644
  2562. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_clu_annotator.cc
  2563. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_clu_annotator.cc
  2564. @@ -19,9 +19,9 @@ limitations under the License.
  2565. #include <string>
  2566. #include <utility>
  2567. -#include "absl/status/status.h" // from @com_google_absl
  2568. +#include "absl/status/status.h" // from @com_google_absl
  2569. #include "absl/strings/str_format.h" // from @com_google_absl
  2570. -#include "absl/strings/str_split.h" // from @com_google_absl
  2571. +#include "absl/strings/str_split.h" // from @com_google_absl
  2572. #include "tensorflow_lite_support/cc/port/status_macros.h"
  2573. #include "tensorflow_lite_support/cc/task/core/task_api_factory.h"
  2574. #include "tensorflow_lite_support/cc/task/core/task_utils.h"
  2575. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_nl_classifier.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_nl_classifier.cc
  2576. index 52c898dacb9ca..d4481cdd17874 100644
  2577. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_nl_classifier.cc
  2578. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_nl_classifier.cc
  2579. @@ -57,7 +57,8 @@ absl::Status SanityCheckOptions(const BertNLClassifierOptions& options) {
  2580. } // namespace
  2581. absl::Status BertNLClassifier::Preprocess(
  2582. - const std::vector<TfLiteTensor*>& input_tensors, const std::string& input) {
  2583. + const std::vector<TfLiteTensor*>& input_tensors,
  2584. + const std::string& input) {
  2585. return preprocessor_->Preprocess(input);
  2586. }
  2587. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_nl_classifier.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_nl_classifier.h
  2588. index 4151025df917b..bcc9c5a533a3f 100644
  2589. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_nl_classifier.h
  2590. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_nl_classifier.h
  2591. @@ -22,7 +22,7 @@ limitations under the License.
  2592. #include <string>
  2593. #include <vector>
  2594. -#include "absl/base/macros.h" // from @com_google_absl
  2595. +#include "absl/base/macros.h" // from @com_google_absl
  2596. #include "absl/status/status.h" // from @com_google_absl
  2597. #include "tensorflow/lite/c/common.h"
  2598. #include "tensorflow/lite/core/api/op_resolver.h"
  2599. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_question_answerer.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_question_answerer.cc
  2600. index 6b37649d4fbfd..b886e3b362902 100644
  2601. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_question_answerer.cc
  2602. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_question_answerer.cc
  2603. @@ -15,7 +15,7 @@ limitations under the License.
  2604. #include "tensorflow_lite_support/cc/task/text/bert_question_answerer.h"
  2605. -#include "absl/status/status.h" // from @com_google_absl
  2606. +#include "absl/status/status.h" // from @com_google_absl
  2607. #include "absl/strings/str_join.h" // from @com_google_absl
  2608. #include "absl/strings/str_split.h" // from @com_google_absl
  2609. #include "tensorflow/lite/core/shims/cc/kernels/register.h"
  2610. @@ -111,7 +111,8 @@ StatusOr<std::unique_ptr<QuestionAnswerer>> BertQuestionAnswerer::CreateFromFd(
  2611. StatusOr<std::unique_ptr<QuestionAnswerer>>
  2612. BertQuestionAnswerer::CreateBertQuestionAnswererFromFile(
  2613. - const std::string& path_to_model, const std::string& path_to_vocab) {
  2614. + const std::string& path_to_model,
  2615. + const std::string& path_to_vocab) {
  2616. std::unique_ptr<BertQuestionAnswerer> api_to_init;
  2617. ASSIGN_OR_RETURN(
  2618. api_to_init,
  2619. @@ -125,8 +126,10 @@ BertQuestionAnswerer::CreateBertQuestionAnswererFromFile(
  2620. StatusOr<std::unique_ptr<QuestionAnswerer>>
  2621. BertQuestionAnswerer::CreateBertQuestionAnswererFromBuffer(
  2622. - const char* model_buffer_data, size_t model_buffer_size,
  2623. - const char* vocab_buffer_data, size_t vocab_buffer_size) {
  2624. + const char* model_buffer_data,
  2625. + size_t model_buffer_size,
  2626. + const char* vocab_buffer_data,
  2627. + size_t vocab_buffer_size) {
  2628. std::unique_ptr<BertQuestionAnswerer> api_to_init;
  2629. ASSIGN_OR_RETURN(
  2630. api_to_init,
  2631. @@ -141,7 +144,8 @@ BertQuestionAnswerer::CreateBertQuestionAnswererFromBuffer(
  2632. StatusOr<std::unique_ptr<QuestionAnswerer>>
  2633. BertQuestionAnswerer::CreateAlbertQuestionAnswererFromFile(
  2634. - const std::string& path_to_model, const std::string& path_to_spmodel) {
  2635. + const std::string& path_to_model,
  2636. + const std::string& path_to_spmodel) {
  2637. std::unique_ptr<BertQuestionAnswerer> api_to_init;
  2638. ASSIGN_OR_RETURN(
  2639. api_to_init,
  2640. @@ -155,8 +159,10 @@ BertQuestionAnswerer::CreateAlbertQuestionAnswererFromFile(
  2641. StatusOr<std::unique_ptr<QuestionAnswerer>>
  2642. BertQuestionAnswerer::CreateAlbertQuestionAnswererFromBuffer(
  2643. - const char* model_buffer_data, size_t model_buffer_size,
  2644. - const char* spmodel_buffer_data, size_t spmodel_buffer_size) {
  2645. + const char* model_buffer_data,
  2646. + size_t model_buffer_size,
  2647. + const char* spmodel_buffer_data,
  2648. + size_t spmodel_buffer_size) {
  2649. std::unique_ptr<BertQuestionAnswerer> api_to_init;
  2650. ASSIGN_OR_RETURN(
  2651. api_to_init,
  2652. @@ -170,14 +176,16 @@ BertQuestionAnswerer::CreateAlbertQuestionAnswererFromBuffer(
  2653. }
  2654. std::vector<QaAnswer> BertQuestionAnswerer::Answer(
  2655. - const std::string& context, const std::string& question) {
  2656. + const std::string& context,
  2657. + const std::string& question) {
  2658. // The BertQuestionAnswererer implementation for Preprocess() and
  2659. // Postprocess() never returns errors: just call value().
  2660. return Infer(context, question).value();
  2661. }
  2662. absl::Status BertQuestionAnswerer::Preprocess(
  2663. - const std::vector<TfLiteTensor*>& input_tensors, const std::string& context,
  2664. + const std::vector<TfLiteTensor*>& input_tensors,
  2665. + const std::string& context,
  2666. const std::string& query) {
  2667. auto* input_tensor_metadatas =
  2668. GetMetadataExtractor()->GetInputTensorMetadata();
  2669. @@ -392,7 +400,8 @@ void BertQuestionAnswerer::InitializeBertTokenizer(
  2670. }
  2671. void BertQuestionAnswerer::InitializeBertTokenizerFromBinary(
  2672. - const char* vocab_buffer_data, size_t vocab_buffer_size) {
  2673. + const char* vocab_buffer_data,
  2674. + size_t vocab_buffer_size) {
  2675. tokenizer_ =
  2676. absl::make_unique<BertTokenizer>(vocab_buffer_data, vocab_buffer_size);
  2677. }
  2678. @@ -403,7 +412,8 @@ void BertQuestionAnswerer::InitializeSentencepieceTokenizer(
  2679. }
  2680. void BertQuestionAnswerer::InitializeSentencepieceTokenizerFromBinary(
  2681. - const char* spmodel_buffer_data, size_t spmodel_buffer_size) {
  2682. + const char* spmodel_buffer_data,
  2683. + size_t spmodel_buffer_size) {
  2684. tokenizer_ = absl::make_unique<SentencePieceTokenizer>(spmodel_buffer_data,
  2685. spmodel_buffer_size);
  2686. }
  2687. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_question_answerer.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_question_answerer.h
  2688. index f041cc8e51637..52ec835371386 100644
  2689. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_question_answerer.h
  2690. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_question_answerer.h
  2691. @@ -16,9 +16,9 @@ limitations under the License.
  2692. #ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_QA_BERT_QUESTION_ANSWERER_H_
  2693. #define TENSORFLOW_LITE_SUPPORT_CC_TASK_QA_BERT_QUESTION_ANSWERER_H_
  2694. -#include "absl/base/macros.h" // from @com_google_absl
  2695. +#include "absl/base/macros.h" // from @com_google_absl
  2696. #include "absl/container/flat_hash_map.h" // from @com_google_absl
  2697. -#include "absl/status/status.h" // from @com_google_absl
  2698. +#include "absl/status/status.h" // from @com_google_absl
  2699. #include "tensorflow_lite_support/cc/port/statusor.h"
  2700. #include "tensorflow_lite_support/cc/task/core/base_task_api.h"
  2701. #include "tensorflow_lite_support/cc/task/core/task_api_factory.h"
  2702. @@ -136,7 +136,8 @@ class BertQuestionAnswerer : public QuestionAnswerer {
  2703. void InitializeSentencepieceTokenizer(const std::string& path_to_spmodel);
  2704. // Initialize API with a SentencepieceTokenizer from the model buffer.
  2705. void InitializeSentencepieceTokenizerFromBinary(
  2706. - const char* spmodel_buffer_data, size_t spmodel_buffer_size);
  2707. + const char* spmodel_buffer_data,
  2708. + size_t spmodel_buffer_size);
  2709. // Initialize the API with the tokenizer set in the metadata.
  2710. absl::Status InitializeFromMetadata(
  2711. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/bert_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/bert_utils.cc
  2712. index 0164bf48f156e..dc88aad9c2bdf 100644
  2713. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/bert_utils.cc
  2714. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/bert_utils.cc
  2715. @@ -17,8 +17,8 @@ limitations under the License.
  2716. #include <string>
  2717. -#include "absl/status/status.h" // from @com_google_absl
  2718. -#include "absl/strings/ascii.h" // from @com_google_absl
  2719. +#include "absl/status/status.h" // from @com_google_absl
  2720. +#include "absl/strings/ascii.h" // from @com_google_absl
  2721. #include "absl/strings/str_cat.h" // from @com_google_absl
  2722. #include "tensorflow_lite_support/cc/port/status_macros.h"
  2723. #include "tensorflow_lite_support/cc/task/text/clu_lib/constants.h"
  2724. @@ -46,10 +46,13 @@ constexpr int kTurnIdForCurrentUtterance = 0;
  2725. absl::Status BertPreprocessing(
  2726. const tflite::support::text::tokenizer::BertTokenizer* tokenizer,
  2727. const std::vector<absl::string_view>& utterances_in_reverse_order,
  2728. - int max_seq_length, int max_history_turns, std::vector<int>* out_token_ids,
  2729. + int max_seq_length,
  2730. + int max_history_turns,
  2731. + std::vector<int>* out_token_ids,
  2732. std::vector<std::pair<int, int>>* out_token_alignments,
  2733. std::vector<int>* out_token_first_subword_indicators,
  2734. - std::vector<int>* out_segment_id_list, std::vector<int>* out_turn_id_list) {
  2735. + std::vector<int>* out_segment_id_list,
  2736. + std::vector<int>* out_turn_id_list) {
  2737. int cls_id;
  2738. if (!tokenizer->LookupId(kClsToken, &cls_id)) {
  2739. return absl::InternalError(
  2740. @@ -183,7 +186,8 @@ absl::Status BertPreprocessing(
  2741. out_turn_id_list->push_back(turn_id);
  2742. // Break if reaching max_seq_length.
  2743. - if (out_token_ids->size() >= max_seq_length) break;
  2744. + if (out_token_ids->size() >= max_seq_length)
  2745. + break;
  2746. }
  2747. if (out_token_ids->size() != out_token_alignments->size()) {
  2748. return absl::InternalError(absl::StrCat(
  2749. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/bert_utils.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/bert_utils.h
  2750. index 69d13be6ce114..c3b3f6c4caf78 100644
  2751. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/bert_utils.h
  2752. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/bert_utils.h
  2753. @@ -80,10 +80,13 @@ namespace tflite::task::text::clu {
  2754. absl::Status BertPreprocessing(
  2755. const tflite::support::text::tokenizer::BertTokenizer* tokenizer,
  2756. const std::vector<absl::string_view>& utterances_in_reverse_order,
  2757. - int max_seq_length, int max_history_turns, std::vector<int>* out_token_ids,
  2758. + int max_seq_length,
  2759. + int max_history_turns,
  2760. + std::vector<int>* out_token_ids,
  2761. std::vector<std::pair<int, int>>* out_token_alignments,
  2762. std::vector<int>* out_token_first_subword_indicators,
  2763. - std::vector<int>* out_segment_id_list, std::vector<int>* out_turn_id_list);
  2764. + std::vector<int>* out_segment_id_list,
  2765. + std::vector<int>* out_turn_id_list);
  2766. } // namespace tflite::task::text::clu
  2767. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/intent_repr.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/intent_repr.cc
  2768. index b310a0782c69f..037566235cf7c 100644
  2769. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/intent_repr.cc
  2770. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/intent_repr.cc
  2771. @@ -17,9 +17,9 @@ limitations under the License.
  2772. #include <vector>
  2773. -#include "absl/status/status.h" // from @com_google_absl
  2774. -#include "absl/strings/str_cat.h" // from @com_google_absl
  2775. -#include "absl/strings/str_split.h" // from @com_google_absl
  2776. +#include "absl/status/status.h" // from @com_google_absl
  2777. +#include "absl/strings/str_cat.h" // from @com_google_absl
  2778. +#include "absl/strings/str_split.h" // from @com_google_absl
  2779. #include "absl/strings/string_view.h" // from @com_google_absl
  2780. #include "tensorflow_lite_support/cc/task/text/clu_lib/constants.h"
  2781. @@ -28,7 +28,8 @@ namespace tflite::task::text::clu {
  2782. // IntentRepr
  2783. std::string IntentRepr::FullName() const {
  2784. - if (domain_.empty()) return name_;
  2785. + if (domain_.empty())
  2786. + return name_;
  2787. return absl::StrCat(domain_, kNamespaceDelim, name_);
  2788. }
  2789. @@ -40,16 +41,19 @@ absl::StatusOr<IntentRepr> IntentRepr::CreateFromFullName(
  2790. if (splits.size() > 2) {
  2791. return absl::InternalError(absl::StrCat("invalid argument: ", full_name));
  2792. }
  2793. - if (splits.size() == 2) ret.domain_ = splits[0];
  2794. + if (splits.size() == 2)
  2795. + ret.domain_ = splits[0];
  2796. ret.name_ = splits[splits.size() - 1];
  2797. return ret;
  2798. }
  2799. -IntentRepr IntentRepr::Create(absl::string_view name, absl::string_view domain,
  2800. +IntentRepr IntentRepr::Create(absl::string_view name,
  2801. + absl::string_view domain,
  2802. const bool share_across_domains) {
  2803. IntentRepr ret;
  2804. ret.name_ = std::string(name);
  2805. - if (!share_across_domains) ret.domain_ = std::string(domain);
  2806. + if (!share_across_domains)
  2807. + ret.domain_ = std::string(domain);
  2808. return ret;
  2809. }
  2810. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/intent_repr.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/intent_repr.h
  2811. index 9084deb1203b4..e040b04d998ea 100644
  2812. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/intent_repr.h
  2813. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/intent_repr.h
  2814. @@ -18,7 +18,7 @@ limitations under the License.
  2815. #include <string>
  2816. -#include "absl/status/statusor.h" // from @com_google_absl
  2817. +#include "absl/status/statusor.h" // from @com_google_absl
  2818. #include "absl/strings/string_view.h" // from @com_google_absl
  2819. namespace tflite::task::text::clu {
  2820. @@ -30,7 +30,8 @@ class IntentRepr {
  2821. const std::string& Name() const { return name_; }
  2822. std::string FullName() const;
  2823. static absl::StatusOr<IntentRepr> CreateFromFullName(const absl::string_view);
  2824. - static IntentRepr Create(absl::string_view name, absl::string_view domain,
  2825. + static IntentRepr Create(absl::string_view name,
  2826. + absl::string_view domain,
  2827. const bool share_across_domains);
  2828. private:
  2829. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/slot_repr.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/slot_repr.cc
  2830. index 114a721ee40ef..dbb0dc2a14263 100644
  2831. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/slot_repr.cc
  2832. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/slot_repr.cc
  2833. @@ -20,15 +20,15 @@ limitations under the License.
  2834. #include <memory>
  2835. #include <vector>
  2836. -#include "absl/status/status.h" // from @com_google_absl
  2837. -#include "absl/status/statusor.h" // from @com_google_absl
  2838. -#include "absl/strings/match.h" // from @com_google_absl
  2839. -#include "absl/strings/str_cat.h" // from @com_google_absl
  2840. -#include "absl/strings/str_split.h" // from @com_google_absl
  2841. +#include "absl/status/status.h" // from @com_google_absl
  2842. +#include "absl/status/statusor.h" // from @com_google_absl
  2843. +#include "absl/strings/match.h" // from @com_google_absl
  2844. +#include "absl/strings/str_cat.h" // from @com_google_absl
  2845. +#include "absl/strings/str_split.h" // from @com_google_absl
  2846. #include "absl/strings/string_view.h" // from @com_google_absl
  2847. -#include "absl/strings/strip.h" // from @com_google_absl
  2848. -#include "absl/strings/substitute.h" // from @com_google_absl
  2849. -#include "absl/types/span.h" // from @com_google_absl
  2850. +#include "absl/strings/strip.h" // from @com_google_absl
  2851. +#include "absl/strings/substitute.h" // from @com_google_absl
  2852. +#include "absl/types/span.h" // from @com_google_absl
  2853. #include "tensorflow_lite_support/cc/port/status_macros.h"
  2854. #include "tensorflow_lite_support/cc/task/text/clu_lib/constants.h"
  2855. @@ -39,7 +39,8 @@ using ::absl::StatusOr;
  2856. // SlotRepr
  2857. std::string SlotRepr::FullName() const {
  2858. - if (domain_.empty()) return name_;
  2859. + if (domain_.empty())
  2860. + return name_;
  2861. return absl::StrCat(domain_, kNamespaceDelim, name_);
  2862. }
  2863. @@ -52,14 +53,16 @@ SlotRepr::SplitDomainAndName(const absl::string_view full_name) {
  2864. }
  2865. absl::string_view domain = "";
  2866. absl::string_view name;
  2867. - if (splits.size() == 2) domain = splits[0];
  2868. + if (splits.size() == 2)
  2869. + domain = splits[0];
  2870. name = splits[splits.size() - 1];
  2871. return std::tuple<absl::string_view, absl::string_view>{domain, name};
  2872. }
  2873. StatusOr<SlotRepr> SlotRepr::CreateFromIob(const absl::string_view repr) {
  2874. SlotRepr ret;
  2875. - if (IsO(repr)) return ret;
  2876. + if (IsO(repr))
  2877. + return ret;
  2878. absl::string_view full_name;
  2879. if (absl::StartsWith(repr, kSlotBTagPrefix)) {
  2880. full_name = absl::StripPrefix(repr, kSlotBTagPrefix);
  2881. @@ -76,7 +79,8 @@ StatusOr<SlotRepr> SlotRepr::CreateFromIob(const absl::string_view repr) {
  2882. return ret;
  2883. }
  2884. -SlotRepr SlotRepr::Create(absl::string_view name, absl::string_view domain,
  2885. +SlotRepr SlotRepr::Create(absl::string_view name,
  2886. + absl::string_view domain,
  2887. const bool share_across_domains) {
  2888. SlotRepr ret;
  2889. ret.name_ = std::string(name);
  2890. @@ -94,7 +98,9 @@ bool SlotRepr::IsB(const absl::string_view repr) {
  2891. return absl::StartsWith(repr, kSlotBTagPrefix);
  2892. }
  2893. -bool SlotRepr::IsO(const absl::string_view repr) { return repr == kSlotOTag; }
  2894. +bool SlotRepr::IsO(const absl::string_view repr) {
  2895. + return repr == kSlotOTag;
  2896. +}
  2897. bool SlotRepr::operator==(const SlotRepr& other) const {
  2898. return domain_ == other.domain_ && name_ == other.name_;
  2899. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/slot_repr.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/slot_repr.h
  2900. index 04ca49b268917..9a5f68a00bdcd 100644
  2901. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/slot_repr.h
  2902. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/slot_repr.h
  2903. @@ -20,9 +20,9 @@ limitations under the License.
  2904. #include <utility>
  2905. #include <vector>
  2906. -#include "absl/status/status.h" // from @com_google_absl
  2907. -#include "absl/status/statusor.h" // from @com_google_absl
  2908. -#include "absl/strings/str_cat.h" // from @com_google_absl
  2909. +#include "absl/status/status.h" // from @com_google_absl
  2910. +#include "absl/status/statusor.h" // from @com_google_absl
  2911. +#include "absl/strings/str_cat.h" // from @com_google_absl
  2912. #include "absl/strings/string_view.h" // from @com_google_absl
  2913. #include "tensorflow_lite_support/cc/task/text/clu_lib/constants.h"
  2914. @@ -68,7 +68,8 @@ class SlotRepr {
  2915. static absl::StatusOr<SlotRepr> CreateFromIob(const absl::string_view);
  2916. // Factory
  2917. - static SlotRepr Create(absl::string_view name, absl::string_view domain = "",
  2918. + static SlotRepr Create(absl::string_view name,
  2919. + absl::string_view domain = "",
  2920. const bool share_across_domains = true);
  2921. // Splits the full_name into domain and slot name.
  2922. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/slot_tagging_output.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/slot_tagging_output.cc
  2923. index 716d29b76a98f..0d5abb443fcc3 100644
  2924. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/slot_tagging_output.cc
  2925. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/slot_tagging_output.cc
  2926. @@ -17,10 +17,10 @@ limitations under the License.
  2927. #include <vector>
  2928. -#include "absl/status/status.h" // from @com_google_absl
  2929. -#include "absl/status/statusor.h" // from @com_google_absl
  2930. +#include "absl/status/status.h" // from @com_google_absl
  2931. +#include "absl/status/statusor.h" // from @com_google_absl
  2932. #include "absl/strings/string_view.h" // from @com_google_absl
  2933. -#include "absl/types/span.h" // from @com_google_absl
  2934. +#include "absl/types/span.h" // from @com_google_absl
  2935. #include "tensorflow_lite_support/cc/port/status_macros.h"
  2936. #include "tensorflow_lite_support/cc/task/text/clu_lib/slot_repr.h"
  2937. @@ -29,7 +29,9 @@ namespace {
  2938. absl::StatusOr<std::vector<SlotMentionStruct>>
  2939. DecodeSlotChunksPredictOnFirstSubword(
  2940. - int cur_turn_start, int cur_turn_end, int seq_len,
  2941. + int cur_turn_start,
  2942. + int cur_turn_end,
  2943. + int seq_len,
  2944. const absl::Span<const absl::string_view> tags_as_span,
  2945. const absl::Span<const float> confidences_as_span,
  2946. const absl::Span<const std::pair<int, int>> token_alignments_as_span,
  2947. @@ -74,10 +76,12 @@ DecodeSlotChunksPredictOnFirstSubword(
  2948. } // namespace
  2949. absl::Status SlotModulePopulateResponse(
  2950. - const std::vector<absl::string_view>& tags, const float* confidences,
  2951. + const std::vector<absl::string_view>& tags,
  2952. + const float* confidences,
  2953. const std::vector<std::pair<int, int>>& token_alignments,
  2954. const std::vector<int>& token_turn_ids,
  2955. - const std::vector<int>& first_subword_indicators, float threshold,
  2956. + const std::vector<int>& first_subword_indicators,
  2957. + float threshold,
  2958. const std::vector<absl::string_view>& reverse_utterance_list_to_encode,
  2959. CluResponse* response) {
  2960. if (token_alignments.size() != token_turn_ids.size()) {
  2961. @@ -104,7 +108,7 @@ absl::Status SlotModulePopulateResponse(
  2962. // Prepare the data and decode slot chunks.
  2963. std::vector<SlotMentionStruct> cur_turn_slot_mentions;
  2964. - // Decode slot chunks based on first subword tokens in the turn.
  2965. + // Decode slot chunks based on first subword tokens in the turn.
  2966. ASSIGN_OR_RETURN(cur_turn_slot_mentions,
  2967. DecodeSlotChunksPredictOnFirstSubword(
  2968. cur_turn_start, cur_turn_end, seq_len, tags_as_span,
  2969. @@ -113,8 +117,10 @@ absl::Status SlotModulePopulateResponse(
  2970. // Populate the response.
  2971. for (const auto& chunk : cur_turn_slot_mentions) {
  2972. - if (chunk.start == -1 || cur_turn_idx != 0) continue;
  2973. - if (chunk.confidence < threshold) continue;
  2974. + if (chunk.start == -1 || cur_turn_idx != 0)
  2975. + continue;
  2976. + if (chunk.confidence < threshold)
  2977. + continue;
  2978. auto slot = response->mutable_noncategorical_slots()->Add();
  2979. slot->set_slot(chunk.repr.Name());
  2980. auto extraction = slot->mutable_extraction();
  2981. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/slot_tagging_output.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/slot_tagging_output.h
  2982. index b8fc64425634e..7d2b9a1a1fd27 100644
  2983. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/slot_tagging_output.h
  2984. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/slot_tagging_output.h
  2985. @@ -41,10 +41,12 @@ namespace tflite::task::text::clu {
  2986. // Outputs:
  2987. // response
  2988. absl::Status SlotModulePopulateResponse(
  2989. - const std::vector<absl::string_view>& tags, const float* confidences,
  2990. + const std::vector<absl::string_view>& tags,
  2991. + const float* confidences,
  2992. const std::vector<std::pair<int, int>>& token_alignments,
  2993. const std::vector<int>& token_turn_ids,
  2994. - const std::vector<int>& first_subword_indicators, float threshold,
  2995. + const std::vector<int>& first_subword_indicators,
  2996. + float threshold,
  2997. const std::vector<absl::string_view>& reverse_utterance_list_to_encode,
  2998. CluResponse* response);
  2999. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/tflite_modules.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/tflite_modules.cc
  3000. index c16f5bc02b861..f893f0341c903 100644
  3001. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/tflite_modules.cc
  3002. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/tflite_modules.cc
  3003. @@ -18,11 +18,11 @@ limitations under the License.
  3004. #include <memory>
  3005. #include <utility>
  3006. -#include "absl/status/status.h" // from @com_google_absl
  3007. -#include "absl/status/statusor.h" // from @com_google_absl
  3008. -#include "absl/strings/str_cat.h" // from @com_google_absl
  3009. -#include "absl/strings/str_join.h" // from @com_google_absl
  3010. -#include "absl/strings/str_split.h" // from @com_google_absl
  3011. +#include "absl/status/status.h" // from @com_google_absl
  3012. +#include "absl/status/statusor.h" // from @com_google_absl
  3013. +#include "absl/strings/str_cat.h" // from @com_google_absl
  3014. +#include "absl/strings/str_join.h" // from @com_google_absl
  3015. +#include "absl/strings/str_split.h" // from @com_google_absl
  3016. #include "absl/strings/string_view.h" // from @com_google_absl
  3017. #include "tensorflow/lite/kernels/kernel_util.h"
  3018. #include "tensorflow/lite/string_util.h"
  3019. @@ -39,10 +39,14 @@ namespace tflite::task::text::clu {
  3020. // tensors by concatenating the current utterance with history turns. It also
  3021. // sets utterance_turn_id_seq for post-processing.
  3022. absl::Status PopulateInputTextTensorForBERT(
  3023. - const CluRequest& request, int token_id_tensor_idx,
  3024. - int token_mask_tensor_idx, int token_type_id_tensor_idx,
  3025. + const CluRequest& request,
  3026. + int token_id_tensor_idx,
  3027. + int token_mask_tensor_idx,
  3028. + int token_type_id_tensor_idx,
  3029. const tflite::support::text::tokenizer::BertTokenizer* tokenizer,
  3030. - size_t max_seq_len, int max_history_turns, tflite::Interpreter* interpreter,
  3031. + size_t max_seq_len,
  3032. + int max_history_turns,
  3033. + tflite::Interpreter* interpreter,
  3034. Artifacts* artifacts) {
  3035. size_t seq_len;
  3036. int64_t* tokens_tensor =
  3037. @@ -139,7 +143,8 @@ absl::Status AbstractModule::Init(tflite::Interpreter* interpreter,
  3038. }
  3039. absl::StatusOr<std::unique_ptr<AbstractModule>> UtteranceSeqModule::Create(
  3040. - tflite::Interpreter* interpreter, const TensorIndexMap* tensor_index_map,
  3041. + tflite::Interpreter* interpreter,
  3042. + const TensorIndexMap* tensor_index_map,
  3043. const BertCluAnnotatorOptions* options,
  3044. const tflite::support::text::tokenizer::BertTokenizer* tokenizer) {
  3045. auto out = std::make_unique<UtteranceSeqModule>();
  3046. @@ -187,7 +192,8 @@ AbstractModule::NamesAndConfidencesFromOutput(int names_tensor_idx,
  3047. }
  3048. absl::StatusOr<std::unique_ptr<AbstractModule>> DomainModule::Create(
  3049. - tflite::Interpreter* interpreter, const TensorIndexMap* tensor_index_map,
  3050. + tflite::Interpreter* interpreter,
  3051. + const TensorIndexMap* tensor_index_map,
  3052. const BertCluAnnotatorOptions* options) {
  3053. auto out = std::make_unique<DomainModule>();
  3054. out->tensor_index_map_ = tensor_index_map;
  3055. @@ -204,7 +210,8 @@ absl::Status DomainModule::Postprocess(Artifacts* artifacts,
  3056. tensor_index_map_->domain_scores_idx));
  3057. const auto& [names, confidences] = t_output;
  3058. for (int i = 0; i < names.size(); ++i) {
  3059. - if (confidences[i] < domain_threshold_) continue;
  3060. + if (confidences[i] < domain_threshold_)
  3061. + continue;
  3062. auto domain = response->add_domains();
  3063. // Conversion to string is needed due to portable_proto generated code
  3064. const std::string names_i(names[i]);
  3065. @@ -215,7 +222,8 @@ absl::Status DomainModule::Postprocess(Artifacts* artifacts,
  3066. }
  3067. absl::StatusOr<std::unique_ptr<AbstractModule>> IntentModule::Create(
  3068. - tflite::Interpreter* interpreter, const TensorIndexMap* tensor_index_map,
  3069. + tflite::Interpreter* interpreter,
  3070. + const TensorIndexMap* tensor_index_map,
  3071. const BertCluAnnotatorOptions* options) {
  3072. auto out = std::make_unique<IntentModule>();
  3073. out->tensor_index_map_ = tensor_index_map;
  3074. @@ -239,7 +247,8 @@ absl::Status IntentModule::Postprocess(Artifacts* artifacts,
  3075. std::vector<absl::string_view> parts = absl::StrSplit(name.Name(), '=');
  3076. if (parts.size() == 2) {
  3077. // The name is like 'xxx=yyy'. It's a categorical slot.
  3078. - if (confidences[i] < categorical_slot_threshold_) continue;
  3079. + if (confidences[i] < categorical_slot_threshold_)
  3080. + continue;
  3081. auto new_categorical_slot = response->mutable_categorical_slots()->Add();
  3082. const auto slot = std::string(parts[0]);
  3083. @@ -251,7 +260,8 @@ absl::Status IntentModule::Postprocess(Artifacts* artifacts,
  3084. new_categorical_slot_prediction->set_score(confidences[i]);
  3085. } else {
  3086. // It's an intent.
  3087. - if (confidences[i] < intent_threshold_) continue;
  3088. + if (confidences[i] < intent_threshold_)
  3089. + continue;
  3090. auto new_intent = response->mutable_intents()->Add();
  3091. new_intent->set_display_name(name.Name());
  3092. new_intent->set_score(confidences[i]);
  3093. @@ -261,7 +271,8 @@ absl::Status IntentModule::Postprocess(Artifacts* artifacts,
  3094. }
  3095. absl::StatusOr<std::unique_ptr<AbstractModule>> SlotModule::Create(
  3096. - tflite::Interpreter* interpreter, const TensorIndexMap* tensor_index_map,
  3097. + tflite::Interpreter* interpreter,
  3098. + const TensorIndexMap* tensor_index_map,
  3099. const BertCluAnnotatorOptions* options) {
  3100. auto out = std::make_unique<SlotModule>();
  3101. out->tensor_index_map_ = tensor_index_map;
  3102. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/tflite_modules.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/tflite_modules.h
  3103. index 5a9f183b8ca4e..eecd65fc495bf 100644
  3104. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/tflite_modules.h
  3105. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/tflite_modules.h
  3106. @@ -16,7 +16,7 @@ limitations under the License.
  3107. #ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_CLU_LIB_TFLITE_MODULES_H_
  3108. #define TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_CLU_LIB_TFLITE_MODULES_H_
  3109. -#include "absl/status/statusor.h" // from @com_google_absl
  3110. +#include "absl/status/statusor.h" // from @com_google_absl
  3111. #include "absl/strings/string_view.h" // from @com_google_absl
  3112. #include "tensorflow/lite/interpreter.h"
  3113. #include "tensorflow_lite_support/cc/task/text/proto/bert_clu_annotator_options_proto_inc.h"
  3114. @@ -85,7 +85,8 @@ class AbstractModule {
  3115. // output tensors.
  3116. // The tensors are assumed to be of shape [1, max_seq_len]
  3117. absl::StatusOr<NamesAndConfidences> NamesAndConfidencesFromOutput(
  3118. - int names_tensor_idx, int scores_tensor_idx) const;
  3119. + int names_tensor_idx,
  3120. + int scores_tensor_idx) const;
  3121. // TFLite interpreter
  3122. Interpreter* interpreter_ = nullptr;
  3123. @@ -98,7 +99,8 @@ class AbstractModule {
  3124. class UtteranceSeqModule : public AbstractModule {
  3125. public:
  3126. static absl::StatusOr<std::unique_ptr<AbstractModule>> Create(
  3127. - Interpreter* interpreter, const TensorIndexMap* tensor_index_map,
  3128. + Interpreter* interpreter,
  3129. + const TensorIndexMap* tensor_index_map,
  3130. const BertCluAnnotatorOptions* options,
  3131. const tflite::support::text::tokenizer::BertTokenizer* tokenizer);
  3132. @@ -116,7 +118,8 @@ class UtteranceSeqModule : public AbstractModule {
  3133. class DomainModule : public AbstractModule {
  3134. public:
  3135. static absl::StatusOr<std::unique_ptr<AbstractModule>> Create(
  3136. - Interpreter* interpreter, const TensorIndexMap* tensor_index_map,
  3137. + Interpreter* interpreter,
  3138. + const TensorIndexMap* tensor_index_map,
  3139. const BertCluAnnotatorOptions* options);
  3140. absl::Status Postprocess(Artifacts* artifacts,
  3141. @@ -130,7 +133,8 @@ class DomainModule : public AbstractModule {
  3142. class IntentModule : public AbstractModule {
  3143. public:
  3144. static absl::StatusOr<std::unique_ptr<AbstractModule>> Create(
  3145. - Interpreter* interpreter, const TensorIndexMap* tensor_index_map,
  3146. + Interpreter* interpreter,
  3147. + const TensorIndexMap* tensor_index_map,
  3148. const BertCluAnnotatorOptions* options);
  3149. absl::Status Postprocess(Artifacts* artifacts,
  3150. @@ -145,7 +149,8 @@ class IntentModule : public AbstractModule {
  3151. class SlotModule : public AbstractModule {
  3152. public:
  3153. static absl::StatusOr<std::unique_ptr<AbstractModule>> Create(
  3154. - Interpreter* interpreter, const TensorIndexMap* tensor_index_map,
  3155. + Interpreter* interpreter,
  3156. + const TensorIndexMap* tensor_index_map,
  3157. const BertCluAnnotatorOptions* options);
  3158. absl::Status Postprocess(Artifacts* artifacts,
  3159. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/tflite_test_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/tflite_test_utils.cc
  3160. index 543958ce93994..30d2bd7513909 100644
  3161. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/tflite_test_utils.cc
  3162. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/tflite_test_utils.cc
  3163. @@ -24,7 +24,8 @@ namespace tflite::task::text::clu {
  3164. template <>
  3165. void PopulateTfLiteTensorValue<std::string>(
  3166. - const std::initializer_list<std::string> values, TfLiteTensor* tensor) {
  3167. + const std::initializer_list<std::string> values,
  3168. + TfLiteTensor* tensor) {
  3169. tflite::DynamicBuffer buf;
  3170. for (const std::string& s : values) {
  3171. buf.AddString(s.data(), s.length());
  3172. @@ -38,13 +39,18 @@ size_t NumTotalFromShape(const std::initializer_list<int>& shape) {
  3173. num_total = 1;
  3174. else
  3175. num_total = 0;
  3176. - for (const int dim : shape) num_total *= dim;
  3177. + for (const int dim : shape)
  3178. + num_total *= dim;
  3179. return num_total;
  3180. }
  3181. -TfLiteTensor* UniqueTfLiteTensor::get() { return tensor_; }
  3182. +TfLiteTensor* UniqueTfLiteTensor::get() {
  3183. + return tensor_;
  3184. +}
  3185. -UniqueTfLiteTensor::~UniqueTfLiteTensor() { TfLiteTensorFree(tensor_); }
  3186. +UniqueTfLiteTensor::~UniqueTfLiteTensor() {
  3187. + TfLiteTensorFree(tensor_);
  3188. +}
  3189. template <>
  3190. TfLiteType TypeToTfLiteType<std::string>() {
  3191. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/tflite_test_utils.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/tflite_test_utils.h
  3192. index 3a393c5223369..f19d2366fc092 100644
  3193. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/tflite_test_utils.h
  3194. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/tflite_test_utils.h
  3195. @@ -64,7 +64,8 @@ size_t NumTotalFromShape(const std::initializer_list<int>& shape);
  3196. template <>
  3197. void PopulateTfLiteTensorValue<std::string>(
  3198. - const std::initializer_list<std::string> values, TfLiteTensor* tensor);
  3199. + const std::initializer_list<std::string> values,
  3200. + TfLiteTensor* tensor);
  3201. template <typename T>
  3202. TfLiteType TypeToTfLiteType() {
  3203. @@ -84,7 +85,8 @@ void ReallocDynamicTensor(const std::initializer_list<int> shape,
  3204. TfLiteIntArray* shape_arr = TfLiteIntArrayCreate(shape.size());
  3205. int i = 0;
  3206. const size_t num_total = NumTotalFromShape(shape);
  3207. - for (const int dim : shape) shape_arr->data[i++] = dim;
  3208. + for (const int dim : shape)
  3209. + shape_arr->data[i++] = dim;
  3210. tensor->dims = shape_arr;
  3211. if (tensor->type != kTfLiteString) {
  3212. TfLiteTensorRealloc(num_total * sizeof(T), tensor);
  3213. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.cc
  3214. index 376ff58ec0b52..5a2966a70e1a2 100644
  3215. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.cc
  3216. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.cc
  3217. @@ -22,10 +22,10 @@ limitations under the License.
  3218. #include <vector>
  3219. #include "absl/algorithm/container.h" // from @com_google_absl
  3220. -#include "absl/status/status.h" // from @com_google_absl
  3221. -#include "absl/strings/str_cat.h" // from @com_google_absl
  3222. +#include "absl/status/status.h" // from @com_google_absl
  3223. +#include "absl/strings/str_cat.h" // from @com_google_absl
  3224. #include "absl/strings/string_view.h" // from @com_google_absl
  3225. -#include "flatbuffers/flatbuffers.h" // from @flatbuffers
  3226. +#include "flatbuffers/flatbuffers.h" // from @flatbuffers
  3227. #include "tensorflow/lite/c/common.h"
  3228. #include "tensorflow/lite/core/api/op_resolver.h"
  3229. #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
  3230. @@ -127,7 +127,8 @@ StatusOr<std::vector<Category>> NLClassifier::ClassifyText(
  3231. }
  3232. absl::Status NLClassifier::Preprocess(
  3233. - const std::vector<TfLiteTensor*>& input_tensors, const std::string& input) {
  3234. + const std::vector<TfLiteTensor*>& input_tensors,
  3235. + const std::string& input) {
  3236. return preprocessor_->Preprocess(input);
  3237. }
  3238. @@ -307,7 +308,8 @@ StatusOr<std::unique_ptr<NLClassifier>> NLClassifier::CreateFromOptions(
  3239. StatusOr<std::unique_ptr<NLClassifier>>
  3240. NLClassifier::CreateFromBufferAndOptions(
  3241. - const char* model_buffer_data, size_t model_buffer_size,
  3242. + const char* model_buffer_data,
  3243. + size_t model_buffer_size,
  3244. const NLClassifierOptions& options,
  3245. std::unique_ptr<tflite::OpResolver> resolver) {
  3246. std::unique_ptr<NLClassifier> nl_classifier;
  3247. @@ -320,7 +322,8 @@ NLClassifier::CreateFromBufferAndOptions(
  3248. }
  3249. StatusOr<std::unique_ptr<NLClassifier>> NLClassifier::CreateFromFileAndOptions(
  3250. - const std::string& path_to_model, const NLClassifierOptions& options,
  3251. + const std::string& path_to_model,
  3252. + const NLClassifierOptions& options,
  3253. std::unique_ptr<tflite::OpResolver> resolver) {
  3254. std::unique_ptr<NLClassifier> nl_classifier;
  3255. ASSIGN_OR_RETURN(nl_classifier,
  3256. @@ -331,7 +334,8 @@ StatusOr<std::unique_ptr<NLClassifier>> NLClassifier::CreateFromFileAndOptions(
  3257. }
  3258. StatusOr<std::unique_ptr<NLClassifier>> NLClassifier::CreateFromFdAndOptions(
  3259. - int fd, const NLClassifierOptions& options,
  3260. + int fd,
  3261. + const NLClassifierOptions& options,
  3262. std::unique_ptr<tflite::OpResolver> resolver) {
  3263. std::unique_ptr<NLClassifier> nl_classifier;
  3264. ASSIGN_OR_RETURN(nl_classifier,
  3265. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h
  3266. index b7af66044b129..68ddc4b5312b7 100644
  3267. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h
  3268. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h
  3269. @@ -23,8 +23,8 @@ limitations under the License.
  3270. #include <string>
  3271. #include <vector>
  3272. -#include "absl/base/macros.h" // from @com_google_absl
  3273. -#include "absl/status/status.h" // from @com_google_absl
  3274. +#include "absl/base/macros.h" // from @com_google_absl
  3275. +#include "absl/status/status.h" // from @com_google_absl
  3276. #include "flatbuffers/flatbuffers.h" // from @flatbuffers
  3277. #include "tensorflow/lite/c/common.h"
  3278. #include "tensorflow/lite/core/api/op_resolver.h"
  3279. @@ -109,7 +109,8 @@ class NLClassifier : public core::BaseTaskApi<std::vector<core::Category>,
  3280. ABSL_DEPRECATED("Prefer using `CreateFromOptions`")
  3281. static tflite::support::StatusOr<std::unique_ptr<NLClassifier>>
  3282. CreateFromBufferAndOptions(
  3283. - const char* model_buffer_data, size_t model_buffer_size,
  3284. + const char* model_buffer_data,
  3285. + size_t model_buffer_size,
  3286. const NLClassifierOptions& options = {},
  3287. std::unique_ptr<tflite::OpResolver> resolver =
  3288. absl::make_unique<tflite_shims::ops::builtin::BuiltinOpResolver>());
  3289. @@ -118,7 +119,8 @@ class NLClassifier : public core::BaseTaskApi<std::vector<core::Category>,
  3290. ABSL_DEPRECATED("Prefer using `CreateFromOptions`")
  3291. static tflite::support::StatusOr<std::unique_ptr<NLClassifier>>
  3292. CreateFromFileAndOptions(
  3293. - const std::string& path_to_model, const NLClassifierOptions& options = {},
  3294. + const std::string& path_to_model,
  3295. + const NLClassifierOptions& options = {},
  3296. std::unique_ptr<tflite::OpResolver> resolver =
  3297. absl::make_unique<tflite_shims::ops::builtin::BuiltinOpResolver>());
  3298. @@ -126,7 +128,8 @@ class NLClassifier : public core::BaseTaskApi<std::vector<core::Category>,
  3299. ABSL_DEPRECATED("Prefer using `CreateFromOptions`")
  3300. static tflite::support::StatusOr<std::unique_ptr<NLClassifier>>
  3301. CreateFromFdAndOptions(
  3302. - int fd, const NLClassifierOptions& options = {},
  3303. + int fd,
  3304. + const NLClassifierOptions& options = {},
  3305. std::unique_ptr<tflite::OpResolver> resolver =
  3306. absl::make_unique<tflite_shims::ops::builtin::BuiltinOpResolver>());
  3307. @@ -182,7 +185,8 @@ class NLClassifier : public core::BaseTaskApi<std::vector<core::Category>,
  3308. const std::vector<TensorType*>& tensors,
  3309. const flatbuffers::Vector<flatbuffers::Offset<TensorMetadata>>*
  3310. metadata_array,
  3311. - const std::string& name, int index) {
  3312. + const std::string& name,
  3313. + int index) {
  3314. int tensor_index = FindTensorIndex(tensors, metadata_array, name, index);
  3315. return tensor_index >= 0 && tensor_index < tensors.size()
  3316. ? tensors[tensor_index]
  3317. @@ -197,7 +201,8 @@ class NLClassifier : public core::BaseTaskApi<std::vector<core::Category>,
  3318. const std::vector<TensorType*>& tensors,
  3319. const flatbuffers::Vector<flatbuffers::Offset<TensorMetadata>>*
  3320. metadata_array,
  3321. - const std::string& name, int default_index) {
  3322. + const std::string& name,
  3323. + int default_index) {
  3324. if (metadata_array != nullptr && metadata_array->size() == tensors.size()) {
  3325. for (size_t i = 0; i < metadata_array->size(); i++) {
  3326. if (strcmp(name.data(), metadata_array->Get(i)->name()->c_str()) == 0) {
  3327. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/proto/text_searcher_options.proto b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/proto/text_searcher_options.proto
  3328. index ebce50cbe5491..ed4c2db81dd01 100644
  3329. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/proto/text_searcher_options.proto
  3330. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/proto/text_searcher_options.proto
  3331. @@ -21,7 +21,6 @@ import "tensorflow_lite_support/cc/task/core/proto/base_options.proto";
  3332. import "tensorflow_lite_support/cc/task/processor/proto/embedding_options.proto";
  3333. import "tensorflow_lite_support/cc/task/processor/proto/search_options.proto";
  3334. -
  3335. // Options for setting up an TextSearcher.
  3336. // Next Id: 4.
  3337. message TextSearcherOptions {
  3338. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/question_answerer.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/question_answerer.h
  3339. index 4cde4329a716b..df21662a40e3a 100644
  3340. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/question_answerer.h
  3341. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/question_answerer.h
  3342. @@ -45,9 +45,9 @@ struct QaAnswer {
  3343. };
  3344. // Interface for an Question-Answer API.
  3345. -class QuestionAnswerer
  3346. - : public core::BaseTaskApi<std::vector<QaAnswer>, const std::string&,
  3347. - const std::string&> {
  3348. +class QuestionAnswerer : public core::BaseTaskApi<std::vector<QaAnswer>,
  3349. + const std::string&,
  3350. + const std::string&> {
  3351. public:
  3352. explicit QuestionAnswerer(std::unique_ptr<core::TfLiteEngine> engine)
  3353. : BaseTaskApi(std::move(engine)) {}
  3354. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/text_embedder.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/text_embedder.cc
  3355. index 7363540797cf2..f7412224cae66 100644
  3356. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/text_embedder.cc
  3357. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/text_embedder.cc
  3358. @@ -58,7 +58,8 @@ absl::Status SanityCheckOptions(const TextEmbedderOptions& options) {
  3359. /* static */
  3360. tflite::support::StatusOr<double> TextEmbedder::CosineSimilarity(
  3361. - const FeatureVector& u, const FeatureVector& v) {
  3362. + const FeatureVector& u,
  3363. + const FeatureVector& v) {
  3364. return processor::EmbeddingPostprocessor::CosineSimilarity(u, v);
  3365. }
  3366. @@ -170,7 +171,8 @@ tflite::support::StatusOr<EmbeddingResult> TextEmbedder::Embed(
  3367. }
  3368. absl::Status TextEmbedder::Preprocess(
  3369. - const std::vector<TfLiteTensor*>& input_tensors, const std::string& input) {
  3370. + const std::vector<TfLiteTensor*>& input_tensors,
  3371. + const std::string& input) {
  3372. return preprocessor_->Preprocess(input);
  3373. }
  3374. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/text_embedder.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/text_embedder.h
  3375. index 3d20d558ca9a0..75597bc040468 100644
  3376. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/text_embedder.h
  3377. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/text_embedder.h
  3378. @@ -84,7 +84,8 @@ class TextEmbedder
  3379. //
  3380. // [1]: https://en.wikipedia.org/wiki/Cosine_similarity
  3381. static tflite::support::StatusOr<double> CosineSimilarity(
  3382. - const processor::FeatureVector& u, const processor::FeatureVector& v);
  3383. + const processor::FeatureVector& u,
  3384. + const processor::FeatureVector& v);
  3385. protected:
  3386. // The options used to build this TextEmbedder.
  3387. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/text_searcher.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/text_searcher.h
  3388. index f9f680847ac5b..ca90bb6c0d141 100644
  3389. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/text_searcher.h
  3390. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/text_searcher.h
  3391. @@ -19,8 +19,8 @@ limitations under the License.
  3392. #include <memory>
  3393. #include <vector>
  3394. -#include "absl/memory/memory.h" // from @com_google_absl
  3395. -#include "absl/status/status.h" // from @com_google_absl
  3396. +#include "absl/memory/memory.h" // from @com_google_absl
  3397. +#include "absl/status/status.h" // from @com_google_absl
  3398. #include "absl/strings/string_view.h" // from @com_google_absl
  3399. #include "tensorflow/lite/c/common.h"
  3400. #include "tensorflow/lite/core/api/op_resolver.h"
  3401. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/universal_sentence_encoder_qa.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/universal_sentence_encoder_qa.cc
  3402. index 52b0041039acf..ba6af609c776b 100644
  3403. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/universal_sentence_encoder_qa.cc
  3404. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/universal_sentence_encoder_qa.cc
  3405. @@ -22,7 +22,7 @@ limitations under the License.
  3406. #include <vector>
  3407. #include "absl/container/flat_hash_map.h" // from @com_google_absl
  3408. -#include "absl/status/status.h" // from @com_google_absl
  3409. +#include "absl/status/status.h" // from @com_google_absl
  3410. #include "tensorflow_lite_support/cc/port/statusor.h"
  3411. #include "tensorflow_lite_support/cc/task/core/base_task_api.h"
  3412. #include "tensorflow_lite_support/cc/task/core/task_api_factory.h"
  3413. @@ -169,7 +169,8 @@ StatusOr<FeatureVector> UniversalSentenceEncoderQA::EncodeQuery(
  3414. }
  3415. StatusOr<FeatureVector> UniversalSentenceEncoderQA::EncodeResponse(
  3416. - absl::string_view response_text, absl::string_view response_context) {
  3417. + absl::string_view response_text,
  3418. + absl::string_view response_context) {
  3419. if (response_text.empty() && response_context.empty()) {
  3420. return Status(
  3421. StatusCode::kInvalidArgument,
  3422. @@ -190,7 +191,8 @@ StatusOr<float> UniversalSentenceEncoderQA::Similarity(const FeatureVector& a,
  3423. }
  3424. std::vector<size_t> UniversalSentenceEncoderQA::Top(
  3425. - const RetrievalOutput& output, size_t k) {
  3426. + const RetrievalOutput& output,
  3427. + size_t k) {
  3428. // Ensure k in [0, total_size).
  3429. // If k == 0, it means that all outputs are ranked.
  3430. if (k == 0) {
  3431. @@ -214,7 +216,8 @@ std::vector<size_t> UniversalSentenceEncoderQA::Top(
  3432. }
  3433. Status UniversalSentenceEncoderQA::Preprocess(
  3434. - const std::vector<TfLiteTensor*>& input_tensors, const QAInput& input) {
  3435. + const std::vector<TfLiteTensor*>& input_tensors,
  3436. + const QAInput& input) {
  3437. RETURN_IF_ERROR(
  3438. PopulateTensor(input.query_text, input_tensors[input_indices_[0]]));
  3439. RETURN_IF_ERROR(
  3440. @@ -235,7 +238,8 @@ StatusOr<QAOutput> UniversalSentenceEncoderQA::Postprocess(
  3441. }
  3442. internal::QAOutput UniversalSentenceEncoderQA::Run(
  3443. - absl::string_view query_text, absl::string_view response_text,
  3444. + absl::string_view query_text,
  3445. + absl::string_view response_text,
  3446. absl::string_view response_context) {
  3447. QAInput input;
  3448. input.query_text = std::string(query_text);
  3449. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/universal_sentence_encoder_qa.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/universal_sentence_encoder_qa.h
  3450. index 3e83c7132c4e7..9b4a58676209c 100644
  3451. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/universal_sentence_encoder_qa.h
  3452. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/universal_sentence_encoder_qa.h
  3453. @@ -20,8 +20,8 @@ limitations under the License.
  3454. #include <vector>
  3455. #include "absl/container/flat_hash_map.h" // from @com_google_absl
  3456. -#include "absl/status/status.h" // from @com_google_absl
  3457. -#include "absl/strings/str_format.h" // from @com_google_absl
  3458. +#include "absl/status/status.h" // from @com_google_absl
  3459. +#include "absl/strings/str_format.h" // from @com_google_absl
  3460. #include "tensorflow_lite_support/cc/port/statusor.h"
  3461. #include "tensorflow_lite_support/cc/task/core/base_task_api.h"
  3462. #include "tensorflow_lite_support/cc/task/core/task_api_factory.h"
  3463. @@ -88,7 +88,8 @@ class UniversalSentenceEncoderQA
  3464. // Encodes response from the text and/or context.
  3465. // Returns an error, if both text and context are empty.
  3466. tflite::support::StatusOr<FeatureVector> EncodeResponse(
  3467. - absl::string_view response_text, absl::string_view response_context);
  3468. + absl::string_view response_text,
  3469. + absl::string_view response_context);
  3470. // Calculates similarity between two encoded vectors (require same size).
  3471. static tflite::support::StatusOr<float> Similarity(const FeatureVector& a,
  3472. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/utils/bert_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/utils/bert_utils.cc
  3473. index 1c0a5b01b7789..04bfc2e4f95d7 100644
  3474. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/utils/bert_utils.cc
  3475. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/utils/bert_utils.cc
  3476. @@ -17,7 +17,7 @@ limitations under the License.
  3477. #include <algorithm>
  3478. -#include "absl/status/status.h" // from @com_google_absl
  3479. +#include "absl/status/status.h" // from @com_google_absl
  3480. #include "absl/strings/str_format.h" // from @com_google_absl
  3481. #include "tensorflow_lite_support/cc/common.h"
  3482. #include "tensorflow_lite_support/cc/task/core/task_utils.h"
  3483. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/base_vision_task_api.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/base_vision_task_api.h
  3484. index 76a03671b54af..d3557fc508c61 100644
  3485. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/base_vision_task_api.h
  3486. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/base_vision_task_api.h
  3487. @@ -23,7 +23,7 @@ limitations under the License.
  3488. #include "absl/memory/memory.h" // from @com_google_absl
  3489. #include "absl/status/status.h" // from @com_google_absl
  3490. -#include "absl/time/clock.h" // from @com_google_absl
  3491. +#include "absl/time/clock.h" // from @com_google_absl
  3492. #include "tensorflow/lite/c/common.h"
  3493. #include "tensorflow_lite_support/cc/common.h"
  3494. #include "tensorflow_lite_support/cc/port/integral_types.h"
  3495. @@ -45,11 +45,12 @@ namespace vision {
  3496. // Base class providing common logic for vision models.
  3497. template <class OutputType>
  3498. class BaseVisionTaskApi
  3499. - : public tflite::task::core::BaseTaskApi<OutputType, const FrameBuffer&,
  3500. - const BoundingBox&> {
  3501. + : public tflite::task::core::
  3502. + BaseTaskApi<OutputType, const FrameBuffer&, const BoundingBox&> {
  3503. public:
  3504. explicit BaseVisionTaskApi(std::unique_ptr<core::TfLiteEngine> engine)
  3505. - : tflite::task::core::BaseTaskApi<OutputType, const FrameBuffer&,
  3506. + : tflite::task::core::BaseTaskApi<OutputType,
  3507. + const FrameBuffer&,
  3508. const BoundingBox&>(std::move(engine)) {
  3509. }
  3510. // BaseVisionTaskApi is neither copyable nor movable.
  3511. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/classification_head.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/classification_head.h
  3512. index 47db0d121d43b..2e1aa6d652967 100644
  3513. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/classification_head.h
  3514. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/classification_head.h
  3515. @@ -18,7 +18,7 @@ limitations under the License.
  3516. #include <string>
  3517. #include <vector>
  3518. -#include "absl/memory/memory.h" // from @com_google_absl
  3519. +#include "absl/memory/memory.h" // from @com_google_absl
  3520. #include "absl/strings/string_view.h" // from @com_google_absl
  3521. #include "tensorflow_lite_support/cc/port/statusor.h"
  3522. #include "tensorflow_lite_support/cc/task/vision/core/label_map_item.h"
  3523. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/frame_buffer.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/frame_buffer.h
  3524. index 1668447393e9e..2936f5acbb921 100644
  3525. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/frame_buffer.h
  3526. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/frame_buffer.h
  3527. @@ -22,12 +22,12 @@ limitations under the License.
  3528. #include <utility>
  3529. #include <vector>
  3530. -#include "absl/memory/memory.h" // from @com_google_absl
  3531. -#include "absl/status/status.h" // from @com_google_absl
  3532. +#include "absl/memory/memory.h" // from @com_google_absl
  3533. +#include "absl/status/status.h" // from @com_google_absl
  3534. #include "absl/strings/str_cat.h" // from @com_google_absl
  3535. -#include "absl/time/clock.h" // from @com_google_absl
  3536. -#include "absl/time/time.h" // from @com_google_absl
  3537. -#include "absl/types/optional.h" // from @com_google_absl
  3538. +#include "absl/time/clock.h" // from @com_google_absl
  3539. +#include "absl/time/time.h" // from @com_google_absl
  3540. +#include "absl/types/optional.h" // from @com_google_absl
  3541. #include "tensorflow_lite_support/cc/port/integral_types.h"
  3542. #include "tensorflow_lite_support/cc/port/statusor.h"
  3543. @@ -74,7 +74,16 @@ namespace vision {
  3544. class FrameBuffer {
  3545. public:
  3546. // Colorspace formats.
  3547. - enum class Format { kRGBA, kRGB, kNV12, kNV21, kYV12, kYV21, kGRAY, kUNKNOWN};
  3548. + enum class Format {
  3549. + kRGBA,
  3550. + kRGB,
  3551. + kNV12,
  3552. + kNV21,
  3553. + kYV12,
  3554. + kYV21,
  3555. + kGRAY,
  3556. + kUNKNOWN
  3557. + };
  3558. // Stride information.
  3559. struct Stride {
  3560. @@ -166,7 +175,8 @@ class FrameBuffer {
  3561. // buffers. In a streaming use case (e.g continuous camera stream), the
  3562. // timestamp can be used as an ID to identify a frame.
  3563. static std::unique_ptr<FrameBuffer> Create(const std::vector<Plane>& planes,
  3564. - Dimension dimension, Format format,
  3565. + Dimension dimension,
  3566. + Format format,
  3567. Orientation orientation,
  3568. absl::Time timestamp) {
  3569. return absl::make_unique<FrameBuffer>(planes, dimension, format,
  3570. @@ -177,7 +187,8 @@ class FrameBuffer {
  3571. // backing buffers. In a streaming use case (e.g continuous camera stream),
  3572. // the timestamp can be used as an ID to identify a frame.
  3573. static std::unique_ptr<FrameBuffer> Create(std::vector<Plane>&& planes,
  3574. - Dimension dimension, Format format,
  3575. + Dimension dimension,
  3576. + Format format,
  3577. Orientation orientation,
  3578. absl::Time timestamp) {
  3579. return absl::make_unique<FrameBuffer>(std::move(planes), dimension, format,
  3580. @@ -189,7 +200,8 @@ class FrameBuffer {
  3581. // more suitable for processing use case that does not need to re-identify
  3582. // this buffer.
  3583. static std::unique_ptr<FrameBuffer> Create(const std::vector<Plane>& planes,
  3584. - Dimension dimension, Format format,
  3585. + Dimension dimension,
  3586. + Format format,
  3587. Orientation orientation) {
  3588. return absl::make_unique<FrameBuffer>(planes, dimension, format,
  3589. orientation, absl::Now());
  3590. @@ -200,7 +212,8 @@ class FrameBuffer {
  3591. // method is more suitable for processing use case that does not need to
  3592. // re-identify this buffer.
  3593. static std::unique_ptr<FrameBuffer> Create(std::vector<Plane>&& planes,
  3594. - Dimension dimension, Format format,
  3595. + Dimension dimension,
  3596. + Format format,
  3597. Orientation orientation) {
  3598. return absl::make_unique<FrameBuffer>(std::move(planes), dimension, format,
  3599. orientation, absl::Now());
  3600. @@ -217,8 +230,11 @@ class FrameBuffer {
  3601. // The FrameBuffer does not take ownership of the backing buffer. The backing
  3602. // buffer is read-only and the caller is responsible for maintaining the
  3603. // backing buffer lifecycle for the lifetime of FrameBuffer.
  3604. - FrameBuffer(const std::vector<Plane>& planes, Dimension dimension,
  3605. - Format format, Orientation orientation, absl::Time timestamp)
  3606. + FrameBuffer(const std::vector<Plane>& planes,
  3607. + Dimension dimension,
  3608. + Format format,
  3609. + Orientation orientation,
  3610. + absl::Time timestamp)
  3611. : planes_(planes),
  3612. dimension_(dimension),
  3613. format_(format),
  3614. @@ -230,8 +246,11 @@ class FrameBuffer {
  3615. // The FrameBuffer does not take ownership of the backing buffer. The backing
  3616. // buffer is read-only and the caller is responsible for maintaining the
  3617. // backing buffer lifecycle for the lifetime of FrameBuffer.
  3618. - FrameBuffer(std::vector<Plane>&& planes, Dimension dimension, Format format,
  3619. - Orientation orientation, absl::Time timestamp)
  3620. + FrameBuffer(std::vector<Plane>&& planes,
  3621. + Dimension dimension,
  3622. + Format format,
  3623. + Orientation orientation,
  3624. + absl::Time timestamp)
  3625. : planes_(std::move(planes)),
  3626. dimension_(dimension),
  3627. format_(format),
  3628. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/label_map_item.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/label_map_item.cc
  3629. index 9c82b63a10359..67fe07534b52a 100644
  3630. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/label_map_item.cc
  3631. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/label_map_item.cc
  3632. @@ -16,7 +16,7 @@ limitations under the License.
  3633. #include "tensorflow_lite_support/cc/task/vision/core/label_map_item.h"
  3634. #include "absl/strings/str_format.h" // from @com_google_absl
  3635. -#include "absl/strings/str_split.h" // from @com_google_absl
  3636. +#include "absl/strings/str_split.h" // from @com_google_absl
  3637. #include "tensorflow_lite_support/cc/common.h"
  3638. namespace tflite {
  3639. @@ -29,7 +29,8 @@ using ::tflite::support::StatusOr;
  3640. using ::tflite::support::TfLiteSupportStatus;
  3641. StatusOr<std::vector<LabelMapItem>> BuildLabelMapFromFiles(
  3642. - absl::string_view labels_file, absl::string_view display_names_file) {
  3643. + absl::string_view labels_file,
  3644. + absl::string_view display_names_file) {
  3645. if (labels_file.empty()) {
  3646. return CreateStatusWithPayload(StatusCode::kInvalidArgument,
  3647. "Expected non-empty labels file.",
  3648. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/label_map_item.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/label_map_item.h
  3649. index 0fb66f2639806..20c316ba4a992 100644
  3650. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/label_map_item.h
  3651. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/label_map_item.h
  3652. @@ -20,8 +20,8 @@ limitations under the License.
  3653. #include "absl/container/flat_hash_map.h" // from @com_google_absl
  3654. #include "absl/container/flat_hash_set.h" // from @com_google_absl
  3655. -#include "absl/status/status.h" // from @com_google_absl
  3656. -#include "absl/strings/string_view.h" // from @com_google_absl
  3657. +#include "absl/status/status.h" // from @com_google_absl
  3658. +#include "absl/strings/string_view.h" // from @com_google_absl
  3659. #include "tensorflow_lite_support/cc/port/statusor.h"
  3660. namespace tflite {
  3661. @@ -49,7 +49,8 @@ struct LabelMapItem {
  3662. // Returns an error e.g. if there's a mismatch between the number of labels and
  3663. // display names.
  3664. tflite::support::StatusOr<std::vector<LabelMapItem>> BuildLabelMapFromFiles(
  3665. - absl::string_view labels_file, absl::string_view display_names_file);
  3666. + absl::string_view labels_file,
  3667. + absl::string_view display_names_file);
  3668. // A class that represents a hierarchy of labels as specified in a label map.
  3669. //
  3670. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_classifier.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_classifier.cc
  3671. index aa1e7707dd99b..36ab3c3ca1903 100644
  3672. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_classifier.cc
  3673. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_classifier.cc
  3674. @@ -16,9 +16,9 @@ limitations under the License.
  3675. #include "tensorflow_lite_support/cc/task/vision/image_classifier.h"
  3676. #include "absl/algorithm/container.h" // from @com_google_absl
  3677. -#include "absl/strings/str_format.h" // from @com_google_absl
  3678. +#include "absl/strings/str_format.h" // from @com_google_absl
  3679. #include "absl/strings/string_view.h" // from @com_google_absl
  3680. -#include "flatbuffers/flatbuffers.h" // from @flatbuffers
  3681. +#include "flatbuffers/flatbuffers.h" // from @flatbuffers
  3682. #include "tensorflow_lite_support/cc/common.h"
  3683. #include "tensorflow_lite_support/cc/port/integral_types.h"
  3684. #include "tensorflow_lite_support/cc/port/status_macros.h"
  3685. @@ -146,7 +146,9 @@ absl::Status ImageClassifier::PreInit() {
  3686. return absl::OkStatus();
  3687. }
  3688. -absl::Status ImageClassifier::PostInit() { return InitScoreCalibrations(); }
  3689. +absl::Status ImageClassifier::PostInit() {
  3690. + return InitScoreCalibrations();
  3691. +}
  3692. absl::Status ImageClassifier::CheckAndSetOutputs() {
  3693. num_outputs_ = TfLiteEngine::OutputCount(GetTfLiteEngine()->interpreter());
  3694. @@ -380,13 +382,15 @@ StatusOr<ClassificationResult> ImageClassifier::Classify(
  3695. }
  3696. StatusOr<ClassificationResult> ImageClassifier::Classify(
  3697. - const FrameBuffer& frame_buffer, const BoundingBox& roi) {
  3698. + const FrameBuffer& frame_buffer,
  3699. + const BoundingBox& roi) {
  3700. return InferWithFallback(frame_buffer, roi);
  3701. }
  3702. StatusOr<ClassificationResult> ImageClassifier::Postprocess(
  3703. const std::vector<const TfLiteTensor*>& output_tensors,
  3704. - const FrameBuffer& /*frame_buffer*/, const BoundingBox& /*roi*/) {
  3705. + const FrameBuffer& /*frame_buffer*/,
  3706. + const BoundingBox& /*roi*/) {
  3707. if (output_tensors.size() != num_outputs_) {
  3708. return CreateStatusWithPayload(
  3709. StatusCode::kInternal,
  3710. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_classifier.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_classifier.h
  3711. index b2f595715e9da..eb0c13ec55c5b 100644
  3712. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_classifier.h
  3713. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_classifier.h
  3714. @@ -20,7 +20,7 @@ limitations under the License.
  3715. #include <vector>
  3716. #include "absl/container/flat_hash_set.h" // from @com_google_absl
  3717. -#include "absl/status/status.h" // from @com_google_absl
  3718. +#include "absl/status/status.h" // from @com_google_absl
  3719. #include "tensorflow/lite/c/common.h"
  3720. #include "tensorflow/lite/core/api/op_resolver.h"
  3721. #include "tensorflow/lite/core/shims/cc/kernels/register.h"
  3722. @@ -109,7 +109,8 @@ class ImageClassifier : public BaseVisionTaskApi<ClassificationResult> {
  3723. // region of interest is not clamped, so this method will return a non-ok
  3724. // status if the region is out of these bounds.
  3725. tflite::support::StatusOr<ClassificationResult> Classify(
  3726. - const FrameBuffer& frame_buffer, const BoundingBox& roi);
  3727. + const FrameBuffer& frame_buffer,
  3728. + const BoundingBox& roi);
  3729. protected:
  3730. // The options used to build this ImageClassifier.
  3731. @@ -123,7 +124,8 @@ class ImageClassifier : public BaseVisionTaskApi<ClassificationResult> {
  3732. // results.
  3733. tflite::support::StatusOr<ClassificationResult> Postprocess(
  3734. const std::vector<const TfLiteTensor*>& output_tensors,
  3735. - const FrameBuffer& frame_buffer, const BoundingBox& roi) override;
  3736. + const FrameBuffer& frame_buffer,
  3737. + const BoundingBox& roi) override;
  3738. // Performs sanity checks on the provided ImageClassifierOptions.
  3739. static absl::Status SanityCheckOptions(const ImageClassifierOptions& options);
  3740. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_embedder.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_embedder.cc
  3741. index 0ce46fb9f9806..943a39b1f762e 100644
  3742. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_embedder.cc
  3743. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_embedder.cc
  3744. @@ -18,10 +18,10 @@ limitations under the License.
  3745. #include <algorithm>
  3746. #include "absl/container/node_hash_set.h" // from @com_google_absl
  3747. -#include "absl/memory/memory.h" // from @com_google_absl
  3748. -#include "absl/status/status.h" // from @com_google_absl
  3749. -#include "absl/strings/str_format.h" // from @com_google_absl
  3750. -#include "absl/strings/string_view.h" // from @com_google_absl
  3751. +#include "absl/memory/memory.h" // from @com_google_absl
  3752. +#include "absl/status/status.h" // from @com_google_absl
  3753. +#include "absl/strings/str_format.h" // from @com_google_absl
  3754. +#include "absl/strings/string_view.h" // from @com_google_absl
  3755. #include "tensorflow/lite/c/common.h"
  3756. #include "tensorflow_lite_support/cc/common.h"
  3757. #include "tensorflow_lite_support/cc/port/status_macros.h"
  3758. @@ -51,7 +51,8 @@ CreatePostprocessor(core::TfLiteEngine* engine,
  3759. /* static */
  3760. tflite::support::StatusOr<double> ImageEmbedder::CosineSimilarity(
  3761. - const FeatureVector& u, const FeatureVector& v) {
  3762. + const FeatureVector& u,
  3763. + const FeatureVector& v) {
  3764. return processor::EmbeddingPostprocessor::CosineSimilarity(u, v);
  3765. }
  3766. @@ -118,13 +119,15 @@ tflite::support::StatusOr<EmbeddingResult> ImageEmbedder::Embed(
  3767. }
  3768. tflite::support::StatusOr<EmbeddingResult> ImageEmbedder::Embed(
  3769. - const FrameBuffer& frame_buffer, const BoundingBox& roi) {
  3770. + const FrameBuffer& frame_buffer,
  3771. + const BoundingBox& roi) {
  3772. return InferWithFallback(frame_buffer, roi);
  3773. }
  3774. tflite::support::StatusOr<EmbeddingResult> ImageEmbedder::Postprocess(
  3775. const std::vector<const TfLiteTensor*>& output_tensors,
  3776. - const FrameBuffer& /*frame_buffer*/, const BoundingBox& /*roi*/) {
  3777. + const FrameBuffer& /*frame_buffer*/,
  3778. + const BoundingBox& /*roi*/) {
  3779. EmbeddingResult result;
  3780. for (int i = 0; i < postprocessors_.size(); ++i) {
  3781. RETURN_IF_ERROR(
  3782. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_embedder.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_embedder.h
  3783. index bc321c83d3774..93e2455eebd19 100644
  3784. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_embedder.h
  3785. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_embedder.h
  3786. @@ -90,7 +90,8 @@ class ImageEmbedder
  3787. // region of interest. Note that the region of interest is not clamped, so
  3788. // this method will fail if the region is out of bounds of the input image.
  3789. tflite::support::StatusOr<EmbeddingResult> Embed(
  3790. - const FrameBuffer& frame_buffer, const BoundingBox& roi);
  3791. + const FrameBuffer& frame_buffer,
  3792. + const BoundingBox& roi);
  3793. // Returns the Embedding output by the output_index'th layer. In (the most
  3794. // common) case where a single embedding is produced, you can just call
  3795. @@ -113,7 +114,8 @@ class ImageEmbedder
  3796. //
  3797. // [1]: https://en.wikipedia.org/wiki/Cosine_similarity
  3798. static tflite::support::StatusOr<double> CosineSimilarity(
  3799. - const FeatureVector& u, const FeatureVector& v);
  3800. + const FeatureVector& u,
  3801. + const FeatureVector& v);
  3802. protected:
  3803. // The options used to build this ImageEmbedder.
  3804. @@ -122,7 +124,8 @@ class ImageEmbedder
  3805. // Post-processing to transform the raw model outputs into embedding results.
  3806. tflite::support::StatusOr<EmbeddingResult> Postprocess(
  3807. const std::vector<const TfLiteTensor*>& output_tensors,
  3808. - const FrameBuffer& frame_buffer, const BoundingBox& roi) override;
  3809. + const FrameBuffer& frame_buffer,
  3810. + const BoundingBox& roi) override;
  3811. // Performs pre-initialization actions.
  3812. virtual absl::Status PreInit();
  3813. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_searcher.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_searcher.cc
  3814. index fb8bdf4f36446..4916290cb1473 100644
  3815. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_searcher.cc
  3816. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_searcher.cc
  3817. @@ -19,8 +19,8 @@ limitations under the License.
  3818. #include <utility>
  3819. #include <vector>
  3820. -#include "absl/memory/memory.h" // from @com_google_absl
  3821. -#include "absl/status/status.h" // from @com_google_absl
  3822. +#include "absl/memory/memory.h" // from @com_google_absl
  3823. +#include "absl/status/status.h" // from @com_google_absl
  3824. #include "absl/strings/string_view.h" // from @com_google_absl
  3825. #include "tensorflow/lite/c/common.h"
  3826. #include "tensorflow/lite/core/api/op_resolver.h"
  3827. @@ -110,7 +110,8 @@ StatusOr<absl::string_view> ImageSearcher::GetUserInfo() {
  3828. StatusOr<SearchResult> ImageSearcher::Postprocess(
  3829. const std::vector<const TfLiteTensor*>& /*output_tensors*/,
  3830. - const FrameBuffer& /*frame_buffer*/, const BoundingBox& /*roi*/) {
  3831. + const FrameBuffer& /*frame_buffer*/,
  3832. + const BoundingBox& /*roi*/) {
  3833. return postprocessor_->Postprocess();
  3834. }
  3835. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_searcher.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_searcher.h
  3836. index 4a510a615ab5b..6b43f8d7736d9 100644
  3837. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_searcher.h
  3838. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_searcher.h
  3839. @@ -19,7 +19,7 @@ limitations under the License.
  3840. #include <memory>
  3841. #include <vector>
  3842. -#include "absl/status/status.h" // from @com_google_absl
  3843. +#include "absl/status/status.h" // from @com_google_absl
  3844. #include "absl/strings/string_view.h" // from @com_google_absl
  3845. #include "tensorflow/lite/core/api/op_resolver.h"
  3846. #include "tensorflow/lite/core/shims/cc/kernels/register.h"
  3847. @@ -93,7 +93,8 @@ class ImageSearcher
  3848. // region of interest. Note that the region of interest is not clamped, so
  3849. // this method will fail if the region is out of bounds of the input image.
  3850. tflite::support::StatusOr<tflite::task::processor::SearchResult> Search(
  3851. - const FrameBuffer& frame_buffer, const BoundingBox& roi);
  3852. + const FrameBuffer& frame_buffer,
  3853. + const BoundingBox& roi);
  3854. // Provides access to the opaque user info stored in the index file (if any),
  3855. // in raw binary form. Returns an empty string if the index doesn't contain
  3856. @@ -108,7 +109,8 @@ class ImageSearcher
  3857. // perform the nearest-neighbor search in the index.
  3858. tflite::support::StatusOr<tflite::task::processor::SearchResult> Postprocess(
  3859. const std::vector<const TfLiteTensor*>& output_tensors,
  3860. - const FrameBuffer& frame_buffer, const BoundingBox& roi) override;
  3861. + const FrameBuffer& frame_buffer,
  3862. + const BoundingBox& roi) override;
  3863. // Initializes the ImageSearcher.
  3864. absl::Status Init(std::unique_ptr<ImageSearcherOptions> options);
  3865. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_segmenter.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_segmenter.cc
  3866. index c9dad866f1a68..1cf9a54b91e0f 100644
  3867. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_segmenter.cc
  3868. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_segmenter.cc
  3869. @@ -17,10 +17,10 @@ limitations under the License.
  3870. #include <algorithm>
  3871. -#include "absl/memory/memory.h" // from @com_google_absl
  3872. -#include "absl/strings/str_format.h" // from @com_google_absl
  3873. +#include "absl/memory/memory.h" // from @com_google_absl
  3874. +#include "absl/strings/str_format.h" // from @com_google_absl
  3875. #include "absl/strings/string_view.h" // from @com_google_absl
  3876. -#include "flatbuffers/flatbuffers.h" // from @flatbuffers
  3877. +#include "flatbuffers/flatbuffers.h" // from @flatbuffers
  3878. #include "tensorflow/lite/c/common.h"
  3879. #include "tensorflow_lite_support/cc/common.h"
  3880. #include "tensorflow_lite_support/cc/port/integral_types.h"
  3881. @@ -110,7 +110,8 @@ constexpr uint8 kColorMap[768] = {
  3882. StatusOr<std::vector<LabelMapItem>> GetLabelMapIfAny(
  3883. const ModelMetadataExtractor& metadata_extractor,
  3884. - const TensorMetadata& tensor_metadata, absl::string_view locale) {
  3885. + const TensorMetadata& tensor_metadata,
  3886. + absl::string_view locale) {
  3887. const std::string labels_filename =
  3888. ModelMetadataExtractor::FindFirstAssociatedFileName(
  3889. tensor_metadata, tflite::AssociatedFileType_TENSOR_AXIS_LABELS);
  3890. @@ -332,7 +333,8 @@ StatusOr<SegmentationResult> ImageSegmenter::Segment(
  3891. StatusOr<SegmentationResult> ImageSegmenter::Postprocess(
  3892. const std::vector<const TfLiteTensor*>& output_tensors,
  3893. - const FrameBuffer& frame_buffer, const BoundingBox& /*roi*/) {
  3894. + const FrameBuffer& frame_buffer,
  3895. + const BoundingBox& /*roi*/) {
  3896. if (output_tensors.size() != 1) {
  3897. return CreateStatusWithPayload(
  3898. StatusCode::kInternal,
  3899. @@ -432,7 +434,10 @@ StatusOr<SegmentationResult> ImageSegmenter::Postprocess(
  3900. }
  3901. StatusOr<float> ImageSegmenter::GetOutputConfidence(
  3902. - const TfLiteTensor& output_tensor, int x, int y, int depth) {
  3903. + const TfLiteTensor& output_tensor,
  3904. + int x,
  3905. + int y,
  3906. + int depth) {
  3907. int index = output_width_ * output_depth_ * y + output_depth_ * x + depth;
  3908. if (has_uint8_outputs_) {
  3909. ASSIGN_OR_RETURN(const uint8* data,
  3910. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_segmenter.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_segmenter.h
  3911. index 3f51f4962738e..e255110d9dc66 100644
  3912. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_segmenter.h
  3913. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_segmenter.h
  3914. @@ -119,7 +119,8 @@ class ImageSegmenter : public BaseVisionTaskApi<SegmentationResult> {
  3915. // results.
  3916. tflite::support::StatusOr<SegmentationResult> Postprocess(
  3917. const std::vector<const TfLiteTensor*>& output_tensors,
  3918. - const FrameBuffer& frame_buffer, const BoundingBox& roi) override;
  3919. + const FrameBuffer& frame_buffer,
  3920. + const BoundingBox& roi) override;
  3921. // Performs sanity checks on the provided ImageSegmenterOptions.
  3922. static absl::Status SanityCheckOptions(const ImageSegmenterOptions& options);
  3923. @@ -148,7 +149,10 @@ class ImageSegmenter : public BaseVisionTaskApi<SegmentationResult> {
  3924. // Returns the output confidence at coordinates {x, y, depth}, dequantizing
  3925. // on-the-fly if needed (i.e. if `has_uint8_outputs_` is true).
  3926. tflite::support::StatusOr<float> GetOutputConfidence(
  3927. - const TfLiteTensor& output_tensor, int x, int y, int depth);
  3928. + const TfLiteTensor& output_tensor,
  3929. + int x,
  3930. + int y,
  3931. + int depth);
  3932. // Prebuilt list of ColoredLabel attached to each Segmentation result. The
  3933. // i-th item in this list corresponds to the i-th label map item.
  3934. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/object_detector.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/object_detector.cc
  3935. index 0a4d5f7553ee9..00775015515ac 100644
  3936. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/object_detector.cc
  3937. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/object_detector.cc
  3938. @@ -20,8 +20,8 @@ limitations under the License.
  3939. #include <vector>
  3940. #include <glog/logging.h>
  3941. -#include "absl/memory/memory.h" // from @com_google_absl
  3942. -#include "absl/status/status.h" // from @com_google_absl
  3943. +#include "absl/memory/memory.h" // from @com_google_absl
  3944. +#include "absl/status/status.h" // from @com_google_absl
  3945. #include "absl/strings/str_format.h" // from @com_google_absl
  3946. #include "absl/strings/string_view.h" // from @com_google_absl
  3947. #include "tensorflow/lite/c/common.h"
  3948. @@ -141,7 +141,8 @@ StatusOr<const BoundingBoxProperties*> GetBoundingBoxProperties(
  3949. StatusOr<std::vector<LabelMapItem>> GetLabelMapIfAny(
  3950. const ModelMetadataExtractor& metadata_extractor,
  3951. - const TensorMetadata& tensor_metadata, absl::string_view locale) {
  3952. + const TensorMetadata& tensor_metadata,
  3953. + absl::string_view locale) {
  3954. const std::string labels_filename =
  3955. ModelMetadataExtractor::FindFirstAssociatedFileName(
  3956. tensor_metadata, tflite::AssociatedFileType_TENSOR_VALUE_LABELS);
  3957. @@ -370,7 +371,9 @@ absl::Status ObjectDetector::PreInit() {
  3958. return absl::OkStatus();
  3959. }
  3960. -absl::Status ObjectDetector::PostInit() { return InitScoreCalibrations(); }
  3961. +absl::Status ObjectDetector::PostInit() {
  3962. + return InitScoreCalibrations();
  3963. +}
  3964. StatusOr<SigmoidCalibrationParameters> BuildCalibrationParametersIfAny(
  3965. const tflite::metadata::ModelMetadataExtractor& metadata_extractor,
  3966. @@ -599,7 +602,8 @@ StatusOr<DetectionResult> ObjectDetector::Detect(
  3967. StatusOr<DetectionResult> ObjectDetector::Postprocess(
  3968. const std::vector<const TfLiteTensor*>& output_tensors,
  3969. - const FrameBuffer& frame_buffer, const BoundingBox& /*roi*/) {
  3970. + const FrameBuffer& frame_buffer,
  3971. + const BoundingBox& /*roi*/) {
  3972. // Most of the checks here should never happen, as outputs have been validated
  3973. // at construction time. Checking nonetheless and returning internal errors if
  3974. // something bad happens.
  3975. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/object_detector.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/object_detector.h
  3976. index eaa6b5371ba52..c37fa8771081e 100644
  3977. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/object_detector.h
  3978. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/object_detector.h
  3979. @@ -19,7 +19,7 @@ limitations under the License.
  3980. #include <memory>
  3981. #include "absl/container/flat_hash_set.h" // from @com_google_absl
  3982. -#include "absl/status/status.h" // from @com_google_absl
  3983. +#include "absl/status/status.h" // from @com_google_absl
  3984. #include "tensorflow/lite/core/api/op_resolver.h"
  3985. #include "tensorflow/lite/core/shims/cc/kernels/register.h"
  3986. #include "tensorflow_lite_support/cc/port/statusor.h"
  3987. @@ -123,7 +123,8 @@ class ObjectDetector : public BaseVisionTaskApi<DetectionResult> {
  3988. // Post-processing to transform the raw model outputs into detection results.
  3989. tflite::support::StatusOr<DetectionResult> Postprocess(
  3990. const std::vector<const TfLiteTensor*>& output_tensors,
  3991. - const FrameBuffer& frame_buffer, const BoundingBox& roi) override;
  3992. + const FrameBuffer& frame_buffer,
  3993. + const BoundingBox& roi) override;
  3994. // Performs sanity checks on the provided ObjectDetectorOptions.
  3995. static absl::Status SanityCheckOptions(const ObjectDetectorOptions& options);
  3996. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/proto/image_searcher_options.proto b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/proto/image_searcher_options.proto
  3997. index 7501bb24d659d..5b5aaf1fa035c 100644
  3998. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/proto/image_searcher_options.proto
  3999. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/proto/image_searcher_options.proto
  4000. @@ -21,7 +21,6 @@ import "tensorflow_lite_support/cc/task/core/proto/base_options.proto";
  4001. import "tensorflow_lite_support/cc/task/processor/proto/embedding_options.proto";
  4002. import "tensorflow_lite_support/cc/task/processor/proto/search_options.proto";
  4003. -
  4004. // Options for setting up an ImageSearcher.
  4005. // Next Id: 4.
  4006. message ImageSearcherOptions {
  4007. @@ -37,5 +36,4 @@ message ImageSearcherOptions {
  4008. // Options specifying the index to search into and controlling the search
  4009. // behavior.
  4010. optional tflite.task.processor.SearchOptions search_options = 3;
  4011. -
  4012. }
  4013. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.cc
  4014. index 1854cf546d599..9a5b96160c033 100644
  4015. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.cc
  4016. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.cc
  4017. @@ -18,7 +18,7 @@ limitations under the License.
  4018. #include <string>
  4019. #include <vector>
  4020. -#include "absl/strings/str_cat.h" // from @com_google_absl
  4021. +#include "absl/strings/str_cat.h" // from @com_google_absl
  4022. #include "absl/strings/str_format.h" // from @com_google_absl
  4023. #include "tensorflow_lite_support/cc/port/status_macros.h"
  4024. @@ -36,8 +36,10 @@ constexpr int kGrayChannel = 1;
  4025. // Creates a FrameBuffer from one plane raw NV21/NV12 buffer and passing
  4026. // arguments.
  4027. StatusOr<std::unique_ptr<FrameBuffer>> CreateFromOnePlaneNVRawBuffer(
  4028. - const uint8* input, FrameBuffer::Dimension dimension,
  4029. - FrameBuffer::Format format, FrameBuffer::Orientation orientation,
  4030. + const uint8* input,
  4031. + FrameBuffer::Dimension dimension,
  4032. + FrameBuffer::Format format,
  4033. + FrameBuffer::Orientation orientation,
  4034. const absl::Time timestamp) {
  4035. FrameBuffer::Plane input_plane = {/*buffer=*/input,
  4036. /*stride=*/{dimension.width, kGrayChannel}};
  4037. @@ -129,7 +131,8 @@ StatusOr<const uint8*> GetUvRawBuffer(const FrameBuffer& buffer) {
  4038. }
  4039. StatusOr<FrameBuffer::Dimension> GetUvPlaneDimension(
  4040. - FrameBuffer::Dimension dimension, FrameBuffer::Format format) {
  4041. + FrameBuffer::Dimension dimension,
  4042. + FrameBuffer::Format format) {
  4043. if (dimension.width <= 0 || dimension.height <= 0) {
  4044. return absl::InvalidArgumentError(
  4045. absl::StrFormat("Invalid input dimension: {%d, %d}.", dimension.width,
  4046. @@ -176,7 +179,8 @@ absl::Status ValidateBufferFormat(const FrameBuffer& buffer) {
  4047. case FrameBuffer::Format::kGRAY:
  4048. case FrameBuffer::Format::kRGB:
  4049. case FrameBuffer::Format::kRGBA:
  4050. - if (buffer.plane_count() == 1) return absl::OkStatus();
  4051. + if (buffer.plane_count() == 1)
  4052. + return absl::OkStatus();
  4053. return absl::InvalidArgumentError(
  4054. "Plane count must be 1 for grayscale and RGB[a] buffers.");
  4055. case FrameBuffer::Format::kNV21:
  4056. @@ -252,8 +256,11 @@ absl::Status ValidateRotateBufferInputs(const FrameBuffer& buffer,
  4057. }
  4058. absl::Status ValidateCropBufferInputs(const FrameBuffer& buffer,
  4059. - const FrameBuffer& output_buffer, int x0,
  4060. - int y0, int x1, int y1) {
  4061. + const FrameBuffer& output_buffer,
  4062. + int x0,
  4063. + int y0,
  4064. + int x1,
  4065. + int y1) {
  4066. if (!AreBufferFormatsCompatible(buffer, output_buffer)) {
  4067. return absl::InvalidArgumentError(
  4068. "Input and output buffer formats must match.");
  4069. @@ -309,8 +316,10 @@ absl::Status ValidateConvertFormats(FrameBuffer::Format from_format,
  4070. // Creates a FrameBuffer from raw RGBA buffer and passing arguments.
  4071. std::unique_ptr<FrameBuffer> CreateFromRgbaRawBuffer(
  4072. - const uint8* input, FrameBuffer::Dimension dimension,
  4073. - FrameBuffer::Orientation orientation, const absl::Time timestamp,
  4074. + const uint8* input,
  4075. + FrameBuffer::Dimension dimension,
  4076. + FrameBuffer::Orientation orientation,
  4077. + const absl::Time timestamp,
  4078. FrameBuffer::Stride stride) {
  4079. if (stride == kDefaultStride) {
  4080. stride.row_stride_bytes = dimension.width * kRgbaChannels;
  4081. @@ -325,8 +334,10 @@ std::unique_ptr<FrameBuffer> CreateFromRgbaRawBuffer(
  4082. // Creates a FrameBuffer from raw RGB buffer and passing arguments.
  4083. std::unique_ptr<FrameBuffer> CreateFromRgbRawBuffer(
  4084. - const uint8* input, FrameBuffer::Dimension dimension,
  4085. - FrameBuffer::Orientation orientation, const absl::Time timestamp,
  4086. + const uint8* input,
  4087. + FrameBuffer::Dimension dimension,
  4088. + FrameBuffer::Orientation orientation,
  4089. + const absl::Time timestamp,
  4090. FrameBuffer::Stride stride) {
  4091. if (stride == kDefaultStride) {
  4092. stride.row_stride_bytes = dimension.width * kRgbChannels;
  4093. @@ -340,8 +351,10 @@ std::unique_ptr<FrameBuffer> CreateFromRgbRawBuffer(
  4094. // Creates a FrameBuffer from raw grayscale buffer and passing arguments.
  4095. std::unique_ptr<FrameBuffer> CreateFromGrayRawBuffer(
  4096. - const uint8* input, FrameBuffer::Dimension dimension,
  4097. - FrameBuffer::Orientation orientation, const absl::Time timestamp,
  4098. + const uint8* input,
  4099. + FrameBuffer::Dimension dimension,
  4100. + FrameBuffer::Orientation orientation,
  4101. + const absl::Time timestamp,
  4102. FrameBuffer::Stride stride) {
  4103. if (stride == kDefaultStride) {
  4104. stride.row_stride_bytes = dimension.width * kGrayChannel;
  4105. @@ -356,10 +369,16 @@ std::unique_ptr<FrameBuffer> CreateFromGrayRawBuffer(
  4106. // Creates a FrameBuffer from raw YUV buffer and passing arguments.
  4107. StatusOr<std::unique_ptr<FrameBuffer>> CreateFromYuvRawBuffer(
  4108. - const uint8* y_plane, const uint8* u_plane, const uint8* v_plane,
  4109. - FrameBuffer::Format format, FrameBuffer::Dimension dimension,
  4110. - int row_stride_y, int row_stride_uv, int pixel_stride_uv,
  4111. - FrameBuffer::Orientation orientation, const absl::Time timestamp) {
  4112. + const uint8* y_plane,
  4113. + const uint8* u_plane,
  4114. + const uint8* v_plane,
  4115. + FrameBuffer::Format format,
  4116. + FrameBuffer::Dimension dimension,
  4117. + int row_stride_y,
  4118. + int row_stride_uv,
  4119. + int pixel_stride_uv,
  4120. + FrameBuffer::Orientation orientation,
  4121. + const absl::Time timestamp) {
  4122. const int pixel_stride_y = 1;
  4123. std::vector<FrameBuffer::Plane> planes;
  4124. if (format == FrameBuffer::Format::kNV21 ||
  4125. @@ -380,9 +399,11 @@ StatusOr<std::unique_ptr<FrameBuffer>> CreateFromYuvRawBuffer(
  4126. }
  4127. StatusOr<std::unique_ptr<FrameBuffer>> CreateFromRawBuffer(
  4128. - const uint8* buffer, FrameBuffer::Dimension dimension,
  4129. + const uint8* buffer,
  4130. + FrameBuffer::Dimension dimension,
  4131. const FrameBuffer::Format target_format,
  4132. - FrameBuffer::Orientation orientation, absl::Time timestamp) {
  4133. + FrameBuffer::Orientation orientation,
  4134. + absl::Time timestamp) {
  4135. switch (target_format) {
  4136. case FrameBuffer::Format::kNV12:
  4137. return CreateFromOnePlaneNVRawBuffer(buffer, dimension, target_format,
  4138. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h
  4139. index 470e76b9037a1..7ebf69fadc3de 100644
  4140. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h
  4141. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h
  4142. @@ -18,8 +18,8 @@ limitations under the License.
  4143. #include <memory>
  4144. #include "absl/status/status.h" // from @com_google_absl
  4145. -#include "absl/time/clock.h" // from @com_google_absl
  4146. -#include "absl/time/time.h" // from @com_google_absl
  4147. +#include "absl/time/clock.h" // from @com_google_absl
  4148. +#include "absl/time/time.h" // from @com_google_absl
  4149. #include "tensorflow_lite_support/cc/port/integral_types.h"
  4150. #include "tensorflow_lite_support/cc/port/statusor.h"
  4151. #include "tensorflow_lite_support/cc/task/vision/core/frame_buffer.h"
  4152. @@ -58,7 +58,8 @@ tflite::support::StatusOr<const uint8*> GetUvRawBuffer(
  4153. // supported formats. This method assums the UV plane share the same dimension,
  4154. // especially for the YV12 / YV21 formats.
  4155. tflite::support::StatusOr<FrameBuffer::Dimension> GetUvPlaneDimension(
  4156. - FrameBuffer::Dimension dimension, FrameBuffer::Format format);
  4157. + FrameBuffer::Dimension dimension,
  4158. + FrameBuffer::Format format);
  4159. // Returns crop dimension based on crop start and end points.
  4160. FrameBuffer::Dimension GetCropDimension(int x0, int x1, int y0, int y1);
  4161. @@ -92,8 +93,11 @@ absl::Status ValidateRotateBufferInputs(const FrameBuffer& buffer,
  4162. // (x0, y0) represents the top-left point of the buffer.
  4163. // (x1, y1) represents the bottom-right point of the buffer.
  4164. absl::Status ValidateCropBufferInputs(const FrameBuffer& buffer,
  4165. - const FrameBuffer& output_buffer, int x0,
  4166. - int y0, int x1, int y1);
  4167. + const FrameBuffer& output_buffer,
  4168. + int x0,
  4169. + int y0,
  4170. + int x1,
  4171. + int y1);
  4172. // Validates the given inputs for flipping `buffer` horizontally or vertically.
  4173. absl::Status ValidateFlipBufferInputs(const FrameBuffer& buffer,
  4174. @@ -110,36 +114,45 @@ absl::Status ValidateConvertFormats(FrameBuffer::Format from_format,
  4175. // Creates a FrameBuffer from raw RGBA buffer and passing arguments.
  4176. std::unique_ptr<FrameBuffer> CreateFromRgbaRawBuffer(
  4177. - const uint8* input, FrameBuffer::Dimension dimension,
  4178. + const uint8* input,
  4179. + FrameBuffer::Dimension dimension,
  4180. FrameBuffer::Orientation orientation = FrameBuffer::Orientation::kTopLeft,
  4181. absl::Time timestamp = absl::Now(),
  4182. FrameBuffer::Stride stride = kDefaultStride);
  4183. // Creates a FrameBuffer from raw RGB buffer and passing arguments.
  4184. std::unique_ptr<FrameBuffer> CreateFromRgbRawBuffer(
  4185. - const uint8* input, FrameBuffer::Dimension dimension,
  4186. + const uint8* input,
  4187. + FrameBuffer::Dimension dimension,
  4188. FrameBuffer::Orientation orientation = FrameBuffer::Orientation::kTopLeft,
  4189. absl::Time timestamp = absl::Now(),
  4190. FrameBuffer::Stride stride = kDefaultStride);
  4191. // Creates a FrameBuffer from raw grayscale buffer and passing arguments.
  4192. std::unique_ptr<FrameBuffer> CreateFromGrayRawBuffer(
  4193. - const uint8* input, FrameBuffer::Dimension dimension,
  4194. + const uint8* input,
  4195. + FrameBuffer::Dimension dimension,
  4196. FrameBuffer::Orientation orientation = FrameBuffer::Orientation::kTopLeft,
  4197. absl::Time timestamp = absl::Now(),
  4198. FrameBuffer::Stride stride = kDefaultStride);
  4199. // Creates a FrameBuffer from raw YUV buffer and passing arguments.
  4200. tflite::support::StatusOr<std::unique_ptr<FrameBuffer>> CreateFromYuvRawBuffer(
  4201. - const uint8* y_plane, const uint8* u_plane, const uint8* v_plane,
  4202. - FrameBuffer::Format format, FrameBuffer::Dimension dimension,
  4203. - int row_stride_y, int row_stride_uv, int pixel_stride_uv,
  4204. + const uint8* y_plane,
  4205. + const uint8* u_plane,
  4206. + const uint8* v_plane,
  4207. + FrameBuffer::Format format,
  4208. + FrameBuffer::Dimension dimension,
  4209. + int row_stride_y,
  4210. + int row_stride_uv,
  4211. + int pixel_stride_uv,
  4212. FrameBuffer::Orientation orientation = FrameBuffer::Orientation::kTopLeft,
  4213. absl::Time timestamp = absl::Now());
  4214. // Creates an instance of FrameBuffer from raw buffer and passing arguments.
  4215. tflite::support::StatusOr<std::unique_ptr<FrameBuffer>> CreateFromRawBuffer(
  4216. - const uint8* buffer, FrameBuffer::Dimension dimension,
  4217. + const uint8* buffer,
  4218. + FrameBuffer::Dimension dimension,
  4219. FrameBuffer::Format target_format,
  4220. FrameBuffer::Orientation orientation = FrameBuffer::Orientation::kTopLeft,
  4221. absl::Time timestamp = absl::Now());
  4222. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.cc
  4223. index 4d767fc3e48b2..4728c30cb60dc 100644
  4224. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.cc
  4225. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.cc
  4226. @@ -22,8 +22,8 @@ limitations under the License.
  4227. #include <utility>
  4228. #include <vector>
  4229. -#include "absl/memory/memory.h" // from @com_google_absl
  4230. -#include "absl/status/status.h" // from @com_google_absl
  4231. +#include "absl/memory/memory.h" // from @com_google_absl
  4232. +#include "absl/status/status.h" // from @com_google_absl
  4233. #include "absl/strings/str_format.h" // from @com_google_absl
  4234. #include "tensorflow/lite/kernels/internal/compatibility.h"
  4235. #include "tensorflow/lite/kernels/op_macros.h"
  4236. @@ -91,7 +91,8 @@ static int GetOrientationIndex(FrameBuffer::Orientation orientation) {
  4237. // The new box origin is (x:box.origin_y, y:width - (box.origin_x + box.width).
  4238. // The new box dimension is (w: box.height, h: box.width).
  4239. //
  4240. -static BoundingBox RotateBoundingBox(const BoundingBox& box, int angle,
  4241. +static BoundingBox RotateBoundingBox(const BoundingBox& box,
  4242. + int angle,
  4243. FrameBuffer::Dimension frame_dimension) {
  4244. int rx = box.origin_x(), ry = box.origin_y(), rw = box.width(),
  4245. rh = box.height();
  4246. @@ -130,9 +131,12 @@ static BoundingBox RotateBoundingBox(const BoundingBox& box, int angle,
  4247. // in counterclockwise degree in one of the values [0, 90, 180, 270].
  4248. //
  4249. // See `RotateBoundingBox` above for more details.
  4250. -static void RotateCoordinates(int from_x, int from_y, int angle,
  4251. +static void RotateCoordinates(int from_x,
  4252. + int from_y,
  4253. + int angle,
  4254. const FrameBuffer::Dimension& frame_dimension,
  4255. - int* to_x, int* to_y) {
  4256. + int* to_x,
  4257. + int* to_y) {
  4258. switch (angle) {
  4259. case 0:
  4260. *to_x = from_x;
  4261. @@ -199,7 +203,10 @@ BoundingBox OrientBoundingBox(const BoundingBox& from_box,
  4262. }
  4263. BoundingBox OrientAndDenormalizeBoundingBox(
  4264. - float from_left, float from_top, float from_right, float from_bottom,
  4265. + float from_left,
  4266. + float from_top,
  4267. + float from_right,
  4268. + float from_bottom,
  4269. FrameBuffer::Orientation from_orientation,
  4270. FrameBuffer::Orientation to_orientation,
  4271. FrameBuffer::Dimension from_dimension) {
  4272. @@ -214,10 +221,12 @@ BoundingBox OrientAndDenormalizeBoundingBox(
  4273. return to_box;
  4274. }
  4275. -void OrientCoordinates(int from_x, int from_y,
  4276. +void OrientCoordinates(int from_x,
  4277. + int from_y,
  4278. FrameBuffer::Orientation from_orientation,
  4279. FrameBuffer::Orientation to_orientation,
  4280. - FrameBuffer::Dimension from_dimension, int* to_x,
  4281. + FrameBuffer::Dimension from_dimension,
  4282. + int* to_x,
  4283. int* to_y) {
  4284. *to_x = from_x;
  4285. *to_y = from_y;
  4286. @@ -298,15 +307,19 @@ bool RequireDimensionSwap(FrameBuffer::Orientation from_orientation,
  4287. return params.rotation_angle_deg == 90 || params.rotation_angle_deg == 270;
  4288. }
  4289. -absl::Status FrameBufferUtils::Crop(const FrameBuffer& buffer, int x0, int y0,
  4290. - int x1, int y1,
  4291. +absl::Status FrameBufferUtils::Crop(const FrameBuffer& buffer,
  4292. + int x0,
  4293. + int y0,
  4294. + int x1,
  4295. + int y1,
  4296. FrameBuffer* output_buffer) {
  4297. TFLITE_DCHECK(utils_ != nullptr);
  4298. return utils_->Crop(buffer, x0, y0, x1, y1, output_buffer);
  4299. }
  4300. FrameBuffer::Dimension FrameBufferUtils::GetSize(
  4301. - const FrameBuffer& buffer, const FrameBufferOperation& operation) {
  4302. + const FrameBuffer& buffer,
  4303. + const FrameBufferOperation& operation) {
  4304. FrameBuffer::Dimension dimension = buffer.dimension();
  4305. if (absl::holds_alternative<OrientOperation>(operation)) {
  4306. OrientParams params =
  4307. @@ -327,7 +340,8 @@ FrameBuffer::Dimension FrameBufferUtils::GetSize(
  4308. }
  4309. std::vector<FrameBuffer::Plane> FrameBufferUtils::GetPlanes(
  4310. - const uint8* buffer, FrameBuffer::Dimension dimension,
  4311. + const uint8* buffer,
  4312. + FrameBuffer::Dimension dimension,
  4313. FrameBuffer::Format format) {
  4314. std::vector<FrameBuffer::Plane> planes;
  4315. switch (format) {
  4316. @@ -378,7 +392,8 @@ std::vector<FrameBuffer::Plane> FrameBufferUtils::GetPlanes(
  4317. }
  4318. FrameBuffer::Orientation FrameBufferUtils::GetOrientation(
  4319. - const FrameBuffer& buffer, const FrameBufferOperation& operation) {
  4320. + const FrameBuffer& buffer,
  4321. + const FrameBufferOperation& operation) {
  4322. if (absl::holds_alternative<OrientOperation>(operation)) {
  4323. return absl::get<OrientOperation>(operation).to_orientation;
  4324. }
  4325. @@ -386,7 +401,8 @@ FrameBuffer::Orientation FrameBufferUtils::GetOrientation(
  4326. }
  4327. FrameBuffer::Format FrameBufferUtils::GetFormat(
  4328. - const FrameBuffer& buffer, const FrameBufferOperation& operation) {
  4329. + const FrameBuffer& buffer,
  4330. + const FrameBufferOperation& operation) {
  4331. if (absl::holds_alternative<ConvertOperation>(operation)) {
  4332. return absl::get<ConvertOperation>(operation).to_format;
  4333. }
  4334. @@ -578,8 +594,10 @@ absl::Status FrameBufferUtils::Execute(
  4335. }
  4336. absl::Status FrameBufferUtils::Preprocess(
  4337. - const FrameBuffer& buffer, absl::optional<BoundingBox> bounding_box,
  4338. - FrameBuffer* output_buffer, bool uniform_resizing) {
  4339. + const FrameBuffer& buffer,
  4340. + absl::optional<BoundingBox> bounding_box,
  4341. + FrameBuffer* output_buffer,
  4342. + bool uniform_resizing) {
  4343. std::vector<FrameBufferOperation> frame_buffer_operations;
  4344. // Handle cropping and resizing.
  4345. bool needs_dimension_swap =
  4346. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.h
  4347. index 59e80e5765bb0..48549461159cb 100644
  4348. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.h
  4349. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.h
  4350. @@ -19,9 +19,9 @@ limitations under the License.
  4351. #include <memory>
  4352. #include <vector>
  4353. -#include "absl/status/status.h" // from @com_google_absl
  4354. +#include "absl/status/status.h" // from @com_google_absl
  4355. #include "absl/types/optional.h" // from @com_google_absl
  4356. -#include "absl/types/variant.h" // from @com_google_absl
  4357. +#include "absl/types/variant.h" // from @com_google_absl
  4358. #include "tensorflow_lite_support/cc/port/integral_types.h"
  4359. #include "tensorflow_lite_support/cc/task/vision/core/frame_buffer.h"
  4360. #include "tensorflow_lite_support/cc/task/vision/proto/bounding_box_proto_inc.h"
  4361. @@ -45,7 +45,10 @@ BoundingBox OrientBoundingBox(const BoundingBox& from_box,
  4362. // Same as OrientBoundingBox but from normalized coordinates.
  4363. BoundingBox OrientAndDenormalizeBoundingBox(
  4364. - float from_left, float from_top, float from_right, float from_bottom,
  4365. + float from_left,
  4366. + float from_top,
  4367. + float from_right,
  4368. + float from_bottom,
  4369. FrameBuffer::Orientation from_orientation,
  4370. FrameBuffer::Orientation to_orientation,
  4371. FrameBuffer::Dimension from_dimension);
  4372. @@ -53,10 +56,12 @@ BoundingBox OrientAndDenormalizeBoundingBox(
  4373. // Rotates `(from_x, from_y)` coordinates from an image of dimension
  4374. // `from_dimension` and orientation `from_orientation` into `(to_x, to_y)`
  4375. // coordinates with orientation `to_orientation`.
  4376. -void OrientCoordinates(int from_x, int from_y,
  4377. +void OrientCoordinates(int from_x,
  4378. + int from_y,
  4379. FrameBuffer::Orientation from_orientation,
  4380. FrameBuffer::Orientation to_orientation,
  4381. - FrameBuffer::Dimension from_dimension, int* to_x,
  4382. + FrameBuffer::Dimension from_dimension,
  4383. + int* to_x,
  4384. int* to_y);
  4385. // Returns whether the conversion from from_orientation to to_orientation
  4386. @@ -92,7 +97,8 @@ OrientParams GetOrientParams(FrameBuffer::Orientation from_orientation,
  4387. // To perform just cropping, the `crop_width` and `crop_height` should be the
  4388. // same as `resize_width` `and resize_height`.
  4389. struct CropResizeOperation {
  4390. - CropResizeOperation(int crop_origin_x, int crop_origin_y,
  4391. + CropResizeOperation(int crop_origin_x,
  4392. + int crop_origin_y,
  4393. FrameBuffer::Dimension crop_dimension,
  4394. FrameBuffer::Dimension resize_dimension)
  4395. : crop_origin_x(crop_origin_x),
  4396. @@ -124,7 +130,8 @@ struct CropResizeOperation {
  4397. // The resized region is aligned to the upper left pixel of the output buffer.
  4398. // The unfilled area of the output buffer remains untouched.
  4399. struct UniformCropResizeOperation {
  4400. - UniformCropResizeOperation(int crop_origin_x, int crop_origin_y,
  4401. + UniformCropResizeOperation(int crop_origin_x,
  4402. + int crop_origin_y,
  4403. FrameBuffer::Dimension crop_dimension,
  4404. FrameBuffer::Dimension output_dimension)
  4405. : crop_origin_x(crop_origin_x),
  4406. @@ -154,9 +161,10 @@ struct OrientOperation {
  4407. // A variant of the supported operations on FrameBuffers. Alias for user
  4408. // convenience.
  4409. -using FrameBufferOperation =
  4410. - absl::variant<CropResizeOperation, ConvertOperation, OrientOperation,
  4411. - UniformCropResizeOperation>;
  4412. +using FrameBufferOperation = absl::variant<CropResizeOperation,
  4413. + ConvertOperation,
  4414. + OrientOperation,
  4415. + UniformCropResizeOperation>;
  4416. // Image processing utility. This utility provides both basic image buffer
  4417. // manipulations (e.g. rotation, format conversion, resizing, etc) as well as
  4418. @@ -212,7 +220,11 @@ class FrameBufferUtils {
  4419. // should be big enough to store the operation result. If the `output_buffer`
  4420. // size dimension does not match with crop dimension, then a resize is
  4421. // automatically performed.
  4422. - absl::Status Crop(const FrameBuffer& buffer, int x0, int y0, int x1, int y1,
  4423. + absl::Status Crop(const FrameBuffer& buffer,
  4424. + int x0,
  4425. + int y0,
  4426. + int x1,
  4427. + int y1,
  4428. FrameBuffer* output_buffer);
  4429. // Performs resizing operation.
  4430. @@ -229,7 +241,8 @@ class FrameBufferUtils {
  4431. //
  4432. // The output_buffer should have metadata populated and its backing buffer
  4433. // should be big enough to store the operation result.
  4434. - absl::Status Rotate(const FrameBuffer& buffer, RotationDegree rotation,
  4435. + absl::Status Rotate(const FrameBuffer& buffer,
  4436. + RotationDegree rotation,
  4437. FrameBuffer* output_buffer);
  4438. // Performs horizontal flip operation.
  4439. @@ -305,7 +318,8 @@ class FrameBufferUtils {
  4440. // Returns the new FrameBuffer orientation after command is processed.
  4441. FrameBuffer::Orientation GetOrientation(
  4442. - const FrameBuffer& buffer, const FrameBufferOperation& operation);
  4443. + const FrameBuffer& buffer,
  4444. + const FrameBufferOperation& operation);
  4445. // Returns the new FrameBuffer format after command is processed.
  4446. FrameBuffer::Format GetFormat(const FrameBuffer& buffer,
  4447. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils_interface.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils_interface.h
  4448. index ec0c3119ea4e8..59da2206bb06f 100644
  4449. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils_interface.h
  4450. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils_interface.h
  4451. @@ -37,8 +37,12 @@ class FrameBufferUtilsInterface {
  4452. //
  4453. // The `output_buffer` should have metadata populated and its backing buffer
  4454. // should be big enough to store the operation result.
  4455. - virtual absl::Status Crop(const FrameBuffer& buffer, int x0, int y0, int x1,
  4456. - int y1, FrameBuffer* output_buffer) = 0;
  4457. + virtual absl::Status Crop(const FrameBuffer& buffer,
  4458. + int x0,
  4459. + int y0,
  4460. + int x1,
  4461. + int y1,
  4462. + FrameBuffer* output_buffer) = 0;
  4463. // Resizes `buffer` to the size of the given `output_buffer`.
  4464. //
  4465. @@ -57,7 +61,8 @@ class FrameBufferUtilsInterface {
  4466. //
  4467. // The `output_buffer` should have metadata populated and its backing buffer
  4468. // should be big enough to store the operation result.
  4469. - virtual absl::Status Rotate(const FrameBuffer& buffer, int angle_deg,
  4470. + virtual absl::Status Rotate(const FrameBuffer& buffer,
  4471. + int angle_deg,
  4472. FrameBuffer* output_buffer) = 0;
  4473. // Flips `buffer` horizontally.
  4474. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/image_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/image_utils.cc
  4475. index 3f8bc7b43f4f1..d5b277ad33b89 100644
  4476. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/image_utils.cc
  4477. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/image_utils.cc
  4478. @@ -23,11 +23,11 @@ limitations under the License.
  4479. #define STB_IMAGE_IMPLEMENTATION
  4480. #define STB_IMAGE_WRITE_IMPLEMENTATION
  4481. -#include "absl/status/status.h" // from @com_google_absl
  4482. -#include "absl/strings/match.h" // from @com_google_absl
  4483. +#include "absl/status/status.h" // from @com_google_absl
  4484. +#include "absl/strings/match.h" // from @com_google_absl
  4485. #include "absl/strings/str_format.h" // from @com_google_absl
  4486. -#include "stb_image.h" // from @stblib
  4487. -#include "stb_image_write.h" // from @stblib
  4488. +#include "stb_image.h" // from @stblib
  4489. +#include "stb_image_write.h" // from @stblib
  4490. #include "tensorflow_lite_support/cc/port/status_macros.h"
  4491. #include "tensorflow_lite_support/cc/port/statusor.h"
  4492. #include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h"
  4493. @@ -88,7 +88,9 @@ absl::Status EncodeImageToPngFile(const ImageData& image_data,
  4494. return absl::OkStatus();
  4495. }
  4496. -void ImageDataFree(ImageData* image) { stbi_image_free(image->pixel_data); }
  4497. +void ImageDataFree(ImageData* image) {
  4498. + stbi_image_free(image->pixel_data);
  4499. +}
  4500. tflite::support::StatusOr<std::unique_ptr<FrameBuffer>>
  4501. CreateFrameBufferFromImageData(const ImageData& image) {
  4502. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/image_utils.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/image_utils.h
  4503. index 6ba5c2d6490ab..7de32ee9c0f53 100644
  4504. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/image_utils.h
  4505. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/image_utils.h
  4506. @@ -15,7 +15,7 @@ limitations under the License.
  4507. #ifndef TENSORFLOW_LITE_SUPPORT_EXAMPLES_TASK_VISION_DESKTOP_UTILS_IMAGE_UTILS_H_
  4508. #define TENSORFLOW_LITE_SUPPORT_EXAMPLES_TASK_VISION_DESKTOP_UTILS_IMAGE_UTILS_H_
  4509. -#include "absl/status/status.h" // from @com_google_absl
  4510. +#include "absl/status/status.h" // from @com_google_absl
  4511. #include "absl/strings/string_view.h" // from @com_google_absl
  4512. #include "tensorflow_lite_support/cc/port/integral_types.h"
  4513. #include "tensorflow_lite_support/cc/port/statusor.h"
  4514. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/libyuv_frame_buffer_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/libyuv_frame_buffer_utils.cc
  4515. index e0dd8a99c64c0..a0ee2dab96b6a 100644
  4516. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/libyuv_frame_buffer_utils.cc
  4517. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/libyuv_frame_buffer_utils.cc
  4518. @@ -20,11 +20,11 @@ limitations under the License.
  4519. #include <memory>
  4520. #include <string>
  4521. -#include "absl/status/status.h" // from @com_google_absl
  4522. -#include "absl/strings/str_cat.h" // from @com_google_absl
  4523. +#include "absl/status/status.h" // from @com_google_absl
  4524. +#include "absl/strings/str_cat.h" // from @com_google_absl
  4525. #include "absl/strings/str_format.h" // from @com_google_absl
  4526. -#include "libyuv.h" // from @libyuv
  4527. -#include "libyuv/convert_argb.h" // from @libyuv
  4528. +#include "libyuv.h" // from @libyuv
  4529. +#include "libyuv/convert_argb.h" // from @libyuv
  4530. #include "tensorflow_lite_support/cc/common.h"
  4531. #include "tensorflow_lite_support/cc/port/integral_types.h"
  4532. #include "tensorflow_lite_support/cc/port/status_macros.h"
  4533. @@ -384,7 +384,8 @@ absl::Status ResizeNv(const FrameBuffer& buffer, FrameBuffer* output_buffer) {
  4534. // Converts `buffer` to libyuv ARGB format and stores the conversion result
  4535. // in `dest_argb`.
  4536. -absl::Status ConvertRgbToArgb(const FrameBuffer& buffer, uint8* dest_argb,
  4537. +absl::Status ConvertRgbToArgb(const FrameBuffer& buffer,
  4538. + uint8* dest_argb,
  4539. int dest_stride_argb) {
  4540. RETURN_IF_ERROR(ValidateBufferPlaneMetadata(buffer));
  4541. if (buffer.format() != FrameBuffer::Format::kRGB) {
  4542. @@ -421,7 +422,8 @@ absl::Status ConvertRgbToArgb(const FrameBuffer& buffer, uint8* dest_argb,
  4543. // Converts `src_argb` in libyuv ARGB format to FrameBuffer::kRGB format and
  4544. // stores the conversion result in `output_buffer`.
  4545. -absl::Status ConvertArgbToRgb(uint8* src_argb, int src_stride_argb,
  4546. +absl::Status ConvertArgbToRgb(uint8* src_argb,
  4547. + int src_stride_argb,
  4548. FrameBuffer* output_buffer) {
  4549. RETURN_IF_ERROR(ValidateBufferPlaneMetadata(*output_buffer));
  4550. if (output_buffer->format() != FrameBuffer::Format::kRGB) {
  4551. @@ -457,7 +459,8 @@ absl::Status ConvertArgbToRgb(uint8* src_argb, int src_stride_argb,
  4552. // Converts `buffer` in FrameBuffer::kRGBA format to libyuv ARGB (BGRA in
  4553. // memory) format and stores the conversion result in `dest_argb`.
  4554. -absl::Status ConvertRgbaToArgb(const FrameBuffer& buffer, uint8* dest_argb,
  4555. +absl::Status ConvertRgbaToArgb(const FrameBuffer& buffer,
  4556. + uint8* dest_argb,
  4557. int dest_stride_argb) {
  4558. RETURN_IF_ERROR(ValidateBufferPlaneMetadata(buffer));
  4559. if (buffer.format() != FrameBuffer::Format::kRGBA) {
  4560. @@ -689,7 +692,8 @@ libyuv::RotationMode GetLibyuvRotationMode(int angle_deg) {
  4561. }
  4562. }
  4563. -absl::Status RotateRgba(const FrameBuffer& buffer, int angle_deg,
  4564. +absl::Status RotateRgba(const FrameBuffer& buffer,
  4565. + int angle_deg,
  4566. FrameBuffer* output_buffer) {
  4567. if (buffer.plane_count() > 1) {
  4568. return CreateStatusWithPayload(
  4569. @@ -713,7 +717,8 @@ absl::Status RotateRgba(const FrameBuffer& buffer, int angle_deg,
  4570. return absl::OkStatus();
  4571. }
  4572. -absl::Status RotateRgb(const FrameBuffer& buffer, int angle_deg,
  4573. +absl::Status RotateRgb(const FrameBuffer& buffer,
  4574. + int angle_deg,
  4575. FrameBuffer* output_buffer) {
  4576. // libyuv does not support rotate kRGB (RGB24) foramat. In this method, the
  4577. // implementation converts kRGB format to ARGB and use ARGB buffer for
  4578. @@ -746,7 +751,8 @@ absl::Status RotateRgb(const FrameBuffer& buffer, int angle_deg,
  4579. output_buffer);
  4580. }
  4581. -absl::Status RotateGray(const FrameBuffer& buffer, int angle_deg,
  4582. +absl::Status RotateGray(const FrameBuffer& buffer,
  4583. + int angle_deg,
  4584. FrameBuffer* output_buffer) {
  4585. if (buffer.plane_count() > 1) {
  4586. return CreateStatusWithPayload(
  4587. @@ -769,7 +775,8 @@ absl::Status RotateGray(const FrameBuffer& buffer, int angle_deg,
  4588. }
  4589. // Rotates YV12/YV21 frame buffer.
  4590. -absl::Status RotateYv(const FrameBuffer& buffer, int angle_deg,
  4591. +absl::Status RotateYv(const FrameBuffer& buffer,
  4592. + int angle_deg,
  4593. FrameBuffer* output_buffer) {
  4594. ASSIGN_OR_RETURN(FrameBuffer::YuvData input_data,
  4595. FrameBuffer::GetYuvDataFromFrameBuffer(buffer));
  4596. @@ -794,7 +801,8 @@ absl::Status RotateYv(const FrameBuffer& buffer, int angle_deg,
  4597. // Rotates NV12/NV21 frame buffer.
  4598. // TODO(b/152097364): Refactor NV12/NV21 rotation after libyuv explicitly
  4599. // support that.
  4600. -absl::Status RotateNv(const FrameBuffer& buffer, int angle_deg,
  4601. +absl::Status RotateNv(const FrameBuffer& buffer,
  4602. + int angle_deg,
  4603. FrameBuffer* output_buffer) {
  4604. if (buffer.format() != FrameBuffer::Format::kNV12 &&
  4605. buffer.format() != FrameBuffer::Format::kNV21) {
  4606. @@ -884,8 +892,12 @@ absl::Status FlipPlaneVertically(const FrameBuffer& buffer,
  4607. }
  4608. // This method only supports kGRAY, kRGBA, and kRGB formats.
  4609. -absl::Status CropPlane(const FrameBuffer& buffer, int x0, int y0, int x1,
  4610. - int y1, FrameBuffer* output_buffer) {
  4611. +absl::Status CropPlane(const FrameBuffer& buffer,
  4612. + int x0,
  4613. + int y0,
  4614. + int x1,
  4615. + int y1,
  4616. + FrameBuffer* output_buffer) {
  4617. if (buffer.plane_count() > 1) {
  4618. return CreateStatusWithPayload(
  4619. StatusCode::kInternal,
  4620. @@ -912,7 +924,11 @@ absl::Status CropPlane(const FrameBuffer& buffer, int x0, int y0, int x1,
  4621. // Crops NV12/NV21 FrameBuffer to the subregion defined by the top left pixel
  4622. // position (x0, y0) and the bottom right pixel position (x1, y1).
  4623. -absl::Status CropNv(const FrameBuffer& buffer, int x0, int y0, int x1, int y1,
  4624. +absl::Status CropNv(const FrameBuffer& buffer,
  4625. + int x0,
  4626. + int y0,
  4627. + int x1,
  4628. + int y1,
  4629. FrameBuffer* output_buffer) {
  4630. ASSIGN_OR_RETURN(FrameBuffer::YuvData input_data,
  4631. FrameBuffer::GetYuvDataFromFrameBuffer(buffer));
  4632. @@ -944,7 +960,11 @@ absl::Status CropNv(const FrameBuffer& buffer, int x0, int y0, int x1, int y1,
  4633. // Crops YV12/YV21 FrameBuffer to the subregion defined by the top left pixel
  4634. // position (x0, y0) and the bottom right pixel position (x1, y1).
  4635. -absl::Status CropYv(const FrameBuffer& buffer, int x0, int y0, int x1, int y1,
  4636. +absl::Status CropYv(const FrameBuffer& buffer,
  4637. + int x0,
  4638. + int y0,
  4639. + int x1,
  4640. + int y1,
  4641. FrameBuffer* output_buffer) {
  4642. ASSIGN_OR_RETURN(FrameBuffer::YuvData input_data,
  4643. FrameBuffer::GetYuvDataFromFrameBuffer(buffer));
  4644. @@ -979,8 +999,12 @@ absl::Status CropYv(const FrameBuffer& buffer, int x0, int y0, int x1, int y1,
  4645. return absl::OkStatus();
  4646. }
  4647. -absl::Status CropResizeYuv(const FrameBuffer& buffer, int x0, int y0, int x1,
  4648. - int y1, FrameBuffer* output_buffer) {
  4649. +absl::Status CropResizeYuv(const FrameBuffer& buffer,
  4650. + int x0,
  4651. + int y0,
  4652. + int x1,
  4653. + int y1,
  4654. + FrameBuffer* output_buffer) {
  4655. FrameBuffer::Dimension crop_dimension = GetCropDimension(x0, x1, y0, y1);
  4656. if (crop_dimension == output_buffer->dimension()) {
  4657. switch (buffer.format()) {
  4658. @@ -1308,8 +1332,12 @@ absl::Status ResizeGray(const FrameBuffer& buffer, FrameBuffer* output_buffer) {
  4659. }
  4660. // This method only supports kGRAY, kRGBA, and kRGB formats.
  4661. -absl::Status CropResize(const FrameBuffer& buffer, int x0, int y0, int x1,
  4662. - int y1, FrameBuffer* output_buffer) {
  4663. +absl::Status CropResize(const FrameBuffer& buffer,
  4664. + int x0,
  4665. + int y0,
  4666. + int x1,
  4667. + int y1,
  4668. + FrameBuffer* output_buffer) {
  4669. FrameBuffer::Dimension crop_dimension = GetCropDimension(x0, x1, y0, y1);
  4670. if (crop_dimension == output_buffer->dimension()) {
  4671. return CropPlane(buffer, x0, y0, x1, y1, output_buffer);
  4672. @@ -1343,8 +1371,11 @@ absl::Status CropResize(const FrameBuffer& buffer, int x0, int y0, int x1,
  4673. } // namespace
  4674. -absl::Status LibyuvFrameBufferUtils::Crop(const FrameBuffer& buffer, int x0,
  4675. - int y0, int x1, int y1,
  4676. +absl::Status LibyuvFrameBufferUtils::Crop(const FrameBuffer& buffer,
  4677. + int x0,
  4678. + int y0,
  4679. + int x1,
  4680. + int y1,
  4681. FrameBuffer* output_buffer) {
  4682. RETURN_IF_ERROR(ValidateBufferPlaneMetadata(buffer));
  4683. RETURN_IF_ERROR(ValidateBufferPlaneMetadata(*output_buffer));
  4684. @@ -1425,7 +1456,8 @@ absl::Status LibyuvFrameBufferUtils::Rotate(const FrameBuffer& buffer,
  4685. }
  4686. absl::Status LibyuvFrameBufferUtils::FlipHorizontally(
  4687. - const FrameBuffer& buffer, FrameBuffer* output_buffer) {
  4688. + const FrameBuffer& buffer,
  4689. + FrameBuffer* output_buffer) {
  4690. RETURN_IF_ERROR(ValidateBufferPlaneMetadata(buffer));
  4691. RETURN_IF_ERROR(ValidateBufferPlaneMetadata(*output_buffer));
  4692. RETURN_IF_ERROR(ValidateFlipBufferInputs(buffer, *output_buffer));
  4693. @@ -1453,7 +1485,8 @@ absl::Status LibyuvFrameBufferUtils::FlipHorizontally(
  4694. }
  4695. absl::Status LibyuvFrameBufferUtils::FlipVertically(
  4696. - const FrameBuffer& buffer, FrameBuffer* output_buffer) {
  4697. + const FrameBuffer& buffer,
  4698. + FrameBuffer* output_buffer) {
  4699. RETURN_IF_ERROR(ValidateBufferPlaneMetadata(buffer));
  4700. RETURN_IF_ERROR(ValidateBufferPlaneMetadata(*output_buffer));
  4701. RETURN_IF_ERROR(ValidateFlipBufferInputs(buffer, *output_buffer));
  4702. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/libyuv_frame_buffer_utils.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/libyuv_frame_buffer_utils.h
  4703. index 5da898bc058a4..6f83559139130 100644
  4704. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/libyuv_frame_buffer_utils.h
  4705. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/libyuv_frame_buffer_utils.h
  4706. @@ -41,7 +41,11 @@ class LibyuvFrameBufferUtils : public FrameBufferUtilsInterface {
  4707. //
  4708. // Crop region dimensions must be equal or smaller than input `buffer`
  4709. // dimensions.
  4710. - absl::Status Crop(const FrameBuffer& buffer, int x0, int y0, int x1, int y1,
  4711. + absl::Status Crop(const FrameBuffer& buffer,
  4712. + int x0,
  4713. + int y0,
  4714. + int x1,
  4715. + int y1,
  4716. FrameBuffer* output_buffer) override;
  4717. // Resizes `buffer` to the size of the given `output_buffer`.
  4718. @@ -51,7 +55,8 @@ class LibyuvFrameBufferUtils : public FrameBufferUtilsInterface {
  4719. // Rotates `buffer` counter-clockwise by the given `angle_deg` (in degrees).
  4720. //
  4721. // The given angle must be a multiple of 90 degrees.
  4722. - absl::Status Rotate(const FrameBuffer& buffer, int angle_deg,
  4723. + absl::Status Rotate(const FrameBuffer& buffer,
  4724. + int angle_deg,
  4725. FrameBuffer* output_buffer) override;
  4726. // Flips `buffer` horizontally.
  4727. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/score_calibration.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/score_calibration.cc
  4728. index bc57c0b904534..d58969d96827e 100644
  4729. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/score_calibration.cc
  4730. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/score_calibration.cc
  4731. @@ -20,11 +20,11 @@ limitations under the License.
  4732. #include <utility>
  4733. #include <vector>
  4734. -#include "absl/status/status.h" // from @com_google_absl
  4735. -#include "absl/strings/str_format.h" // from @com_google_absl
  4736. -#include "absl/strings/str_split.h" // from @com_google_absl
  4737. +#include "absl/status/status.h" // from @com_google_absl
  4738. +#include "absl/strings/str_format.h" // from @com_google_absl
  4739. +#include "absl/strings/str_split.h" // from @com_google_absl
  4740. #include "absl/strings/string_view.h" // from @com_google_absl
  4741. -#include "absl/types/optional.h" // from @com_google_absl
  4742. +#include "absl/types/optional.h" // from @com_google_absl
  4743. #include "tensorflow_lite_support/cc/common.h"
  4744. #include "tensorflow_lite_support/cc/port/status_macros.h"
  4745. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/score_calibration.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/score_calibration.h
  4746. index 95cbecf54bd1d..e2b403d9b35b9 100644
  4747. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/score_calibration.h
  4748. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/score_calibration.h
  4749. @@ -23,9 +23,9 @@ limitations under the License.
  4750. #include <vector>
  4751. #include "absl/container/flat_hash_map.h" // from @com_google_absl
  4752. -#include "absl/status/status.h" // from @com_google_absl
  4753. -#include "absl/strings/string_view.h" // from @com_google_absl
  4754. -#include "absl/types/optional.h" // from @com_google_absl
  4755. +#include "absl/status/status.h" // from @com_google_absl
  4756. +#include "absl/strings/string_view.h" // from @com_google_absl
  4757. +#include "absl/types/optional.h" // from @com_google_absl
  4758. #include "tensorflow_lite_support/cc/port/statusor.h"
  4759. #include "tensorflow_lite_support/cc/task/vision/core/label_map_item.h"
  4760. #include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
  4761. @@ -37,7 +37,10 @@ namespace vision {
  4762. // Sigmoid structure.
  4763. struct Sigmoid {
  4764. Sigmoid() : scale(1.0) {}
  4765. - Sigmoid(std::string label, float slope, float offset, float scale = 1.0,
  4766. + Sigmoid(std::string label,
  4767. + float slope,
  4768. + float offset,
  4769. + float scale = 1.0,
  4770. absl::optional<float> min_uncalibrated_score = absl::nullopt)
  4771. : label(label),
  4772. slope(slope),
  4773. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/common_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/common_test.cc
  4774. index 311994c1abbf9..bc2f9dfd53a96 100644
  4775. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/common_test.cc
  4776. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/common_test.cc
  4777. @@ -16,7 +16,7 @@ limitations under the License.
  4778. #include "tensorflow_lite_support/cc/common.h"
  4779. #include "absl/status/status.h" // from @com_google_absl
  4780. -#include "absl/strings/cord.h" // from @com_google_absl
  4781. +#include "absl/strings/cord.h" // from @com_google_absl
  4782. #include "tensorflow_lite_support/cc/port/gmock.h"
  4783. #include "tensorflow_lite_support/cc/port/gtest.h"
  4784. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/processor/image_preprocessor_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/processor/image_preprocessor_test.cc
  4785. index 9a00e2f9e89a1..ef0e783e97c3e 100644
  4786. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/processor/image_preprocessor_test.cc
  4787. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/processor/image_preprocessor_test.cc
  4788. @@ -46,8 +46,8 @@ constexpr char kTestDataDirectory[] =
  4789. constexpr char kDilatedConvolutionModelWithMetaData[] = "dilated_conv.tflite";
  4790. StatusOr<ImageData> LoadImage(std::string image_name) {
  4791. - return DecodeImageFromFile(JoinPath("./" /*test src dir*/,
  4792. - kTestDataDirectory, image_name));
  4793. + return DecodeImageFromFile(
  4794. + JoinPath("./" /*test src dir*/, kTestDataDirectory, image_name));
  4795. }
  4796. class DynamicInputTest : public tflite_shims::testing::Test {
  4797. @@ -60,7 +60,7 @@ class DynamicInputTest : public tflite_shims::testing::Test {
  4798. SUPPORT_ASSERT_OK(engine_->InitInterpreter());
  4799. SUPPORT_ASSERT_OK_AND_ASSIGN(auto preprocessor,
  4800. - ImagePreprocessor::Create(engine_.get(), {0}));
  4801. + ImagePreprocessor::Create(engine_.get(), {0}));
  4802. SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image, LoadImage("burger.jpg"));
  4803. std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbRawBuffer(
  4804. @@ -94,9 +94,10 @@ TEST_F(DynamicInputTest, GoldenImageComparison) {
  4805. PreprocessImage();
  4806. // Get the processed input image.
  4807. - SUPPORT_ASSERT_OK_AND_ASSIGN(float* processed_input_data,
  4808. - tflite::task::core::AssertAndReturnTypedTensor<float>(
  4809. - engine_->GetInputs()[0]));
  4810. + SUPPORT_ASSERT_OK_AND_ASSIGN(
  4811. + float* processed_input_data,
  4812. + tflite::task::core::AssertAndReturnTypedTensor<float>(
  4813. + engine_->GetInputs()[0]));
  4814. SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image, LoadImage("burger.jpg"));
  4815. const uint8* image_data = image.pixel_data;
  4816. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/bert_nl_classifier_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/bert_nl_classifier_test.cc
  4817. index 629f069e7b8d1..c4a8cea0d53b9 100644
  4818. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/bert_nl_classifier_test.cc
  4819. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/bert_nl_classifier_test.cc
  4820. @@ -49,8 +49,7 @@ constexpr char kInvalidModelPath[] = "i/do/not/exist.tflite";
  4821. constexpr int kMaxSeqLen = 128;
  4822. std::string GetFullPath(absl::string_view file_name) {
  4823. - return JoinPath("./" /*test src dir*/, kTestDataDirectory,
  4824. - file_name);
  4825. + return JoinPath("./" /*test src dir*/, kTestDataDirectory, file_name);
  4826. }
  4827. class BertNLClassifierTest : public tflite_shims::testing::Test {};
  4828. @@ -77,14 +76,15 @@ TEST_F(BertNLClassifierTest, CreateFromOptionsFailsWithMissingBaseOptions) {
  4829. }
  4830. TEST_F(BertNLClassifierTest, TestNLClassifierCreationFilePath) {
  4831. - SUPPORT_ASSERT_OK(BertNLClassifier::CreateFromFile(GetFullPath(kTestModelPath)));
  4832. + SUPPORT_ASSERT_OK(
  4833. + BertNLClassifier::CreateFromFile(GetFullPath(kTestModelPath)));
  4834. }
  4835. TEST_F(BertNLClassifierTest, TestNLClassifierCreationBinary) {
  4836. std::string model_buffer =
  4837. LoadBinaryContent(GetFullPath(kTestModelPath).c_str());
  4838. SUPPORT_ASSERT_OK(BertNLClassifier::CreateFromBuffer(model_buffer.data(),
  4839. - model_buffer.size()));
  4840. + model_buffer.size()));
  4841. }
  4842. TEST_F(BertNLClassifierTest, TestNLClassifierCreationFailure) {
  4843. @@ -136,7 +136,7 @@ TEST_F(BertNLClassifierTest, ClassifySucceedsWithBaseOptions) {
  4844. contents);
  4845. SUPPORT_ASSERT_OK_AND_ASSIGN(classifier,
  4846. - BertNLClassifier::CreateFromOptions(options));
  4847. + BertNLClassifier::CreateFromOptions(options));
  4848. }
  4849. verify_classifier(std::move(classifier), /*verify_positive=*/false);
  4850. @@ -146,8 +146,8 @@ TEST_F(BertNLClassifierTest, TestNLClassifier_ClassifyNegative) {
  4851. std::string model_buffer =
  4852. LoadBinaryContent(GetFullPath(kTestModelPath).c_str());
  4853. SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<BertNLClassifier> classifier,
  4854. - BertNLClassifier::CreateFromBuffer(model_buffer.data(),
  4855. - model_buffer.size()));
  4856. + BertNLClassifier::CreateFromBuffer(
  4857. + model_buffer.data(), model_buffer.size()));
  4858. verify_classifier(std::move(classifier), false);
  4859. }
  4860. @@ -156,24 +156,26 @@ TEST_F(BertNLClassifierTest, TestNLClassifier_ClassifyPositive) {
  4861. std::string model_buffer =
  4862. LoadBinaryContent(GetFullPath(kTestModelPath).c_str());
  4863. SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<BertNLClassifier> classifier,
  4864. - BertNLClassifier::CreateFromBuffer(model_buffer.data(),
  4865. - model_buffer.size()));
  4866. + BertNLClassifier::CreateFromBuffer(
  4867. + model_buffer.data(), model_buffer.size()));
  4868. verify_classifier(std::move(classifier), true);
  4869. }
  4870. TEST_F(BertNLClassifierTest, TestNLClassifierFd_ClassifyPositive) {
  4871. - SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<BertNLClassifier> classifier,
  4872. - BertNLClassifier::CreateFromFd(open(
  4873. - GetFullPath(kTestModelPath).c_str(), O_RDONLY)));
  4874. + SUPPORT_ASSERT_OK_AND_ASSIGN(
  4875. + std::unique_ptr<BertNLClassifier> classifier,
  4876. + BertNLClassifier::CreateFromFd(
  4877. + open(GetFullPath(kTestModelPath).c_str(), O_RDONLY)));
  4878. verify_classifier(std::move(classifier), false);
  4879. }
  4880. TEST_F(BertNLClassifierTest, TestNLClassifierFd_ClassifyNegative) {
  4881. - SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<BertNLClassifier> classifier,
  4882. - BertNLClassifier::CreateFromFd(open(
  4883. - GetFullPath(kTestModelPath).c_str(), O_RDONLY)));
  4884. + SUPPORT_ASSERT_OK_AND_ASSIGN(
  4885. + std::unique_ptr<BertNLClassifier> classifier,
  4886. + BertNLClassifier::CreateFromFd(
  4887. + open(GetFullPath(kTestModelPath).c_str(), O_RDONLY)));
  4888. verify_classifier(std::move(classifier), true);
  4889. }
  4890. @@ -191,8 +193,8 @@ TEST_F(BertNLClassifierTest, TestNLClassifier_ClassifyLongPositive_notOOB) {
  4891. }
  4892. ss_for_positive_review << " movie review";
  4893. SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<BertNLClassifier> classifier,
  4894. - BertNLClassifier::CreateFromBuffer(model_buffer.data(),
  4895. - model_buffer.size()));
  4896. + BertNLClassifier::CreateFromBuffer(
  4897. + model_buffer.data(), model_buffer.size()));
  4898. std::vector<core::Category> results =
  4899. classifier->Classify(ss_for_positive_review.str());
  4900. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/bert_question_answerer_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/bert_question_answerer_test.cc
  4901. index 252441df1cb59..a70dab7782044 100644
  4902. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/bert_question_answerer_test.cc
  4903. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/bert_question_answerer_test.cc
  4904. @@ -69,8 +69,7 @@ constexpr int kPredictAnsNum = 5;
  4905. class BertQuestionAnswererTest : public tflite_shims::testing::Test {};
  4906. std::string GetFullPath(absl::string_view file_name) {
  4907. - return JoinPath("./" /*test src dir*/, kTestDataDirectory,
  4908. - file_name);
  4909. + return JoinPath("./" /*test src dir*/, kTestDataDirectory, file_name);
  4910. }
  4911. TEST_F(BertQuestionAnswererTest,
  4912. @@ -108,8 +107,8 @@ TEST_F(BertQuestionAnswererTest, AnswerSucceedsWithModelWithMetadata) {
  4913. options.mutable_base_options()->mutable_model_file()->set_file_content(
  4914. contents);
  4915. - SUPPORT_ASSERT_OK_AND_ASSIGN(question_answerer,
  4916. - BertQuestionAnswerer::CreateFromOptions(options));
  4917. + SUPPORT_ASSERT_OK_AND_ASSIGN(
  4918. + question_answerer, BertQuestionAnswerer::CreateFromOptions(options));
  4919. }
  4920. std::vector<QaAnswer> answer = question_answerer->Answer(kContext, kQuestion);
  4921. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/clu_lib/bert_utils_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/clu_lib/bert_utils_test.cc
  4922. index 3d98fe16b07e9..6fd9508fd1ba0 100644
  4923. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/clu_lib/bert_utils_test.cc
  4924. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/clu_lib/bert_utils_test.cc
  4925. @@ -128,10 +128,10 @@ TEST_F(BertUtilsTestClass, ZeroHistoryNotTrucated) {
  4926. std::vector<int> subword_indicators;
  4927. std::vector<int> segment_id_list;
  4928. std::vector<int> turn_id_list;
  4929. - SUPPORT_ASSERT_OK(BertPreprocessing(tokenizer_.get(), conversations_in_reverse_order,
  4930. - max_seq_length, max_history_turns, &token_ids,
  4931. - &token_alignments, &subword_indicators,
  4932. - &segment_id_list, &turn_id_list));
  4933. + SUPPORT_ASSERT_OK(BertPreprocessing(
  4934. + tokenizer_.get(), conversations_in_reverse_order, max_seq_length,
  4935. + max_history_turns, &token_ids, &token_alignments, &subword_indicators,
  4936. + &segment_id_list, &turn_id_list));
  4937. EXPECT_THAT(token_ids, expected_token_ids);
  4938. EXPECT_THAT(token_alignments, expected_token_alignments);
  4939. EXPECT_THAT(subword_indicators, expected_first_subword_indicators);
  4940. @@ -193,10 +193,10 @@ TEST_F(BertUtilsTestClass, ZeroHistoryTrucated) {
  4941. std::vector<int> subword_indicators;
  4942. std::vector<int> segment_id_list;
  4943. std::vector<int> turn_id_list;
  4944. - SUPPORT_ASSERT_OK(BertPreprocessing(tokenizer_.get(), conversations_in_reverse_order,
  4945. - max_seq_length, max_history_turns, &token_ids,
  4946. - &token_alignments, &subword_indicators,
  4947. - &segment_id_list, &turn_id_list));
  4948. + SUPPORT_ASSERT_OK(BertPreprocessing(
  4949. + tokenizer_.get(), conversations_in_reverse_order, max_seq_length,
  4950. + max_history_turns, &token_ids, &token_alignments, &subword_indicators,
  4951. + &segment_id_list, &turn_id_list));
  4952. EXPECT_THAT(token_ids, expected_token_ids);
  4953. EXPECT_THAT(token_alignments, expected_token_alignments);
  4954. EXPECT_THAT(subword_indicators, expected_first_subword_indicators);
  4955. @@ -342,10 +342,10 @@ TEST_F(BertUtilsTestClass, WithHistoryNotTrucated) {
  4956. std::vector<int> subword_indicators;
  4957. std::vector<int> segment_id_list;
  4958. std::vector<int> turn_id_list;
  4959. - SUPPORT_ASSERT_OK(BertPreprocessing(tokenizer_.get(), conversations_in_reverse_order,
  4960. - max_seq_length, max_history_turns, &token_ids,
  4961. - &token_alignments, &subword_indicators,
  4962. - &segment_id_list, &turn_id_list));
  4963. + SUPPORT_ASSERT_OK(BertPreprocessing(
  4964. + tokenizer_.get(), conversations_in_reverse_order, max_seq_length,
  4965. + max_history_turns, &token_ids, &token_alignments, &subword_indicators,
  4966. + &segment_id_list, &turn_id_list));
  4967. EXPECT_THAT(token_ids, expected_token_ids);
  4968. EXPECT_THAT(token_alignments, expected_token_alignments);
  4969. EXPECT_THAT(subword_indicators, expected_first_subword_indicators);
  4970. @@ -458,10 +458,10 @@ TEST_F(BertUtilsTestClass, WithHistoryTrucated) {
  4971. std::vector<int> subword_indicators;
  4972. std::vector<int> segment_id_list;
  4973. std::vector<int> turn_id_list;
  4974. - SUPPORT_ASSERT_OK(BertPreprocessing(tokenizer_.get(), conversations_in_reverse_order,
  4975. - max_seq_length, max_history_turns, &token_ids,
  4976. - &token_alignments, &subword_indicators,
  4977. - &segment_id_list, &turn_id_list));
  4978. + SUPPORT_ASSERT_OK(BertPreprocessing(
  4979. + tokenizer_.get(), conversations_in_reverse_order, max_seq_length,
  4980. + max_history_turns, &token_ids, &token_alignments, &subword_indicators,
  4981. + &segment_id_list, &turn_id_list));
  4982. EXPECT_THAT(token_ids, expected_token_ids);
  4983. EXPECT_THAT(token_alignments, expected_token_alignments);
  4984. EXPECT_THAT(subword_indicators, expected_first_subword_indicators);
  4985. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/clu_lib/intent_repr_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/clu_lib/intent_repr_test.cc
  4986. index 8341751bbbac2..0501ec4a669b5 100644
  4987. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/clu_lib/intent_repr_test.cc
  4988. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/clu_lib/intent_repr_test.cc
  4989. @@ -29,7 +29,7 @@ TEST(IntentClassification, IntentRepr) {
  4990. TEST(IntentClassification, IntentRepr2) {
  4991. SUPPORT_ASSERT_OK_AND_ASSIGN(const auto intent_repr,
  4992. - IntentRepr::CreateFromFullName("REQUEST"));
  4993. + IntentRepr::CreateFromFullName("REQUEST"));
  4994. EXPECT_EQ(intent_repr.Name(), "REQUEST");
  4995. EXPECT_EQ(intent_repr.Domain(), "");
  4996. }
  4997. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/nlclassifier/nl_classifier_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/nlclassifier/nl_classifier_test.cc
  4998. index 67b03c3a45323..81198cfca30fc 100644
  4999. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/nlclassifier/nl_classifier_test.cc
  5000. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/nlclassifier/nl_classifier_test.cc
  5001. @@ -121,8 +121,7 @@ struct ProtoOptionsTestParam {
  5002. };
  5003. std::string GetFullPath(absl::string_view file_name) {
  5004. - return JoinPath("./" /*test src dir*/, kTestDataDirectory,
  5005. - file_name);
  5006. + return JoinPath("./" /*test src dir*/, kTestDataDirectory, file_name);
  5007. }
  5008. class ProtoOptionsTest : public TestWithParam<ProtoOptionsTestParam> {
  5009. @@ -163,7 +162,8 @@ TEST_F(ProtoOptionsTest, ClassifySucceedsWithBaseOptions) {
  5010. options.mutable_base_options()->mutable_model_file()->set_file_content(
  5011. contents);
  5012. - SUPPORT_ASSERT_OK_AND_ASSIGN(classifier, NLClassifier::CreateFromOptions(options));
  5013. + SUPPORT_ASSERT_OK_AND_ASSIGN(classifier,
  5014. + NLClassifier::CreateFromOptions(options));
  5015. }
  5016. std::vector<core::Category> positive_results =
  5017. @@ -180,8 +180,8 @@ TEST_F(ProtoOptionsTest, ClassifySucceedsWithBaseOptions) {
  5018. TEST_F(ProtoOptionsTest, CreationFromIncorrectInputTensor) {
  5019. NLClassifierProtoOptions options;
  5020. - options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath(
  5021. - "./" /*test src dir*/, kTestDataDirectory, kTestModelPath));
  5022. + options.mutable_base_options()->mutable_model_file()->set_file_name(
  5023. + JoinPath("./" /*test src dir*/, kTestDataDirectory, kTestModelPath));
  5024. options.set_input_tensor_name("invalid_tensor_name");
  5025. options.set_input_tensor_index(-1);
  5026. @@ -200,8 +200,8 @@ TEST_F(ProtoOptionsTest, CreationFromIncorrectInputTensor) {
  5027. TEST_F(ProtoOptionsTest, CreationFromIncorrectOutputScoreTensor) {
  5028. NLClassifierProtoOptions options;
  5029. - options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath(
  5030. - "./" /*test src dir*/, kTestDataDirectory, kTestModelPath));
  5031. + options.mutable_base_options()->mutable_model_file()->set_file_name(
  5032. + JoinPath("./" /*test src dir*/, kTestDataDirectory, kTestModelPath));
  5033. options.set_output_score_tensor_name("invalid_tensor_name");
  5034. options.set_output_score_tensor_index(-1);
  5035. @@ -224,7 +224,7 @@ TEST_F(ProtoOptionsTest, TestInferenceWithRegexTokenizer) {
  5036. options.mutable_base_options()->mutable_model_file()->set_file_name(
  5037. GetFullPath(kTestModelWithRegexTokenizer));
  5038. SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<NLClassifier> classifier,
  5039. - NLClassifier::CreateFromOptions(options));
  5040. + NLClassifier::CreateFromOptions(options));
  5041. std::vector<core::Category> positive_results =
  5042. classifier->Classify(kPositiveInput);
  5043. @@ -277,7 +277,7 @@ TEST_F(ProtoOptionsTest, TestInferenceWithAssociatedLabelBuiltinOps) {
  5044. options.mutable_base_options()->mutable_model_file()->set_file_name(
  5045. GetFullPath(kTestModelWithLabelBuiltInOpsPath));
  5046. SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<NLClassifier> classifier,
  5047. - NLClassifier::CreateFromOptions(options));
  5048. + NLClassifier::CreateFromOptions(options));
  5049. std::vector<core::Category> results = classifier->Classify(kInputStr);
  5050. std::vector<core::Category> expected_class = {
  5051. {"Negative", 0.49332118034362793},
  5052. @@ -296,8 +296,10 @@ struct ProtoOptionsTestParamToString {
  5053. };
  5054. NLClassifierProtoOptions CreateProtoOptionsFromTensorName(
  5055. - const char* input_tensor_name, const char* output_score_tensor_name,
  5056. - const char* output_label_tensor_name, const char* model_path) {
  5057. + const char* input_tensor_name,
  5058. + const char* output_score_tensor_name,
  5059. + const char* output_label_tensor_name,
  5060. + const char* model_path) {
  5061. NLClassifierProtoOptions options;
  5062. options.set_input_tensor_name(input_tensor_name);
  5063. options.set_output_score_tensor_name(output_score_tensor_name);
  5064. @@ -310,8 +312,10 @@ NLClassifierProtoOptions CreateProtoOptionsFromTensorName(
  5065. }
  5066. NLClassifierProtoOptions CreateProtoOptionsFromTensorIndex(
  5067. - const int input_tensor_index, const int output_score_tensor_index,
  5068. - const int output_label_tensor_index, const char* model_path) {
  5069. + const int input_tensor_index,
  5070. + const int output_score_tensor_index,
  5071. + const int output_label_tensor_index,
  5072. + const char* model_path) {
  5073. NLClassifierProtoOptions options;
  5074. options.set_input_tensor_index(input_tensor_index);
  5075. options.set_output_score_tensor_index(output_score_tensor_index);
  5076. @@ -439,14 +443,16 @@ TEST_P(ProtoOptionsTest, TestClassify) {
  5077. EXPECT_THAT(results, UnorderedElementsAreArray(expected_class));
  5078. }
  5079. -INSTANTIATE_TEST_SUITE_P(TestClassify, ProtoOptionsTest,
  5080. +INSTANTIATE_TEST_SUITE_P(TestClassify,
  5081. + ProtoOptionsTest,
  5082. ValuesIn(ClassifyParams()),
  5083. ProtoOptionsTestParamToString());
  5084. // Tests for struct sNLClassifierOptions.
  5085. class StructOptionsTest : public tflite_shims::testing::Test {};
  5086. -void AssertStatus(absl::Status status, absl::StatusCode status_code,
  5087. +void AssertStatus(absl::Status status,
  5088. + absl::StatusCode status_code,
  5089. TfLiteSupportStatus tfls_code) {
  5090. ASSERT_EQ(status.code(), status_code);
  5091. EXPECT_THAT(status.GetPayload(kTfLiteSupportPayload),
  5092. @@ -454,30 +460,29 @@ void AssertStatus(absl::Status status, absl::StatusCode status_code,
  5093. }
  5094. TEST_F(StructOptionsTest, TestApiCreationFromBuffer) {
  5095. - std::string model_buffer =
  5096. - LoadBinaryContent(JoinPath("./" /*test src dir*/,
  5097. - kTestDataDirectory, kTestModelPath)
  5098. - .c_str());
  5099. + std::string model_buffer = LoadBinaryContent(
  5100. + JoinPath("./" /*test src dir*/, kTestDataDirectory, kTestModelPath)
  5101. + .c_str());
  5102. SUPPORT_ASSERT_OK(NLClassifier::CreateFromBufferAndOptions(
  5103. model_buffer.data(), model_buffer.size(), {}, CreateCustomResolver()));
  5104. }
  5105. TEST_F(StructOptionsTest, TestApiCreationFromFile) {
  5106. - SUPPORT_ASSERT_OK(NLClassifier::CreateFromFileAndOptions(GetFullPath(kTestModelPath),
  5107. - {}, CreateCustomResolver()));
  5108. + SUPPORT_ASSERT_OK(NLClassifier::CreateFromFileAndOptions(
  5109. + GetFullPath(kTestModelPath), {}, CreateCustomResolver()));
  5110. }
  5111. TEST_F(StructOptionsTest, TestApiCreationFromIncorrectInputTensor) {
  5112. NLClassifierOptions options;
  5113. options.input_tensor_index = -1;
  5114. options.input_tensor_name = "I do not exist";
  5115. - AssertStatus(NLClassifier::CreateFromFileAndOptions(
  5116. - JoinPath("./" /*test src dir*/,
  5117. - kTestDataDirectory, kTestModelPath),
  5118. - options, CreateCustomResolver())
  5119. - .status(),
  5120. - absl::StatusCode::kInvalidArgument,
  5121. - TfLiteSupportStatus::kInputTensorNotFoundError);
  5122. + AssertStatus(
  5123. + NLClassifier::CreateFromFileAndOptions(
  5124. + JoinPath("./" /*test src dir*/, kTestDataDirectory, kTestModelPath),
  5125. + options, CreateCustomResolver())
  5126. + .status(),
  5127. + absl::StatusCode::kInvalidArgument,
  5128. + TfLiteSupportStatus::kInputTensorNotFoundError);
  5129. }
  5130. TEST_F(StructOptionsTest, TestApiCreationFromIncorrectOutputScoreTensor) {
  5131. @@ -497,9 +502,10 @@ TEST_F(StructOptionsTest, TestInferenceWithRegexTokenizer) {
  5132. options.output_score_tensor_name = "probability";
  5133. // The model with regex tokenizer doesn't need any custom ops.
  5134. - SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<NLClassifier> classifier,
  5135. - NLClassifier::CreateFromFileAndOptions(
  5136. - GetFullPath(kTestModelWithRegexTokenizer), options));
  5137. + SUPPORT_ASSERT_OK_AND_ASSIGN(
  5138. + std::unique_ptr<NLClassifier> classifier,
  5139. + NLClassifier::CreateFromFileAndOptions(
  5140. + GetFullPath(kTestModelWithRegexTokenizer), options));
  5141. std::vector<core::Category> positive_results =
  5142. classifier->Classify(kPositiveInput);
  5143. @@ -519,9 +525,9 @@ TEST_F(StructOptionsTest, TestInferenceWithBoolOutput) {
  5144. options.output_score_tensor_index = 0;
  5145. SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<NLClassifier> classifier,
  5146. - NLClassifier::CreateFromFileAndOptions(
  5147. - GetFullPath(kTestModelBoolOutputPath), options,
  5148. - CreateCustomResolver()));
  5149. + NLClassifier::CreateFromFileAndOptions(
  5150. + GetFullPath(kTestModelBoolOutputPath),
  5151. + options, CreateCustomResolver()));
  5152. std::vector<core::Category> results = classifier->Classify(kInputStr);
  5153. std::vector<core::Category> expected_class = {
  5154. {"0", 1},
  5155. @@ -535,10 +541,11 @@ TEST_F(StructOptionsTest, TestInferenceWithBoolOutput) {
  5156. TEST_F(StructOptionsTest, TestInferenceWithAssociatedLabelCustomOps) {
  5157. NLClassifierOptions options;
  5158. options.output_score_tensor_name = kMetadataOutputScoreTensorName;
  5159. - SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<NLClassifier> classifier,
  5160. - NLClassifier::CreateFromFileAndOptions(
  5161. - GetFullPath(kTestModelWithLabelCustomOpsPath),
  5162. - options, CreateCustomResolver()));
  5163. + SUPPORT_ASSERT_OK_AND_ASSIGN(
  5164. + std::unique_ptr<NLClassifier> classifier,
  5165. + NLClassifier::CreateFromFileAndOptions(
  5166. + GetFullPath(kTestModelWithLabelCustomOpsPath), options,
  5167. + CreateCustomResolver()));
  5168. std::vector<core::Category> results = classifier->Classify(kInputStr);
  5169. std::vector<core::Category> expected_class = {
  5170. {"label0", 255},
  5171. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/text_embedder_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/text_embedder_test.cc
  5172. index 5a86a288b4624..b097813ecedf7 100644
  5173. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/text_embedder_test.cc
  5174. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/text_embedder_test.cc
  5175. @@ -17,7 +17,7 @@ limitations under the License.
  5176. #include <iostream>
  5177. -#include "absl/status/status.h" // from @com_google_absl
  5178. +#include "absl/status/status.h" // from @com_google_absl
  5179. #include "absl/strings/string_view.h" // from @com_google_absl
  5180. #include "tensorflow/lite/core/shims/cc/shims_test_util.h"
  5181. #include "tensorflow_lite_support/cc/port/gmock.h"
  5182. @@ -56,8 +56,8 @@ class CreateFromOptionsTest : public tflite_shims::testing::Test {};
  5183. TextEmbedderOptions GetBasicOptions(absl::string_view model_name) {
  5184. TextEmbedderOptions options;
  5185. - options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath(
  5186. - "./" /*test src dir*/, kTestDataDirectory, model_name));
  5187. + options.mutable_base_options()->mutable_model_file()->set_file_name(
  5188. + JoinPath("./" /*test src dir*/, kTestDataDirectory, model_name));
  5189. return options;
  5190. }
  5191. @@ -130,7 +130,7 @@ TEST(EmbedTest, SucceedsWithMobileBertModel) {
  5192. TextEmbedderOptions options = GetBasicOptions(kMobileBert);
  5193. // No Embedding options means all head get a default option.
  5194. SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextEmbedder> text_embedder,
  5195. - TextEmbedder::CreateFromOptions(options));
  5196. + TextEmbedder::CreateFromOptions(options));
  5197. SUPPORT_ASSERT_OK_AND_ASSIGN(
  5198. auto result0,
  5199. @@ -141,8 +141,8 @@ TEST(EmbedTest, SucceedsWithMobileBertModel) {
  5200. EXPECT_NEAR(result0.embeddings(0).feature_vector().value_float(0), 19.9016f,
  5201. kValueDiffTolerance);
  5202. - SUPPORT_ASSERT_OK_AND_ASSIGN(auto result1,
  5203. - text_embedder->Embed("what a great and fantastic trip"));
  5204. + SUPPORT_ASSERT_OK_AND_ASSIGN(
  5205. + auto result1, text_embedder->Embed("what a great and fantastic trip"));
  5206. EXPECT_EQ(result1.embeddings_size(), 1);
  5207. EXPECT_EQ(result1.embeddings(0).feature_vector().value_float_size(), 512);
  5208. @@ -162,7 +162,7 @@ TEST(EmbedTest, SucceedsWithRegexModel) {
  5209. TextEmbedderOptions options = GetBasicOptions(kRegexOneEmbeddingModel);
  5210. // No Embedding options means all head get a default option.
  5211. SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextEmbedder> text_embedder,
  5212. - TextEmbedder::CreateFromOptions(options));
  5213. + TextEmbedder::CreateFromOptions(options));
  5214. SUPPORT_ASSERT_OK_AND_ASSIGN(
  5215. auto result0,
  5216. @@ -173,8 +173,8 @@ TEST(EmbedTest, SucceedsWithRegexModel) {
  5217. EXPECT_NEAR(result0.embeddings(0).feature_vector().value_float(0), 0.0309356f,
  5218. kValueDiffTolerance);
  5219. - SUPPORT_ASSERT_OK_AND_ASSIGN(auto result1,
  5220. - text_embedder->Embed("what a great and fantastic trip"));
  5221. + SUPPORT_ASSERT_OK_AND_ASSIGN(
  5222. + auto result1, text_embedder->Embed("what a great and fantastic trip"));
  5223. EXPECT_EQ(result1.embeddings_size(), 1);
  5224. EXPECT_EQ(result1.embeddings(0).feature_vector().value_float_size(), 16);
  5225. @@ -206,8 +206,8 @@ TEST(EmbedTest, SucceedsWithUniversalSentenceEncoder) {
  5226. EXPECT_NEAR(result0.embeddings(0).feature_vector().value_float(0), 1.422951f,
  5227. kValueDiffTolerance);
  5228. - SUPPORT_ASSERT_OK_AND_ASSIGN(auto result1,
  5229. - text_embedder->Embed("what a great and fantastic trip"));
  5230. + SUPPORT_ASSERT_OK_AND_ASSIGN(
  5231. + auto result1, text_embedder->Embed("what a great and fantastic trip"));
  5232. EXPECT_EQ(result1.embeddings_size(), 1);
  5233. EXPECT_EQ(result1.embeddings(0).feature_vector().value_float_size(), 100);
  5234. @@ -227,7 +227,7 @@ TEST(GetEmbeddingDimension, Succeeds) {
  5235. // Create embedder.
  5236. TextEmbedderOptions options = GetBasicOptions(kMobileBert);
  5237. SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextEmbedder> text_embedder,
  5238. - TextEmbedder::CreateFromOptions(options));
  5239. + TextEmbedder::CreateFromOptions(options));
  5240. EXPECT_EQ(text_embedder->GetEmbeddingDimension(0), 512);
  5241. EXPECT_EQ(text_embedder->GetEmbeddingDimension(1), -1);
  5242. @@ -238,7 +238,7 @@ TEST(GetNumberOfOutputLayers, Succeeds) {
  5243. TextEmbedderOptions options = GetBasicOptions(kMobileBert);
  5244. // No Embedding options means all head get a default option.
  5245. SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextEmbedder> text_embedder,
  5246. - TextEmbedder::CreateFromOptions(options));
  5247. + TextEmbedder::CreateFromOptions(options));
  5248. EXPECT_EQ(text_embedder->GetNumberOfOutputLayers(), kNumberOfOutputLayers);
  5249. }
  5250. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/text_searcher_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/text_searcher_test.cc
  5251. index fec09a1ad77cc..f38615c5b3092 100644
  5252. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/text_searcher_test.cc
  5253. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/text_searcher_test.cc
  5254. @@ -18,9 +18,9 @@ limitations under the License.
  5255. #include <memory>
  5256. #include <string>
  5257. -#include "absl/flags/flag.h" // from @com_google_absl
  5258. -#include "absl/status/status.h" // from @com_google_absl
  5259. -#include "absl/strings/cord.h" // from @com_google_absl
  5260. +#include "absl/flags/flag.h" // from @com_google_absl
  5261. +#include "absl/status/status.h" // from @com_google_absl
  5262. +#include "absl/strings/cord.h" // from @com_google_absl
  5263. #include "absl/strings/str_cat.h" // from @com_google_absl
  5264. #include "tensorflow/lite/core/api/op_resolver.h"
  5265. #include "tensorflow/lite/core/shims/cc/shims_test_util.h"
  5266. @@ -219,7 +219,8 @@ TEST_P(CreateFromOptionsTest, FailsWithInvalidMaxResults) {
  5267. }
  5268. INSTANTIATE_TEST_SUITE_P(
  5269. - CreateFromOptionsTest, CreateFromOptionsTest,
  5270. + CreateFromOptionsTest,
  5271. + CreateFromOptionsTest,
  5272. Values(CreateFromOptionsParams{.name = "Bert",
  5273. .embedder_model_name = kMobileBertEmbedder,
  5274. .searcher_model_name = kMobileBertSearcher,
  5275. @@ -267,7 +268,7 @@ TEST_P(SearchTest, SucceedsWithStandaloneIndex) {
  5276. // Perform search.
  5277. SUPPORT_ASSERT_OK_AND_ASSIGN(const SearchResult& result,
  5278. - searcher->Search("The weather was excellent."));
  5279. + searcher->Search("The weather was excellent."));
  5280. // Check results.
  5281. ExpectApproximatelyEqual(
  5282. @@ -288,7 +289,7 @@ TEST_P(SearchTest, SucceedsWithMetadataIndex) {
  5283. // Perform search.
  5284. SUPPORT_ASSERT_OK_AND_ASSIGN(const SearchResult& result,
  5285. - searcher->Search("The weather was excellent."));
  5286. + searcher->Search("The weather was excellent."));
  5287. // Check results.
  5288. ExpectApproximatelyEqual(
  5289. @@ -313,7 +314,7 @@ TEST_P(SearchTest, SucceedsWithMaxResults) {
  5290. // Perform search.
  5291. SUPPORT_ASSERT_OK_AND_ASSIGN(const SearchResult& result,
  5292. - searcher->Search("The weather was excellent."));
  5293. + searcher->Search("The weather was excellent."));
  5294. // Check results.
  5295. SearchResult all_results =
  5296. @@ -327,7 +328,8 @@ TEST_P(SearchTest, SucceedsWithMaxResults) {
  5297. }
  5298. INSTANTIATE_TEST_SUITE_P(
  5299. - SearchTest, SearchTest,
  5300. + SearchTest,
  5301. + SearchTest,
  5302. Values(
  5303. SearchParams{
  5304. .name = "Bert",
  5305. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/universal_sentence_encoder_qa_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/universal_sentence_encoder_qa_test.cc
  5306. index 2529060cab275..5f0535b5c1438 100644
  5307. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/universal_sentence_encoder_qa_test.cc
  5308. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/universal_sentence_encoder_qa_test.cc
  5309. @@ -77,8 +77,7 @@ class UniversalSentenceEncoderQATest : public tflite_shims::testing::Test {
  5310. public:
  5311. UniversalSentenceEncoderQATest() {
  5312. // Load model file, and create qa client.
  5313. - const auto filename =
  5314. - JoinPath("./" /*test src dir*/, kTestUseQaModelDir);
  5315. + const auto filename = JoinPath("./" /*test src dir*/, kTestUseQaModelDir);
  5316. RetrievalOptions options;
  5317. options.mutable_base_options()->mutable_model_file()->set_file_name(
  5318. filename);
  5319. @@ -96,7 +95,7 @@ class UniversalSentenceEncoderQATest : public tflite_shims::testing::Test {
  5320. TEST_F(UniversalSentenceEncoderQATest, TestEncodeQuery) {
  5321. ASSERT_TRUE(qa_client_ != nullptr);
  5322. SUPPORT_ASSERT_OK_AND_ASSIGN(const auto encoded_question,
  5323. - qa_client_->EncodeQuery(kQuery));
  5324. + qa_client_->EncodeQuery(kQuery));
  5325. EXPECT_EQ(UniversalSentenceEncoderQA::kFinalEmbeddingSize,
  5326. encoded_question.value_float_size());
  5327. @@ -107,7 +106,7 @@ TEST_F(UniversalSentenceEncoderQATest, TestEncodeQuery) {
  5328. TEST_F(UniversalSentenceEncoderQATest, TestEncodeResponse) {
  5329. ASSERT_TRUE(qa_client_ != nullptr);
  5330. SUPPORT_ASSERT_OK_AND_ASSIGN(const auto encoded_response,
  5331. - qa_client_->EncodeResponse(kResponse, kContext));
  5332. + qa_client_->EncodeResponse(kResponse, kContext));
  5333. EXPECT_EQ(UniversalSentenceEncoderQA::kFinalEmbeddingSize,
  5334. encoded_response.value_float_size());
  5335. @@ -208,13 +207,14 @@ TEST_F(UniversalSentenceEncoderQATest, TestRetrieveWithEncoding) {
  5336. ASSERT_TRUE(qa_client_ != nullptr);
  5337. RetrievalInput input;
  5338. input.set_query_text(kQueryComp);
  5339. - SUPPORT_ASSERT_OK_AND_ASSIGN(const auto& query, qa_client_->EncodeQuery(kQueryComp));
  5340. + SUPPORT_ASSERT_OK_AND_ASSIGN(const auto& query,
  5341. + qa_client_->EncodeQuery(kQueryComp));
  5342. SUPPORT_ASSERT_OK_AND_ASSIGN(const auto& resp0,
  5343. - qa_client_->EncodeResponse(kResponseComp0, ""));
  5344. + qa_client_->EncodeResponse(kResponseComp0, ""));
  5345. SUPPORT_ASSERT_OK_AND_ASSIGN(const auto& resp1,
  5346. - qa_client_->EncodeResponse(kResponseComp1, ""));
  5347. + qa_client_->EncodeResponse(kResponseComp1, ""));
  5348. SUPPORT_ASSERT_OK_AND_ASSIGN(const auto& resp2,
  5349. - qa_client_->EncodeResponse(kResponseComp2, ""));
  5350. + qa_client_->EncodeResponse(kResponseComp2, ""));
  5351. *input.mutable_responses()->Add()->mutable_text_encoding() = resp0;
  5352. *input.mutable_responses()->Add()->mutable_text_encoding() = resp1;
  5353. *input.mutable_responses()->Add()->mutable_text_encoding() = resp2;
  5354. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_classifier_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_classifier_test.cc
  5355. index 6a0ce66dde9b5..2daf293b48f05 100644
  5356. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_classifier_test.cc
  5357. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_classifier_test.cc
  5358. @@ -17,9 +17,9 @@ limitations under the License.
  5359. #include <memory>
  5360. -#include "absl/flags/flag.h" // from @com_google_absl
  5361. +#include "absl/flags/flag.h" // from @com_google_absl
  5362. #include "absl/status/status.h" // from @com_google_absl
  5363. -#include "absl/strings/cord.h" // from @com_google_absl
  5364. +#include "absl/strings/cord.h" // from @com_google_absl
  5365. #include "tensorflow/lite/c/common.h"
  5366. #include "tensorflow/lite/core/shims/cc/shims_test_util.h"
  5367. #include "tensorflow/lite/kernels/builtin_op_kernels.h"
  5368. @@ -70,8 +70,8 @@ constexpr char kMobileNetQuantizedWithMetadata[] =
  5369. constexpr char kAutoMLModelWithMetadata[] = "automl_labeler_model.tflite";
  5370. StatusOr<ImageData> LoadImage(std::string image_name) {
  5371. - return DecodeImageFromFile(JoinPath("./" /*test src dir*/,
  5372. - kTestDataDirectory, image_name));
  5373. + return DecodeImageFromFile(
  5374. + JoinPath("./" /*test src dir*/, kTestDataDirectory, image_name));
  5375. }
  5376. // If the proto definition changes, please also change this function.
  5377. @@ -159,9 +159,8 @@ TEST_F(CreateFromOptionsTest, FailsWithTwoModelSources) {
  5378. options.mutable_model_file_with_metadata()->set_file_name(
  5379. JoinPath("./" /*test src dir*/, kTestDataDirectory,
  5380. kMobileNetQuantizedWithMetadata));
  5381. - options.mutable_base_options()->mutable_model_file()->set_file_name(
  5382. - JoinPath("./" /*test src dir*/, kTestDataDirectory,
  5383. - kMobileNetFloatWithMetadata));
  5384. + options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath(
  5385. + "./" /*test src dir*/, kTestDataDirectory, kMobileNetFloatWithMetadata));
  5386. StatusOr<std::unique_ptr<ImageClassifier>> image_classifier_or =
  5387. ImageClassifier::CreateFromOptions(options);
  5388. @@ -234,9 +233,8 @@ TEST_F(CreateFromOptionsTest, FailsWithCombinedWhitelistAndBlacklist) {
  5389. TEST_F(CreateFromOptionsTest, SucceedsWithNumberOfThreads) {
  5390. ImageClassifierOptions options;
  5391. options.set_num_threads(4);
  5392. - options.mutable_model_file_with_metadata()->set_file_name(
  5393. - JoinPath("./" /*test src dir*/, kTestDataDirectory,
  5394. - kMobileNetFloatWithMetadata));
  5395. + options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
  5396. + "./" /*test src dir*/, kTestDataDirectory, kMobileNetFloatWithMetadata));
  5397. SUPPORT_ASSERT_OK(ImageClassifier::CreateFromOptions(options));
  5398. }
  5399. @@ -248,9 +246,8 @@ INSTANTIATE_TEST_SUITE_P(Default, NumThreadsTest, testing::Values(0, -2));
  5400. TEST_P(NumThreadsTest, FailsWithInvalidNumberOfThreads) {
  5401. ImageClassifierOptions options;
  5402. options.set_num_threads(GetParam());
  5403. - options.mutable_model_file_with_metadata()->set_file_name(
  5404. - JoinPath("./" /*test src dir*/, kTestDataDirectory,
  5405. - kMobileNetFloatWithMetadata));
  5406. + options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
  5407. + "./" /*test src dir*/, kTestDataDirectory, kMobileNetFloatWithMetadata));
  5408. StatusOr<std::unique_ptr<ImageClassifier>> image_classifier_or =
  5409. ImageClassifier::CreateFromOptions(options);
  5410. @@ -273,12 +270,12 @@ TEST(ClassifyTest, SucceedsWithFloatModel) {
  5411. ImageClassifierOptions options;
  5412. options.set_max_results(3);
  5413. - options.mutable_model_file_with_metadata()->set_file_name(
  5414. - JoinPath("./" /*test src dir*/, kTestDataDirectory,
  5415. - kMobileNetFloatWithMetadata));
  5416. + options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
  5417. + "./" /*test src dir*/, kTestDataDirectory, kMobileNetFloatWithMetadata));
  5418. - SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier,
  5419. - ImageClassifier::CreateFromOptions(options));
  5420. + SUPPORT_ASSERT_OK_AND_ASSIGN(
  5421. + std::unique_ptr<ImageClassifier> image_classifier,
  5422. + ImageClassifier::CreateFromOptions(options));
  5423. StatusOr<ClassificationResult> result_or =
  5424. image_classifier->Classify(*frame_buffer);
  5425. @@ -307,19 +304,20 @@ TEST(ClassifyTest, SucceedsWithFloatModel) {
  5426. }
  5427. TEST(ClassifyTest, SucceedsWithRegionOfInterest) {
  5428. - SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData rgb_image, LoadImage("multi_objects.jpg"));
  5429. + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData rgb_image,
  5430. + LoadImage("multi_objects.jpg"));
  5431. std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbRawBuffer(
  5432. rgb_image.pixel_data,
  5433. FrameBuffer::Dimension{rgb_image.width, rgb_image.height});
  5434. ImageClassifierOptions options;
  5435. options.set_max_results(1);
  5436. - options.mutable_model_file_with_metadata()->set_file_name(
  5437. - JoinPath("./" /*test src dir*/, kTestDataDirectory,
  5438. - kMobileNetFloatWithMetadata));
  5439. + options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
  5440. + "./" /*test src dir*/, kTestDataDirectory, kMobileNetFloatWithMetadata));
  5441. - SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier,
  5442. - ImageClassifier::CreateFromOptions(options));
  5443. + SUPPORT_ASSERT_OK_AND_ASSIGN(
  5444. + std::unique_ptr<ImageClassifier> image_classifier,
  5445. + ImageClassifier::CreateFromOptions(options));
  5446. // Crop around the soccer ball.
  5447. BoundingBox roi;
  5448. @@ -358,8 +356,9 @@ TEST(ClassifyTest, SucceedsWithQuantizedModel) {
  5449. JoinPath("./" /*test src dir*/, kTestDataDirectory,
  5450. kMobileNetQuantizedWithMetadata));
  5451. - SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier,
  5452. - ImageClassifier::CreateFromOptions(options));
  5453. + SUPPORT_ASSERT_OK_AND_ASSIGN(
  5454. + std::unique_ptr<ImageClassifier> image_classifier,
  5455. + ImageClassifier::CreateFromOptions(options));
  5456. StatusOr<ClassificationResult> result_or =
  5457. image_classifier->Classify(*frame_buffer);
  5458. @@ -391,12 +390,12 @@ TEST(ClassifyTest, SucceedsWithBaseOptions) {
  5459. ImageClassifierOptions options;
  5460. options.set_max_results(3);
  5461. - options.mutable_base_options()->mutable_model_file()->set_file_name(
  5462. - JoinPath("./" /*test src dir*/, kTestDataDirectory,
  5463. - kMobileNetFloatWithMetadata));
  5464. + options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath(
  5465. + "./" /*test src dir*/, kTestDataDirectory, kMobileNetFloatWithMetadata));
  5466. - SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier,
  5467. - ImageClassifier::CreateFromOptions(options));
  5468. + SUPPORT_ASSERT_OK_AND_ASSIGN(
  5469. + std::unique_ptr<ImageClassifier> image_classifier,
  5470. + ImageClassifier::CreateFromOptions(options));
  5471. StatusOr<ClassificationResult> result_or =
  5472. image_classifier->Classify(*frame_buffer);
  5473. @@ -452,8 +451,8 @@ TEST(ClassifyTest, SucceedsWithMiniBenchmark) {
  5474. rgb_image.pixel_data,
  5475. FrameBuffer::Dimension{rgb_image.width, rgb_image.height});
  5476. - auto file_name = JoinPath("./" /*test src dir*/,
  5477. - kTestDataDirectory, kMobileNetFloatWithMetadata);
  5478. + auto file_name = JoinPath("./" /*test src dir*/, kTestDataDirectory,
  5479. + kMobileNetFloatWithMetadata);
  5480. ImageClassifierOptions options;
  5481. options.set_max_results(3);
  5482. @@ -462,8 +461,9 @@ TEST(ClassifyTest, SucceedsWithMiniBenchmark) {
  5483. ConfigureXnnPackMiniBenchmark(/*num_threads=*/2, options);
  5484. - SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier,
  5485. - ImageClassifier::CreateFromOptions(options));
  5486. + SUPPORT_ASSERT_OK_AND_ASSIGN(
  5487. + std::unique_ptr<ImageClassifier> image_classifier,
  5488. + ImageClassifier::CreateFromOptions(options));
  5489. StatusOr<ClassificationResult> result_or =
  5490. image_classifier->Classify(*frame_buffer);
  5491. @@ -493,11 +493,11 @@ TEST(ClassifyTest, SucceedsWithMiniBenchmark) {
  5492. TEST(ClassifyTest, GetInputCountSucceeds) {
  5493. ImageClassifierOptions options;
  5494. - options.mutable_base_options()->mutable_model_file()->set_file_name(
  5495. - JoinPath("./" /*test src dir*/, kTestDataDirectory,
  5496. - kMobileNetFloatWithMetadata));
  5497. - SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier,
  5498. - ImageClassifier::CreateFromOptions(options));
  5499. + options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath(
  5500. + "./" /*test src dir*/, kTestDataDirectory, kMobileNetFloatWithMetadata));
  5501. + SUPPORT_ASSERT_OK_AND_ASSIGN(
  5502. + std::unique_ptr<ImageClassifier> image_classifier,
  5503. + ImageClassifier::CreateFromOptions(options));
  5504. int32_t input_count = image_classifier->GetInputCount();
  5505. EXPECT_THAT(input_count, 1);
  5506. @@ -505,11 +505,11 @@ TEST(ClassifyTest, GetInputCountSucceeds) {
  5507. TEST(ClassifyTest, GetInputShapeSucceeds) {
  5508. ImageClassifierOptions options;
  5509. - options.mutable_base_options()->mutable_model_file()->set_file_name(
  5510. - JoinPath("./" /*test src dir*/, kTestDataDirectory,
  5511. - kMobileNetFloatWithMetadata));
  5512. - SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier,
  5513. - ImageClassifier::CreateFromOptions(options));
  5514. + options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath(
  5515. + "./" /*test src dir*/, kTestDataDirectory, kMobileNetFloatWithMetadata));
  5516. + SUPPORT_ASSERT_OK_AND_ASSIGN(
  5517. + std::unique_ptr<ImageClassifier> image_classifier,
  5518. + ImageClassifier::CreateFromOptions(options));
  5519. // Verify the shape array size.
  5520. const TfLiteIntArray* input_shape_0 = image_classifier->GetInputShape(0);
  5521. @@ -523,11 +523,11 @@ TEST(ClassifyTest, GetInputShapeSucceeds) {
  5522. TEST(ClassifyTest, GetOutputCountSucceeds) {
  5523. ImageClassifierOptions options;
  5524. - options.mutable_base_options()->mutable_model_file()->set_file_name(
  5525. - JoinPath("./" /*test src dir*/, kTestDataDirectory,
  5526. - kMobileNetFloatWithMetadata));
  5527. - SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier,
  5528. - ImageClassifier::CreateFromOptions(options));
  5529. + options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath(
  5530. + "./" /*test src dir*/, kTestDataDirectory, kMobileNetFloatWithMetadata));
  5531. + SUPPORT_ASSERT_OK_AND_ASSIGN(
  5532. + std::unique_ptr<ImageClassifier> image_classifier,
  5533. + ImageClassifier::CreateFromOptions(options));
  5534. int32_t output_count = image_classifier->GetOutputCount();
  5535. EXPECT_THAT(output_count, 1);
  5536. @@ -535,11 +535,11 @@ TEST(ClassifyTest, GetOutputCountSucceeds) {
  5537. TEST(ClassifyTest, GetOutputShapeSucceeds) {
  5538. ImageClassifierOptions options;
  5539. - options.mutable_base_options()->mutable_model_file()->set_file_name(
  5540. - JoinPath("./" /*test src dir*/, kTestDataDirectory,
  5541. - kMobileNetFloatWithMetadata));
  5542. - SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier,
  5543. - ImageClassifier::CreateFromOptions(options));
  5544. + options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath(
  5545. + "./" /*test src dir*/, kTestDataDirectory, kMobileNetFloatWithMetadata));
  5546. + SUPPORT_ASSERT_OK_AND_ASSIGN(
  5547. + std::unique_ptr<ImageClassifier> image_classifier,
  5548. + ImageClassifier::CreateFromOptions(options));
  5549. // Verify the shape array size.
  5550. const TfLiteIntArray* output_shape_0 = image_classifier->GetOutputShape(0);
  5551. @@ -604,9 +604,8 @@ class PostprocessTest : public tflite_shims::testing::Test {
  5552. TEST_F(PostprocessTest, SucceedsWithMaxResultsOption) {
  5553. ImageClassifierOptions options;
  5554. - options.mutable_model_file_with_metadata()->set_file_name(
  5555. - JoinPath("./" /*test src dir*/, kTestDataDirectory,
  5556. - kAutoMLModelWithMetadata));
  5557. + options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
  5558. + "./" /*test src dir*/, kTestDataDirectory, kAutoMLModelWithMetadata));
  5559. options.set_max_results(3);
  5560. SetUp(options);
  5561. @@ -618,9 +617,10 @@ TEST_F(PostprocessTest, SucceedsWithMaxResultsOption) {
  5562. std::vector<uint8_t> scores = {/*daisy*/ 0, /*dandelion*/ 64, /*roses*/ 255,
  5563. /*sunflowers*/ 32, /*tulips*/ 128};
  5564. SUPPORT_ASSERT_OK(PopulateTensor(scores, output_tensor));
  5565. - SUPPORT_ASSERT_OK_AND_ASSIGN(ClassificationResult result,
  5566. - test_image_classifier_->Postprocess(
  5567. - {output_tensor}, *dummy_frame_buffer_, /*roi=*/{}));
  5568. + SUPPORT_ASSERT_OK_AND_ASSIGN(
  5569. + ClassificationResult result,
  5570. + test_image_classifier_->Postprocess({output_tensor}, *dummy_frame_buffer_,
  5571. + /*roi=*/{}));
  5572. ExpectApproximatelyEqual(
  5573. result,
  5574. ParseTextProtoOrDie<ClassificationResult>(
  5575. @@ -635,9 +635,8 @@ TEST_F(PostprocessTest, SucceedsWithMaxResultsOption) {
  5576. TEST_F(PostprocessTest, SucceedsWithScoreThresholdOption) {
  5577. ImageClassifierOptions options;
  5578. - options.mutable_model_file_with_metadata()->set_file_name(
  5579. - JoinPath("./" /*test src dir*/, kTestDataDirectory,
  5580. - kAutoMLModelWithMetadata));
  5581. + options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
  5582. + "./" /*test src dir*/, kTestDataDirectory, kAutoMLModelWithMetadata));
  5583. options.set_score_threshold(0.4);
  5584. SetUp(options);
  5585. @@ -649,9 +648,10 @@ TEST_F(PostprocessTest, SucceedsWithScoreThresholdOption) {
  5586. std::vector<uint8_t> scores = {/*daisy*/ 0, /*dandelion*/ 64, /*roses*/ 255,
  5587. /*sunflowers*/ 32, /*tulips*/ 128};
  5588. SUPPORT_ASSERT_OK(PopulateTensor(scores, output_tensor));
  5589. - SUPPORT_ASSERT_OK_AND_ASSIGN(ClassificationResult result,
  5590. - test_image_classifier_->Postprocess(
  5591. - {output_tensor}, *dummy_frame_buffer_, /*roi=*/{}));
  5592. + SUPPORT_ASSERT_OK_AND_ASSIGN(
  5593. + ClassificationResult result,
  5594. + test_image_classifier_->Postprocess({output_tensor}, *dummy_frame_buffer_,
  5595. + /*roi=*/{}));
  5596. ExpectApproximatelyEqual(
  5597. result,
  5598. @@ -666,9 +666,8 @@ TEST_F(PostprocessTest, SucceedsWithScoreThresholdOption) {
  5599. TEST_F(PostprocessTest, SucceedsWithWhitelistOption) {
  5600. ImageClassifierOptions options;
  5601. - options.mutable_model_file_with_metadata()->set_file_name(
  5602. - JoinPath("./" /*test src dir*/, kTestDataDirectory,
  5603. - kAutoMLModelWithMetadata));
  5604. + options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
  5605. + "./" /*test src dir*/, kTestDataDirectory, kAutoMLModelWithMetadata));
  5606. options.add_class_name_whitelist("dandelion");
  5607. options.add_class_name_whitelist("daisy");
  5608. @@ -681,9 +680,10 @@ TEST_F(PostprocessTest, SucceedsWithWhitelistOption) {
  5609. std::vector<uint8_t> scores = {/*daisy*/ 0, /*dandelion*/ 64, /*roses*/ 255,
  5610. /*sunflowers*/ 32, /*tulips*/ 128};
  5611. SUPPORT_ASSERT_OK(PopulateTensor(scores, output_tensor));
  5612. - SUPPORT_ASSERT_OK_AND_ASSIGN(ClassificationResult result,
  5613. - test_image_classifier_->Postprocess(
  5614. - {output_tensor}, *dummy_frame_buffer_, /*roi=*/{}));
  5615. + SUPPORT_ASSERT_OK_AND_ASSIGN(
  5616. + ClassificationResult result,
  5617. + test_image_classifier_->Postprocess({output_tensor}, *dummy_frame_buffer_,
  5618. + /*roi=*/{}));
  5619. ExpectApproximatelyEqual(
  5620. result,
  5621. ParseTextProtoOrDie<ClassificationResult>(
  5622. @@ -697,9 +697,8 @@ TEST_F(PostprocessTest, SucceedsWithWhitelistOption) {
  5623. TEST_F(PostprocessTest, SucceedsWithBlacklistOption) {
  5624. ImageClassifierOptions options;
  5625. - options.mutable_model_file_with_metadata()->set_file_name(
  5626. - JoinPath("./" /*test src dir*/, kTestDataDirectory,
  5627. - kAutoMLModelWithMetadata));
  5628. + options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
  5629. + "./" /*test src dir*/, kTestDataDirectory, kAutoMLModelWithMetadata));
  5630. options.add_class_name_blacklist("dandelion");
  5631. options.add_class_name_blacklist("daisy");
  5632. @@ -712,9 +711,10 @@ TEST_F(PostprocessTest, SucceedsWithBlacklistOption) {
  5633. std::vector<uint8_t> scores = {/*daisy*/ 0, /*dandelion*/ 64, /*roses*/ 255,
  5634. /*sunflowers*/ 32, /*tulips*/ 128};
  5635. SUPPORT_ASSERT_OK(PopulateTensor(scores, output_tensor));
  5636. - SUPPORT_ASSERT_OK_AND_ASSIGN(ClassificationResult result,
  5637. - test_image_classifier_->Postprocess(
  5638. - {output_tensor}, *dummy_frame_buffer_, /*roi=*/{}));
  5639. + SUPPORT_ASSERT_OK_AND_ASSIGN(
  5640. + ClassificationResult result,
  5641. + test_image_classifier_->Postprocess({output_tensor}, *dummy_frame_buffer_,
  5642. + /*roi=*/{}));
  5643. ExpectApproximatelyEqual(
  5644. result,
  5645. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_embedder_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_embedder_test.cc
  5646. index 6ce017d3f1728..41226f602a26b 100644
  5647. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_embedder_test.cc
  5648. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_embedder_test.cc
  5649. @@ -17,7 +17,7 @@ limitations under the License.
  5650. #include <memory>
  5651. -#include "absl/flags/flag.h" // from @com_google_absl
  5652. +#include "absl/flags/flag.h" // from @com_google_absl
  5653. #include "absl/status/status.h" // from @com_google_absl
  5654. #include "tensorflow/lite/c/common.h"
  5655. #include "tensorflow/lite/core/shims/cc/shims_test_util.h"
  5656. @@ -59,8 +59,8 @@ constexpr char kMobileNetV3[] = "mobilenet_v3_small_100_224_embedder.tflite";
  5657. constexpr double kSimilarityTolerancy = 1e-6;
  5658. StatusOr<ImageData> LoadImage(std::string image_name) {
  5659. - return DecodeImageFromFile(JoinPath("./" /*test src dir*/,
  5660. - kTestDataDirectory, image_name));
  5661. + return DecodeImageFromFile(
  5662. + JoinPath("./" /*test src dir*/, kTestDataDirectory, image_name));
  5663. }
  5664. class MobileNetV3OpResolver : public ::tflite::MutableOpResolver {
  5665. @@ -93,8 +93,8 @@ class CreateFromOptionsTest : public tflite_shims::testing::Test {};
  5666. TEST_F(CreateFromOptionsTest, SucceedsWithSelectiveOpResolver) {
  5667. ImageEmbedderOptions options;
  5668. - options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
  5669. - "./" /*test src dir*/, kTestDataDirectory, kMobileNetV3));
  5670. + options.mutable_model_file_with_metadata()->set_file_name(
  5671. + JoinPath("./" /*test src dir*/, kTestDataDirectory, kMobileNetV3));
  5672. SUPPORT_ASSERT_OK(ImageEmbedder::CreateFromOptions(
  5673. options, absl::make_unique<MobileNetV3OpResolver>()));
  5674. @@ -113,8 +113,8 @@ class MobileNetV3OpResolverMissingOps : public ::tflite::MutableOpResolver {
  5675. TEST_F(CreateFromOptionsTest, FailsWithSelectiveOpResolverMissingOps) {
  5676. ImageEmbedderOptions options;
  5677. - options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
  5678. - "./" /*test src dir*/, kTestDataDirectory, kMobileNetV3));
  5679. + options.mutable_model_file_with_metadata()->set_file_name(
  5680. + JoinPath("./" /*test src dir*/, kTestDataDirectory, kMobileNetV3));
  5681. auto image_embedder_or = ImageEmbedder::CreateFromOptions(
  5682. options, absl::make_unique<MobileNetV3OpResolverMissingOps>());
  5683. @@ -231,8 +231,9 @@ TEST(CosineSimilarityTest, Succeeds) {
  5684. // Prevent literal from being interpreted as null-terminated C-style string.
  5685. *v_quantized.mutable_value_string() = std::string("\x80\x00\x00\x00", 4);
  5686. - SUPPORT_ASSERT_OK_AND_ASSIGN(double float_similarity,
  5687. - ImageEmbedder::CosineSimilarity(u_float, v_float));
  5688. + SUPPORT_ASSERT_OK_AND_ASSIGN(
  5689. + double float_similarity,
  5690. + ImageEmbedder::CosineSimilarity(u_float, v_float));
  5691. SUPPORT_ASSERT_OK_AND_ASSIGN(
  5692. double quantized_similarity,
  5693. ImageEmbedder::CosineSimilarity(u_quantized, v_quantized));
  5694. @@ -246,10 +247,10 @@ TEST(CosineSimilarityTest, Succeeds) {
  5695. TEST(EmbedTest, SucceedsWithoutL2Normalization) {
  5696. // Create embedder.
  5697. ImageEmbedderOptions options;
  5698. - options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
  5699. - "./" /*test src dir*/, kTestDataDirectory, kMobileNetV3));
  5700. + options.mutable_model_file_with_metadata()->set_file_name(
  5701. + JoinPath("./" /*test src dir*/, kTestDataDirectory, kMobileNetV3));
  5702. SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> embedder,
  5703. - ImageEmbedder::CreateFromOptions(options));
  5704. + ImageEmbedder::CreateFromOptions(options));
  5705. // Load images: one is a crop of the other.
  5706. SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image, LoadImage("burger.jpg"));
  5707. std::unique_ptr<FrameBuffer> image_frame_buffer = CreateFromRgbRawBuffer(
  5708. @@ -260,10 +261,10 @@ TEST(EmbedTest, SucceedsWithoutL2Normalization) {
  5709. // Extract both embeddings.
  5710. SUPPORT_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& image_result,
  5711. - embedder->Embed(*image_frame_buffer));
  5712. + embedder->Embed(*image_frame_buffer));
  5713. ImageDataFree(&image);
  5714. SUPPORT_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result,
  5715. - embedder->Embed(*crop_frame_buffer));
  5716. + embedder->Embed(*crop_frame_buffer));
  5717. ImageDataFree(&crop);
  5718. // Check results sizes
  5719. @@ -276,9 +277,9 @@ TEST(EmbedTest, SucceedsWithoutL2Normalization) {
  5720. crop_result.embeddings(0).feature_vector();
  5721. EXPECT_EQ(crop_feature_vector.value_float_size(), 1024);
  5722. // Check cosine similarity.
  5723. - SUPPORT_ASSERT_OK_AND_ASSIGN(double similarity,
  5724. - ImageEmbedder::CosineSimilarity(image_feature_vector,
  5725. - crop_feature_vector));
  5726. + SUPPORT_ASSERT_OK_AND_ASSIGN(
  5727. + double similarity, ImageEmbedder::CosineSimilarity(image_feature_vector,
  5728. + crop_feature_vector));
  5729. double expected_similarity = 0.932738;
  5730. EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
  5731. }
  5732. @@ -287,11 +288,11 @@ TEST(EmbedTest, SucceedsWithoutL2Normalization) {
  5733. TEST(EmbedTest, SucceedsWithL2Normalization) {
  5734. // Create embedder.
  5735. ImageEmbedderOptions options;
  5736. - options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
  5737. - "./" /*test src dir*/, kTestDataDirectory, kMobileNetV3));
  5738. + options.mutable_model_file_with_metadata()->set_file_name(
  5739. + JoinPath("./" /*test src dir*/, kTestDataDirectory, kMobileNetV3));
  5740. options.set_l2_normalize(true);
  5741. SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> embedder,
  5742. - ImageEmbedder::CreateFromOptions(options));
  5743. + ImageEmbedder::CreateFromOptions(options));
  5744. // Load images: one is a crop of the other.
  5745. SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image, LoadImage("burger.jpg"));
  5746. std::unique_ptr<FrameBuffer> image_frame_buffer = CreateFromRgbRawBuffer(
  5747. @@ -302,10 +303,10 @@ TEST(EmbedTest, SucceedsWithL2Normalization) {
  5748. // Extract both embeddings.
  5749. SUPPORT_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& image_result,
  5750. - embedder->Embed(*image_frame_buffer));
  5751. + embedder->Embed(*image_frame_buffer));
  5752. ImageDataFree(&image);
  5753. SUPPORT_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result,
  5754. - embedder->Embed(*crop_frame_buffer));
  5755. + embedder->Embed(*crop_frame_buffer));
  5756. ImageDataFree(&crop);
  5757. // Check results sizes
  5758. @@ -318,9 +319,9 @@ TEST(EmbedTest, SucceedsWithL2Normalization) {
  5759. crop_result.embeddings(0).feature_vector();
  5760. EXPECT_EQ(crop_feature_vector.value_float_size(), 1024);
  5761. // Check cosine similarity.
  5762. - SUPPORT_ASSERT_OK_AND_ASSIGN(double similarity,
  5763. - ImageEmbedder::CosineSimilarity(image_feature_vector,
  5764. - crop_feature_vector));
  5765. + SUPPORT_ASSERT_OK_AND_ASSIGN(
  5766. + double similarity, ImageEmbedder::CosineSimilarity(image_feature_vector,
  5767. + crop_feature_vector));
  5768. double expected_similarity = 0.932738;
  5769. EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
  5770. }
  5771. @@ -331,12 +332,12 @@ TEST(EmbedTest, SucceedsWithL2Normalization) {
  5772. TEST(EmbedTest, SucceedsWithQuantization) {
  5773. // Create embedder.
  5774. ImageEmbedderOptions options;
  5775. - options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
  5776. - "./" /*test src dir*/, kTestDataDirectory, kMobileNetV3));
  5777. + options.mutable_model_file_with_metadata()->set_file_name(
  5778. + JoinPath("./" /*test src dir*/, kTestDataDirectory, kMobileNetV3));
  5779. options.set_l2_normalize(true);
  5780. options.set_quantize(true);
  5781. SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> embedder,
  5782. - ImageEmbedder::CreateFromOptions(options));
  5783. + ImageEmbedder::CreateFromOptions(options));
  5784. // Load images: one is a crop of the other.
  5785. SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image, LoadImage("burger.jpg"));
  5786. std::unique_ptr<FrameBuffer> image_frame_buffer = CreateFromRgbRawBuffer(
  5787. @@ -347,10 +348,10 @@ TEST(EmbedTest, SucceedsWithQuantization) {
  5788. // Extract both embeddings.
  5789. SUPPORT_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& image_result,
  5790. - embedder->Embed(*image_frame_buffer));
  5791. + embedder->Embed(*image_frame_buffer));
  5792. ImageDataFree(&image);
  5793. SUPPORT_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result,
  5794. - embedder->Embed(*crop_frame_buffer));
  5795. + embedder->Embed(*crop_frame_buffer));
  5796. ImageDataFree(&crop);
  5797. // Check results sizes
  5798. @@ -363,9 +364,9 @@ TEST(EmbedTest, SucceedsWithQuantization) {
  5799. crop_result.embeddings(0).feature_vector();
  5800. EXPECT_EQ(crop_feature_vector.value_string().size(), 1024);
  5801. // Check cosine similarity.
  5802. - SUPPORT_ASSERT_OK_AND_ASSIGN(double similarity,
  5803. - ImageEmbedder::CosineSimilarity(image_feature_vector,
  5804. - crop_feature_vector));
  5805. + SUPPORT_ASSERT_OK_AND_ASSIGN(
  5806. + double similarity, ImageEmbedder::CosineSimilarity(image_feature_vector,
  5807. + crop_feature_vector));
  5808. // Close to but expectedly different from the above tests due to slight loss
  5809. // of precision during quantization:
  5810. double expected_similarity = 0.929717;
  5811. @@ -378,10 +379,10 @@ TEST(EmbedTest, SucceedsWithQuantization) {
  5812. TEST(EmbedTest, SucceedsWithRegionOfInterest) {
  5813. // Create embedder.
  5814. ImageEmbedderOptions options;
  5815. - options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
  5816. - "./" /*test src dir*/, kTestDataDirectory, kMobileNetV3));
  5817. + options.mutable_model_file_with_metadata()->set_file_name(
  5818. + JoinPath("./" /*test src dir*/, kTestDataDirectory, kMobileNetV3));
  5819. SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> embedder,
  5820. - ImageEmbedder::CreateFromOptions(options));
  5821. + ImageEmbedder::CreateFromOptions(options));
  5822. // Load images: one is a crop of the other.
  5823. SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image, LoadImage("burger.jpg"));
  5824. std::unique_ptr<FrameBuffer> image_frame_buffer = CreateFromRgbRawBuffer(
  5825. @@ -398,10 +399,10 @@ TEST(EmbedTest, SucceedsWithRegionOfInterest) {
  5826. // Extract both embeddings.
  5827. SUPPORT_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& image_result,
  5828. - embedder->Embed(*image_frame_buffer, roi));
  5829. + embedder->Embed(*image_frame_buffer, roi));
  5830. ImageDataFree(&image);
  5831. SUPPORT_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result,
  5832. - embedder->Embed(*crop_frame_buffer));
  5833. + embedder->Embed(*crop_frame_buffer));
  5834. ImageDataFree(&crop);
  5835. // Check results sizes
  5836. @@ -414,9 +415,9 @@ TEST(EmbedTest, SucceedsWithRegionOfInterest) {
  5837. crop_result.embeddings(0).feature_vector();
  5838. EXPECT_EQ(crop_feature_vector.value_float_size(), 1024);
  5839. // Check cosine similarity.
  5840. - SUPPORT_ASSERT_OK_AND_ASSIGN(double similarity,
  5841. - ImageEmbedder::CosineSimilarity(image_feature_vector,
  5842. - crop_feature_vector));
  5843. + SUPPORT_ASSERT_OK_AND_ASSIGN(
  5844. + double similarity, ImageEmbedder::CosineSimilarity(image_feature_vector,
  5845. + crop_feature_vector));
  5846. double expected_similarity = 0.999914;
  5847. EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
  5848. }
  5849. @@ -424,10 +425,10 @@ TEST(EmbedTest, SucceedsWithRegionOfInterest) {
  5850. TEST(GetEmbeddingDimension, Succeeds) {
  5851. // Create embedder.
  5852. ImageEmbedderOptions options;
  5853. - options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
  5854. - "./" /*test src dir*/, kTestDataDirectory, kMobileNetV3));
  5855. + options.mutable_model_file_with_metadata()->set_file_name(
  5856. + JoinPath("./" /*test src dir*/, kTestDataDirectory, kMobileNetV3));
  5857. SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> embedder,
  5858. - ImageEmbedder::CreateFromOptions(options));
  5859. + ImageEmbedder::CreateFromOptions(options));
  5860. EXPECT_EQ(embedder->GetEmbeddingDimension(0), 1024);
  5861. EXPECT_EQ(embedder->GetEmbeddingDimension(1), -1);
  5862. @@ -436,10 +437,10 @@ TEST(GetEmbeddingDimension, Succeeds) {
  5863. TEST(GetNumberOfOutputLayers, Succeeds) {
  5864. // Create embedder.
  5865. ImageEmbedderOptions options;
  5866. - options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
  5867. - "./" /*test src dir*/, kTestDataDirectory, kMobileNetV3));
  5868. + options.mutable_model_file_with_metadata()->set_file_name(
  5869. + JoinPath("./" /*test src dir*/, kTestDataDirectory, kMobileNetV3));
  5870. SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> embedder,
  5871. - ImageEmbedder::CreateFromOptions(options));
  5872. + ImageEmbedder::CreateFromOptions(options));
  5873. EXPECT_EQ(embedder->GetNumberOfOutputLayers(), 1);
  5874. }
  5875. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_searcher_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_searcher_test.cc
  5876. index 0b1f3b11b383c..00183eb65b5df 100644
  5877. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_searcher_test.cc
  5878. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_searcher_test.cc
  5879. @@ -18,9 +18,9 @@ limitations under the License.
  5880. #include <memory>
  5881. #include <string>
  5882. -#include "absl/flags/flag.h" // from @com_google_absl
  5883. -#include "absl/status/status.h" // from @com_google_absl
  5884. -#include "absl/strings/cord.h" // from @com_google_absl
  5885. +#include "absl/flags/flag.h" // from @com_google_absl
  5886. +#include "absl/status/status.h" // from @com_google_absl
  5887. +#include "absl/strings/cord.h" // from @com_google_absl
  5888. #include "absl/strings/str_cat.h" // from @com_google_absl
  5889. #include "tensorflow/lite/core/shims/cc/shims_test_util.h"
  5890. #include "tensorflow_lite_support/cc/common.h"
  5891. @@ -66,8 +66,8 @@ constexpr char kMobileNetV3Searcher[] =
  5892. "mobilenet_v3_small_100_224_searcher.tflite";
  5893. StatusOr<ImageData> LoadImage(std::string image_name) {
  5894. - return DecodeImageFromFile(JoinPath("./" /*test src dir*/,
  5895. - kTestDataDirectory, image_name));
  5896. + return DecodeImageFromFile(
  5897. + JoinPath("./" /*test src dir*/, kTestDataDirectory, image_name));
  5898. }
  5899. // Checks that the two provided `SearchResult` protos are equal, with a
  5900. @@ -88,9 +88,8 @@ class CreateFromOptionsTest : public tflite_shims::testing::Test {};
  5901. TEST_F(CreateFromOptionsTest, SucceedsWithStandaloneIndex) {
  5902. ImageSearcherOptions options;
  5903. - options.mutable_base_options()->mutable_model_file()->set_file_name(
  5904. - JoinPath("./" /*test src dir*/, kTestDataDirectory,
  5905. - kMobileNetV3Embedder));
  5906. + options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath(
  5907. + "./" /*test src dir*/, kTestDataDirectory, kMobileNetV3Embedder));
  5908. options.mutable_embedding_options()->set_l2_normalize(true);
  5909. options.mutable_search_options()->mutable_index_file()->set_file_name(
  5910. JoinPath("./" /*test src dir*/, kTestDataDirectory, kIndex));
  5911. @@ -100,9 +99,8 @@ TEST_F(CreateFromOptionsTest, SucceedsWithStandaloneIndex) {
  5912. TEST_F(CreateFromOptionsTest, SucceedsWithMetadataIndex) {
  5913. ImageSearcherOptions options;
  5914. - options.mutable_base_options()->mutable_model_file()->set_file_name(
  5915. - JoinPath("./" /*test src dir*/, kTestDataDirectory,
  5916. - kMobileNetV3Searcher));
  5917. + options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath(
  5918. + "./" /*test src dir*/, kTestDataDirectory, kMobileNetV3Searcher));
  5919. options.mutable_embedding_options()->set_l2_normalize(true);
  5920. SUPPORT_ASSERT_OK(ImageSearcher::CreateFromOptions(options));
  5921. @@ -129,9 +127,8 @@ TEST_F(CreateFromOptionsTest, FailsWithMissingModel) {
  5922. TEST_F(CreateFromOptionsTest, FailsWithMissingIndex) {
  5923. ImageSearcherOptions options;
  5924. - options.mutable_base_options()->mutable_model_file()->set_file_name(
  5925. - JoinPath("./" /*test src dir*/, kTestDataDirectory,
  5926. - kMobileNetV3Embedder));
  5927. + options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath(
  5928. + "./" /*test src dir*/, kTestDataDirectory, kMobileNetV3Embedder));
  5929. options.mutable_embedding_options()->set_l2_normalize(true);
  5930. StatusOr<std::unique_ptr<ImageSearcher>> image_searcher_or =
  5931. @@ -151,9 +148,8 @@ TEST_F(CreateFromOptionsTest, FailsWithMissingIndex) {
  5932. TEST_F(CreateFromOptionsTest, FailsWithQuantization) {
  5933. ImageSearcherOptions options;
  5934. - options.mutable_base_options()->mutable_model_file()->set_file_name(
  5935. - JoinPath("./" /*test src dir*/, kTestDataDirectory,
  5936. - kMobileNetV3Embedder));
  5937. + options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath(
  5938. + "./" /*test src dir*/, kTestDataDirectory, kMobileNetV3Embedder));
  5939. options.mutable_embedding_options()->set_l2_normalize(true);
  5940. options.mutable_embedding_options()->set_quantize(true);
  5941. options.mutable_search_options()->mutable_index_file()->set_file_name(
  5942. @@ -174,9 +170,8 @@ TEST_F(CreateFromOptionsTest, FailsWithQuantization) {
  5943. TEST_F(CreateFromOptionsTest, FailsWithInvalidMaxResults) {
  5944. ImageSearcherOptions options;
  5945. - options.mutable_base_options()->mutable_model_file()->set_file_name(
  5946. - JoinPath("./" /*test src dir*/, kTestDataDirectory,
  5947. - kMobileNetV3Embedder));
  5948. + options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath(
  5949. + "./" /*test src dir*/, kTestDataDirectory, kMobileNetV3Embedder));
  5950. options.mutable_embedding_options()->set_l2_normalize(true);
  5951. options.mutable_search_options()->mutable_index_file()->set_file_name(
  5952. JoinPath("./" /*test src dir*/, kTestDataDirectory, kIndex));
  5953. @@ -197,14 +192,13 @@ TEST_F(CreateFromOptionsTest, FailsWithInvalidMaxResults) {
  5954. TEST(SearchTest, SucceedsWithStandaloneIndex) {
  5955. // Create Searcher.
  5956. ImageSearcherOptions options;
  5957. - options.mutable_base_options()->mutable_model_file()->set_file_name(
  5958. - JoinPath("./" /*test src dir*/, kTestDataDirectory,
  5959. - kMobileNetV3Embedder));
  5960. + options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath(
  5961. + "./" /*test src dir*/, kTestDataDirectory, kMobileNetV3Embedder));
  5962. options.mutable_embedding_options()->set_l2_normalize(true);
  5963. options.mutable_search_options()->mutable_index_file()->set_file_name(
  5964. JoinPath("./" /*test src dir*/, kTestDataDirectory, kIndex));
  5965. SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSearcher> searcher,
  5966. - ImageSearcher::CreateFromOptions(options));
  5967. + ImageSearcher::CreateFromOptions(options));
  5968. // Load image.
  5969. SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image, LoadImage("burger.jpg"));
  5970. std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbRawBuffer(
  5971. @@ -212,7 +206,7 @@ TEST(SearchTest, SucceedsWithStandaloneIndex) {
  5972. // Perform search.
  5973. SUPPORT_ASSERT_OK_AND_ASSIGN(const SearchResult& result,
  5974. - searcher->Search(*frame_buffer));
  5975. + searcher->Search(*frame_buffer));
  5976. ImageDataFree(&image);
  5977. // Check results.
  5978. @@ -229,12 +223,11 @@ TEST(SearchTest, SucceedsWithStandaloneIndex) {
  5979. TEST(SearchTest, SucceedsWithMetadataIndex) {
  5980. // Create Searcher.
  5981. ImageSearcherOptions options;
  5982. - options.mutable_base_options()->mutable_model_file()->set_file_name(
  5983. - JoinPath("./" /*test src dir*/, kTestDataDirectory,
  5984. - kMobileNetV3Searcher));
  5985. + options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath(
  5986. + "./" /*test src dir*/, kTestDataDirectory, kMobileNetV3Searcher));
  5987. options.mutable_embedding_options()->set_l2_normalize(true);
  5988. SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSearcher> searcher,
  5989. - ImageSearcher::CreateFromOptions(options));
  5990. + ImageSearcher::CreateFromOptions(options));
  5991. // Load image.
  5992. SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image, LoadImage("burger.jpg"));
  5993. std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbRawBuffer(
  5994. @@ -242,7 +235,7 @@ TEST(SearchTest, SucceedsWithMetadataIndex) {
  5995. // Perform search.
  5996. SUPPORT_ASSERT_OK_AND_ASSIGN(const SearchResult& result,
  5997. - searcher->Search(*frame_buffer));
  5998. + searcher->Search(*frame_buffer));
  5999. ImageDataFree(&image);
  6000. // Check results.
  6001. @@ -259,15 +252,14 @@ TEST(SearchTest, SucceedsWithMetadataIndex) {
  6002. TEST(SearchTest, SucceedsWithMaxResults) {
  6003. // Create Searcher.
  6004. ImageSearcherOptions options;
  6005. - options.mutable_base_options()->mutable_model_file()->set_file_name(
  6006. - JoinPath("./" /*test src dir*/, kTestDataDirectory,
  6007. - kMobileNetV3Embedder));
  6008. + options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath(
  6009. + "./" /*test src dir*/, kTestDataDirectory, kMobileNetV3Embedder));
  6010. options.mutable_embedding_options()->set_l2_normalize(true);
  6011. options.mutable_search_options()->mutable_index_file()->set_file_name(
  6012. JoinPath("./" /*test src dir*/, kTestDataDirectory, kIndex));
  6013. options.mutable_search_options()->set_max_results(2);
  6014. SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSearcher> searcher,
  6015. - ImageSearcher::CreateFromOptions(options));
  6016. + ImageSearcher::CreateFromOptions(options));
  6017. // Load image.
  6018. SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image, LoadImage("burger.jpg"));
  6019. std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbRawBuffer(
  6020. @@ -275,7 +267,7 @@ TEST(SearchTest, SucceedsWithMaxResults) {
  6021. // Perform search.
  6022. SUPPORT_ASSERT_OK_AND_ASSIGN(const SearchResult& result,
  6023. - searcher->Search(*frame_buffer));
  6024. + searcher->Search(*frame_buffer));
  6025. ImageDataFree(&image);
  6026. // Check results.
  6027. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_segmenter_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_segmenter_test.cc
  6028. index e32b8e4c27524..8671b68c3b884 100644
  6029. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_segmenter_test.cc
  6030. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_segmenter_test.cc
  6031. @@ -17,9 +17,9 @@ limitations under the License.
  6032. #include <memory>
  6033. -#include "absl/flags/flag.h" // from @com_google_absl
  6034. +#include "absl/flags/flag.h" // from @com_google_absl
  6035. #include "absl/status/status.h" // from @com_google_absl
  6036. -#include "absl/strings/cord.h" // from @com_google_absl
  6037. +#include "absl/strings/cord.h" // from @com_google_absl
  6038. #include "tensorflow/lite/c/common.h"
  6039. #include "tensorflow/lite/core/shims/cc/shims_test_util.h"
  6040. #include "tensorflow/lite/kernels/builtin_op_kernels.h"
  6041. @@ -99,8 +99,8 @@ constexpr float kGoldenMaskTolerance = 1e-2;
  6042. constexpr int kGoldenMaskMagnificationFactor = 10;
  6043. StatusOr<ImageData> LoadImage(std::string image_name) {
  6044. - return DecodeImageFromFile(JoinPath("./" /*test src dir*/,
  6045. - kTestDataDirectory, image_name));
  6046. + return DecodeImageFromFile(
  6047. + JoinPath("./" /*test src dir*/, kTestDataDirectory, image_name));
  6048. }
  6049. // Checks that the two provided `Segmentation` protos are equal.
  6050. @@ -141,8 +141,8 @@ class CreateFromOptionsTest : public tflite_shims::testing::Test {};
  6051. TEST_F(CreateFromOptionsTest, SucceedsWithSelectiveOpResolver) {
  6052. ImageSegmenterOptions options;
  6053. - options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
  6054. - "./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
  6055. + options.mutable_model_file_with_metadata()->set_file_name(
  6056. + JoinPath("./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
  6057. SUPPORT_ASSERT_OK(ImageSegmenter::CreateFromOptions(
  6058. options, absl::make_unique<DeepLabOpResolver>()));
  6059. @@ -160,8 +160,8 @@ class DeepLabOpResolverMissingOps : public ::tflite::MutableOpResolver {
  6060. TEST_F(CreateFromOptionsTest, FailsWithSelectiveOpResolverMissingOps) {
  6061. ImageSegmenterOptions options;
  6062. - options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
  6063. - "./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
  6064. + options.mutable_model_file_with_metadata()->set_file_name(
  6065. + JoinPath("./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
  6066. auto image_segmenter_or = ImageSegmenter::CreateFromOptions(
  6067. options, absl::make_unique<DeepLabOpResolverMissingOps>());
  6068. @@ -177,10 +177,10 @@ TEST_F(CreateFromOptionsTest, FailsWithSelectiveOpResolverMissingOps) {
  6069. TEST_F(CreateFromOptionsTest, FailsWithTwoModelSources) {
  6070. ImageSegmenterOptions options;
  6071. - options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
  6072. - "./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
  6073. - options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath(
  6074. - "./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
  6075. + options.mutable_model_file_with_metadata()->set_file_name(
  6076. + JoinPath("./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
  6077. + options.mutable_base_options()->mutable_model_file()->set_file_name(
  6078. + JoinPath("./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
  6079. StatusOr<std::unique_ptr<ImageSegmenter>> image_segmenter_or =
  6080. ImageSegmenter::CreateFromOptions(options);
  6081. @@ -212,8 +212,8 @@ TEST_F(CreateFromOptionsTest, FailsWithMissingModel) {
  6082. TEST_F(CreateFromOptionsTest, FailsWithUnspecifiedOutputType) {
  6083. ImageSegmenterOptions options;
  6084. - options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
  6085. - "./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
  6086. + options.mutable_model_file_with_metadata()->set_file_name(
  6087. + JoinPath("./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
  6088. options.set_output_type(ImageSegmenterOptions::UNSPECIFIED);
  6089. auto image_segmenter_or = ImageSegmenter::CreateFromOptions(options);
  6090. @@ -230,8 +230,8 @@ TEST_F(CreateFromOptionsTest, FailsWithUnspecifiedOutputType) {
  6091. TEST_F(CreateFromOptionsTest, SucceedsWithNumberOfThreads) {
  6092. ImageSegmenterOptions options;
  6093. options.set_num_threads(4);
  6094. - options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
  6095. - "./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
  6096. + options.mutable_model_file_with_metadata()->set_file_name(
  6097. + JoinPath("./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
  6098. SUPPORT_ASSERT_OK(ImageSegmenter::CreateFromOptions(options));
  6099. }
  6100. @@ -243,8 +243,8 @@ INSTANTIATE_TEST_SUITE_P(Default, NumThreadsTest, testing::Values(0, -2));
  6101. TEST_P(NumThreadsTest, FailsWithInvalidNumberOfThreads) {
  6102. ImageSegmenterOptions options;
  6103. options.set_num_threads(GetParam());
  6104. - options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
  6105. - "./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
  6106. + options.mutable_model_file_with_metadata()->set_file_name(
  6107. + JoinPath("./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
  6108. StatusOr<std::unique_ptr<ImageSegmenter>> image_segmenter_or =
  6109. ImageSegmenter::CreateFromOptions(options);
  6110. @@ -263,21 +263,21 @@ TEST_P(NumThreadsTest, FailsWithInvalidNumberOfThreads) {
  6111. TEST(SegmentTest, SucceedsWithCategoryMask) {
  6112. // Load input and build frame buffer.
  6113. SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData rgb_image,
  6114. - LoadImage("segmentation_input_rotation0.jpg"));
  6115. + LoadImage("segmentation_input_rotation0.jpg"));
  6116. std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbRawBuffer(
  6117. rgb_image.pixel_data,
  6118. FrameBuffer::Dimension{rgb_image.width, rgb_image.height});
  6119. // Load golden mask output.
  6120. SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData golden_mask,
  6121. - LoadImage("segmentation_golden_rotation0.png"));
  6122. + LoadImage("segmentation_golden_rotation0.png"));
  6123. ImageSegmenterOptions options;
  6124. - options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
  6125. - "./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
  6126. + options.mutable_model_file_with_metadata()->set_file_name(
  6127. + JoinPath("./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
  6128. SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> image_segmenter,
  6129. - ImageSegmenter::CreateFromOptions(options));
  6130. + ImageSegmenter::CreateFromOptions(options));
  6131. SUPPORT_ASSERT_OK_AND_ASSIGN(const SegmentationResult result,
  6132. - image_segmenter->Segment(*frame_buffer));
  6133. + image_segmenter->Segment(*frame_buffer));
  6134. EXPECT_EQ(result.segmentation_size(), 1);
  6135. const Segmentation& segmentation = result.segmentation(0);
  6136. @@ -301,23 +301,24 @@ TEST(SegmentTest, SucceedsWithCategoryMask) {
  6137. TEST(SegmentTest, SucceedsWithOrientation) {
  6138. // Load input and build frame buffer with kRightBottom orientation.
  6139. - SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData rgb_image,
  6140. - LoadImage("segmentation_input_rotation90_flop.jpg"));
  6141. + SUPPORT_ASSERT_OK_AND_ASSIGN(
  6142. + ImageData rgb_image, LoadImage("segmentation_input_rotation90_flop.jpg"));
  6143. std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbRawBuffer(
  6144. rgb_image.pixel_data,
  6145. FrameBuffer::Dimension{rgb_image.width, rgb_image.height},
  6146. FrameBuffer::Orientation::kRightBottom);
  6147. // Load golden mask output.
  6148. - SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData golden_mask,
  6149. - LoadImage("segmentation_golden_rotation90_flop.png"));
  6150. + SUPPORT_ASSERT_OK_AND_ASSIGN(
  6151. + ImageData golden_mask,
  6152. + LoadImage("segmentation_golden_rotation90_flop.png"));
  6153. ImageSegmenterOptions options;
  6154. - options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
  6155. - "./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
  6156. + options.mutable_model_file_with_metadata()->set_file_name(
  6157. + JoinPath("./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
  6158. SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> image_segmenter,
  6159. - ImageSegmenter::CreateFromOptions(options));
  6160. + ImageSegmenter::CreateFromOptions(options));
  6161. SUPPORT_ASSERT_OK_AND_ASSIGN(const SegmentationResult result,
  6162. - image_segmenter->Segment(*frame_buffer));
  6163. + image_segmenter->Segment(*frame_buffer));
  6164. EXPECT_EQ(result.segmentation_size(), 1);
  6165. const Segmentation& segmentation = result.segmentation(0);
  6166. @@ -341,21 +342,21 @@ TEST(SegmentTest, SucceedsWithOrientation) {
  6167. TEST(SegmentTest, SucceedsWithBaseOptions) {
  6168. // Load input and build frame buffer.
  6169. SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData rgb_image,
  6170. - LoadImage("segmentation_input_rotation0.jpg"));
  6171. + LoadImage("segmentation_input_rotation0.jpg"));
  6172. std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbRawBuffer(
  6173. rgb_image.pixel_data,
  6174. FrameBuffer::Dimension{rgb_image.width, rgb_image.height});
  6175. // Load golden mask output.
  6176. SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData golden_mask,
  6177. - LoadImage("segmentation_golden_rotation0.png"));
  6178. + LoadImage("segmentation_golden_rotation0.png"));
  6179. ImageSegmenterOptions options;
  6180. - options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath(
  6181. - "./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
  6182. + options.mutable_base_options()->mutable_model_file()->set_file_name(
  6183. + JoinPath("./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
  6184. SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> image_segmenter,
  6185. - ImageSegmenter::CreateFromOptions(options));
  6186. + ImageSegmenter::CreateFromOptions(options));
  6187. SUPPORT_ASSERT_OK_AND_ASSIGN(const SegmentationResult result,
  6188. - image_segmenter->Segment(*frame_buffer));
  6189. + image_segmenter->Segment(*frame_buffer));
  6190. EXPECT_EQ(result.segmentation_size(), 1);
  6191. const Segmentation& segmentation = result.segmentation(0);
  6192. @@ -461,18 +462,18 @@ class PostprocessTest : public tflite_shims::testing::Test {
  6193. TEST_F(PostprocessTest, SucceedsWithCategoryMask) {
  6194. ImageSegmenterOptions options;
  6195. - options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
  6196. - "./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
  6197. + options.mutable_model_file_with_metadata()->set_file_name(
  6198. + JoinPath("./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
  6199. std::unique_ptr<FrameBuffer> frame_buffer =
  6200. CreateFromRgbaRawBuffer(/*input=*/nullptr, {});
  6201. SetUp(options);
  6202. ASSERT_TRUE(test_image_segmenter_ != nullptr) << init_status_;
  6203. SUPPORT_ASSERT_OK_AND_ASSIGN(const TfLiteTensor* output_tensor,
  6204. - FillAndGetOutputTensor());
  6205. + FillAndGetOutputTensor());
  6206. SUPPORT_ASSERT_OK_AND_ASSIGN(SegmentationResult result,
  6207. - test_image_segmenter_->Postprocess(
  6208. - {output_tensor}, *frame_buffer, /*roi=*/{}));
  6209. + test_image_segmenter_->Postprocess(
  6210. + {output_tensor}, *frame_buffer, /*roi=*/{}));
  6211. EXPECT_EQ(result.segmentation_size(), 1);
  6212. const Segmentation& segmentation = result.segmentation(0);
  6213. @@ -487,8 +488,8 @@ TEST_F(PostprocessTest, SucceedsWithCategoryMask) {
  6214. TEST_F(PostprocessTest, SucceedsWithCategoryMaskAndOrientation) {
  6215. ImageSegmenterOptions options;
  6216. - options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
  6217. - "./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
  6218. + options.mutable_model_file_with_metadata()->set_file_name(
  6219. + JoinPath("./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
  6220. // Frame buffer with kRightBottom orientation.
  6221. std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbaRawBuffer(
  6222. /*input=*/nullptr, {}, FrameBuffer::Orientation::kRightBottom);
  6223. @@ -496,10 +497,10 @@ TEST_F(PostprocessTest, SucceedsWithCategoryMaskAndOrientation) {
  6224. SetUp(options);
  6225. ASSERT_TRUE(test_image_segmenter_ != nullptr) << init_status_;
  6226. SUPPORT_ASSERT_OK_AND_ASSIGN(const TfLiteTensor* output_tensor,
  6227. - FillAndGetOutputTensor());
  6228. + FillAndGetOutputTensor());
  6229. SUPPORT_ASSERT_OK_AND_ASSIGN(SegmentationResult result,
  6230. - test_image_segmenter_->Postprocess(
  6231. - {output_tensor}, *frame_buffer, /*roi=*/{}));
  6232. + test_image_segmenter_->Postprocess(
  6233. + {output_tensor}, *frame_buffer, /*roi=*/{}));
  6234. EXPECT_EQ(result.segmentation_size(), 1);
  6235. const Segmentation& segmentation = result.segmentation(0);
  6236. @@ -515,18 +516,18 @@ TEST_F(PostprocessTest, SucceedsWithCategoryMaskAndOrientation) {
  6237. TEST_F(PostprocessTest, SucceedsWithConfidenceMask) {
  6238. ImageSegmenterOptions options;
  6239. options.set_output_type(ImageSegmenterOptions::CONFIDENCE_MASK);
  6240. - options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
  6241. - "./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
  6242. + options.mutable_model_file_with_metadata()->set_file_name(
  6243. + JoinPath("./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
  6244. std::unique_ptr<FrameBuffer> frame_buffer =
  6245. CreateFromRgbaRawBuffer(/*input=*/nullptr, {});
  6246. SetUp(options);
  6247. ASSERT_TRUE(test_image_segmenter_ != nullptr) << init_status_;
  6248. SUPPORT_ASSERT_OK_AND_ASSIGN(const TfLiteTensor* output_tensor,
  6249. - FillAndGetOutputTensor());
  6250. + FillAndGetOutputTensor());
  6251. SUPPORT_ASSERT_OK_AND_ASSIGN(SegmentationResult result,
  6252. - test_image_segmenter_->Postprocess(
  6253. - {output_tensor}, *frame_buffer, /*roi=*/{}));
  6254. + test_image_segmenter_->Postprocess(
  6255. + {output_tensor}, *frame_buffer, /*roi=*/{}));
  6256. EXPECT_EQ(result.segmentation_size(), 1);
  6257. const Segmentation& segmentation = result.segmentation(0);
  6258. @@ -547,8 +548,8 @@ TEST_F(PostprocessTest, SucceedsWithConfidenceMask) {
  6259. TEST_F(PostprocessTest, SucceedsWithConfidenceMaskAndOrientation) {
  6260. ImageSegmenterOptions options;
  6261. options.set_output_type(ImageSegmenterOptions::CONFIDENCE_MASK);
  6262. - options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
  6263. - "./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
  6264. + options.mutable_model_file_with_metadata()->set_file_name(
  6265. + JoinPath("./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
  6266. // Frame buffer with kRightBottom orientation.
  6267. std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbaRawBuffer(
  6268. /*input=*/nullptr, {}, FrameBuffer::Orientation::kRightBottom);
  6269. @@ -556,10 +557,10 @@ TEST_F(PostprocessTest, SucceedsWithConfidenceMaskAndOrientation) {
  6270. SetUp(options);
  6271. ASSERT_TRUE(test_image_segmenter_ != nullptr) << init_status_;
  6272. SUPPORT_ASSERT_OK_AND_ASSIGN(const TfLiteTensor* output_tensor,
  6273. - FillAndGetOutputTensor());
  6274. + FillAndGetOutputTensor());
  6275. SUPPORT_ASSERT_OK_AND_ASSIGN(SegmentationResult result,
  6276. - test_image_segmenter_->Postprocess(
  6277. - {output_tensor}, *frame_buffer, /*roi=*/{}));
  6278. + test_image_segmenter_->Postprocess(
  6279. + {output_tensor}, *frame_buffer, /*roi=*/{}));
  6280. EXPECT_EQ(result.segmentation_size(), 1);
  6281. const Segmentation& segmentation = result.segmentation(0);
  6282. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/object_detector_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/object_detector_test.cc
  6283. index a4f35574d7bfe..6c0f395868e20 100644
  6284. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/object_detector_test.cc
  6285. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/object_detector_test.cc
  6286. @@ -17,9 +17,9 @@ limitations under the License.
  6287. #include <memory>
  6288. -#include "absl/flags/flag.h" // from @com_google_absl
  6289. +#include "absl/flags/flag.h" // from @com_google_absl
  6290. #include "absl/status/status.h" // from @com_google_absl
  6291. -#include "absl/strings/cord.h" // from @com_google_absl
  6292. +#include "absl/strings/cord.h" // from @com_google_absl
  6293. #include "tensorflow/lite/c/common.h"
  6294. #include "tensorflow/lite/core/shims/cc/shims_test_util.h"
  6295. #include "tensorflow/lite/kernels/builtin_op_kernels.h"
  6296. @@ -103,8 +103,8 @@ constexpr char kEfficientDetWithMetadata[] =
  6297. "coco_efficientdet_lite0_v1_1.0_quant_2021_09_06.tflite";
  6298. StatusOr<ImageData> LoadImage(std::string image_name) {
  6299. - return DecodeImageFromFile(JoinPath("./" /*test src dir*/,
  6300. - kTestDataDirectory, image_name));
  6301. + return DecodeImageFromFile(
  6302. + JoinPath("./" /*test src dir*/, kTestDataDirectory, image_name));
  6303. }
  6304. // Checks that the two provided `DetectionResult` protos are equal, with a
  6305. @@ -153,9 +153,8 @@ class CreateFromOptionsTest : public tflite_shims::testing::Test {};
  6306. TEST_F(CreateFromOptionsTest, SucceedsWithSelectiveOpResolver) {
  6307. ObjectDetectorOptions options;
  6308. - options.mutable_model_file_with_metadata()->set_file_name(
  6309. - JoinPath("./" /*test src dir*/, kTestDataDirectory,
  6310. - kMobileSsdWithMetadata));
  6311. + options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
  6312. + "./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata));
  6313. SUPPORT_ASSERT_OK(ObjectDetector::CreateFromOptions(
  6314. options, absl::make_unique<MobileSsdQuantizedOpResolver>()));
  6315. @@ -186,9 +185,8 @@ class MobileSsdQuantizedOpResolverMissingOps
  6316. TEST_F(CreateFromOptionsTest, FailsWithSelectiveOpResolverMissingOps) {
  6317. ObjectDetectorOptions options;
  6318. - options.mutable_model_file_with_metadata()->set_file_name(
  6319. - JoinPath("./" /*test src dir*/, kTestDataDirectory,
  6320. - kMobileSsdWithMetadata));
  6321. + options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
  6322. + "./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata));
  6323. auto object_detector_or = ObjectDetector::CreateFromOptions(
  6324. options, absl::make_unique<MobileSsdQuantizedOpResolverMissingOps>());
  6325. @@ -203,12 +201,10 @@ TEST_F(CreateFromOptionsTest, FailsWithSelectiveOpResolverMissingOps) {
  6326. TEST_F(CreateFromOptionsTest, FailsWithTwoModelSources) {
  6327. ObjectDetectorOptions options;
  6328. - options.mutable_model_file_with_metadata()->set_file_name(
  6329. - JoinPath("./" /*test src dir*/, kTestDataDirectory,
  6330. - kMobileSsdWithMetadata));
  6331. - options.mutable_base_options()->mutable_model_file()->set_file_name(
  6332. - JoinPath("./" /*test src dir*/, kTestDataDirectory,
  6333. - kMobileSsdWithMetadata));
  6334. + options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
  6335. + "./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata));
  6336. + options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath(
  6337. + "./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata));
  6338. StatusOr<std::unique_ptr<ObjectDetector>> object_detector_or =
  6339. ObjectDetector::CreateFromOptions(options);
  6340. @@ -241,9 +237,8 @@ TEST_F(CreateFromOptionsTest, FailsWithMissingModel) {
  6341. TEST_F(CreateFromOptionsTest, FailsWithInvalidMaxResults) {
  6342. ObjectDetectorOptions options;
  6343. - options.mutable_model_file_with_metadata()->set_file_name(
  6344. - JoinPath("./" /*test src dir*/, kTestDataDirectory,
  6345. - kMobileSsdWithMetadata));
  6346. + options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
  6347. + "./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata));
  6348. options.set_max_results(0);
  6349. StatusOr<std::unique_ptr<ObjectDetector>> object_detector_or =
  6350. @@ -260,9 +255,8 @@ TEST_F(CreateFromOptionsTest, FailsWithInvalidMaxResults) {
  6351. TEST_F(CreateFromOptionsTest, FailsWithCombinedWhitelistAndBlacklist) {
  6352. ObjectDetectorOptions options;
  6353. - options.mutable_model_file_with_metadata()->set_file_name(
  6354. - JoinPath("./" /*test src dir*/, kTestDataDirectory,
  6355. - kMobileSsdWithMetadata));
  6356. + options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
  6357. + "./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata));
  6358. options.add_class_name_whitelist("foo");
  6359. options.add_class_name_blacklist("bar");
  6360. @@ -281,9 +275,8 @@ TEST_F(CreateFromOptionsTest, FailsWithCombinedWhitelistAndBlacklist) {
  6361. TEST_F(CreateFromOptionsTest, SucceedsWithNumberOfThreads) {
  6362. ObjectDetectorOptions options;
  6363. options.set_num_threads(4);
  6364. - options.mutable_model_file_with_metadata()->set_file_name(
  6365. - JoinPath("./" /*test src dir*/, kTestDataDirectory,
  6366. - kMobileSsdWithMetadata));
  6367. + options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
  6368. + "./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata));
  6369. SUPPORT_ASSERT_OK(ObjectDetector::CreateFromOptions(options));
  6370. }
  6371. @@ -295,9 +288,8 @@ INSTANTIATE_TEST_SUITE_P(Default, NumThreadsTest, testing::Values(0, -2));
  6372. TEST_P(NumThreadsTest, FailsWithInvalidNumberOfThreads) {
  6373. ObjectDetectorOptions options;
  6374. options.set_num_threads(GetParam());
  6375. - options.mutable_model_file_with_metadata()->set_file_name(
  6376. - JoinPath("./" /*test src dir*/, kTestDataDirectory,
  6377. - kMobileSsdWithMetadata));
  6378. + options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
  6379. + "./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata));
  6380. StatusOr<std::unique_ptr<ObjectDetector>> object_detector_or =
  6381. ObjectDetector::CreateFromOptions(options);
  6382. @@ -315,51 +307,52 @@ TEST_P(NumThreadsTest, FailsWithInvalidNumberOfThreads) {
  6383. class DetectTest : public tflite_shims::testing::Test {};
  6384. TEST_F(DetectTest, Succeeds) {
  6385. - SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData rgb_image, LoadImage("cats_and_dogs.jpg"));
  6386. + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData rgb_image,
  6387. + LoadImage("cats_and_dogs.jpg"));
  6388. std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbRawBuffer(
  6389. rgb_image.pixel_data,
  6390. FrameBuffer::Dimension{rgb_image.width, rgb_image.height});
  6391. ObjectDetectorOptions options;
  6392. options.set_max_results(4);
  6393. - options.mutable_model_file_with_metadata()->set_file_name(
  6394. - JoinPath("./" /*test src dir*/, kTestDataDirectory,
  6395. - kMobileSsdWithMetadata));
  6396. + options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
  6397. + "./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata));
  6398. SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ObjectDetector> object_detector,
  6399. - ObjectDetector::CreateFromOptions(options));
  6400. + ObjectDetector::CreateFromOptions(options));
  6401. SUPPORT_ASSERT_OK_AND_ASSIGN(const DetectionResult result,
  6402. - object_detector->Detect(*frame_buffer));
  6403. + object_detector->Detect(*frame_buffer));
  6404. ImageDataFree(&rgb_image);
  6405. ExpectApproximatelyEqual(
  6406. result, ParseTextProtoOrDie<DetectionResult>(kExpectResults));
  6407. }
  6408. TEST_F(DetectTest, SucceedswithBaseOptions) {
  6409. - SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData rgb_image, LoadImage("cats_and_dogs.jpg"));
  6410. + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData rgb_image,
  6411. + LoadImage("cats_and_dogs.jpg"));
  6412. std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbRawBuffer(
  6413. rgb_image.pixel_data,
  6414. FrameBuffer::Dimension{rgb_image.width, rgb_image.height});
  6415. ObjectDetectorOptions options;
  6416. options.set_max_results(4);
  6417. - options.mutable_base_options()->mutable_model_file()->set_file_name(
  6418. - JoinPath("./" /*test src dir*/, kTestDataDirectory,
  6419. - kMobileSsdWithMetadata));
  6420. + options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath(
  6421. + "./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata));
  6422. SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ObjectDetector> object_detector,
  6423. - ObjectDetector::CreateFromOptions(options));
  6424. + ObjectDetector::CreateFromOptions(options));
  6425. SUPPORT_ASSERT_OK_AND_ASSIGN(const DetectionResult result,
  6426. - object_detector->Detect(*frame_buffer));
  6427. + object_detector->Detect(*frame_buffer));
  6428. ImageDataFree(&rgb_image);
  6429. ExpectApproximatelyEqual(
  6430. result, ParseTextProtoOrDie<DetectionResult>(kExpectResults));
  6431. }
  6432. TEST_F(DetectTest, SucceedswithScoreCalibrations) {
  6433. - SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData rgb_image, LoadImage("cats_and_dogs.jpg"));
  6434. + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData rgb_image,
  6435. + LoadImage("cats_and_dogs.jpg"));
  6436. std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbRawBuffer(
  6437. rgb_image.pixel_data,
  6438. FrameBuffer::Dimension{rgb_image.width, rgb_image.height});
  6439. @@ -371,10 +364,10 @@ TEST_F(DetectTest, SucceedswithScoreCalibrations) {
  6440. kMobileSsdWithMetadataDummyScoreCalibration));
  6441. SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ObjectDetector> object_detector,
  6442. - ObjectDetector::CreateFromOptions(options));
  6443. + ObjectDetector::CreateFromOptions(options));
  6444. SUPPORT_ASSERT_OK_AND_ASSIGN(const DetectionResult result,
  6445. - object_detector->Detect(*frame_buffer));
  6446. + object_detector->Detect(*frame_buffer));
  6447. ImageDataFree(&rgb_image);
  6448. ExpectApproximatelyEqual(
  6449. result, ParseTextProtoOrDie<DetectionResult>(kExpectResults));
  6450. @@ -482,20 +475,21 @@ class PostprocessTest : public tflite_shims::testing::Test {
  6451. TEST_F(PostprocessTest, SucceedsWithScoreThresholdOption) {
  6452. ObjectDetectorOptions options;
  6453. - options.mutable_model_file_with_metadata()->set_file_name(
  6454. - JoinPath("./" /*test src dir*/, kTestDataDirectory,
  6455. - kMobileSsdWithMetadata));
  6456. + options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
  6457. + "./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata));
  6458. options.set_score_threshold(0.5);
  6459. SetUp(options);
  6460. ASSERT_TRUE(test_object_detector_ != nullptr) << init_status_;
  6461. - SUPPORT_ASSERT_OK_AND_ASSIGN(const std::vector<const TfLiteTensor*> output_tensors,
  6462. - FillAndGetOutputTensors());
  6463. + SUPPORT_ASSERT_OK_AND_ASSIGN(
  6464. + const std::vector<const TfLiteTensor*> output_tensors,
  6465. + FillAndGetOutputTensors());
  6466. - SUPPORT_ASSERT_OK_AND_ASSIGN(DetectionResult result,
  6467. - test_object_detector_->Postprocess(
  6468. - output_tensors, *dummy_frame_buffer_, /*roi=*/{}));
  6469. + SUPPORT_ASSERT_OK_AND_ASSIGN(
  6470. + DetectionResult result,
  6471. + test_object_detector_->Postprocess(output_tensors, *dummy_frame_buffer_,
  6472. + /*roi=*/{}));
  6473. ExpectApproximatelyEqual(
  6474. result,
  6475. @@ -517,16 +511,16 @@ TEST_F(PostprocessTest, SucceedsWithFrameBufferOrientation) {
  6476. FrameBuffer::Orientation::kBottomRight);
  6477. ObjectDetectorOptions options;
  6478. - options.mutable_model_file_with_metadata()->set_file_name(
  6479. - JoinPath("./" /*test src dir*/, kTestDataDirectory,
  6480. - kMobileSsdWithMetadata));
  6481. + options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
  6482. + "./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata));
  6483. options.set_score_threshold(0.5);
  6484. SetUp(options);
  6485. ASSERT_TRUE(test_object_detector_ != nullptr) << init_status_;
  6486. - SUPPORT_ASSERT_OK_AND_ASSIGN(const std::vector<const TfLiteTensor*> output_tensors,
  6487. - FillAndGetOutputTensors());
  6488. + SUPPORT_ASSERT_OK_AND_ASSIGN(
  6489. + const std::vector<const TfLiteTensor*> output_tensors,
  6490. + FillAndGetOutputTensors());
  6491. SUPPORT_ASSERT_OK_AND_ASSIGN(
  6492. DetectionResult result,
  6493. @@ -549,20 +543,21 @@ TEST_F(PostprocessTest, SucceedsWithFrameBufferOrientation) {
  6494. TEST_F(PostprocessTest, SucceedsWithMaxResultsOption) {
  6495. ObjectDetectorOptions options;
  6496. - options.mutable_model_file_with_metadata()->set_file_name(
  6497. - JoinPath("./" /*test src dir*/, kTestDataDirectory,
  6498. - kMobileSsdWithMetadata));
  6499. + options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
  6500. + "./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata));
  6501. options.set_max_results(1);
  6502. SetUp(options);
  6503. ASSERT_TRUE(test_object_detector_ != nullptr) << init_status_;
  6504. - SUPPORT_ASSERT_OK_AND_ASSIGN(const std::vector<const TfLiteTensor*> output_tensors,
  6505. - FillAndGetOutputTensors());
  6506. + SUPPORT_ASSERT_OK_AND_ASSIGN(
  6507. + const std::vector<const TfLiteTensor*> output_tensors,
  6508. + FillAndGetOutputTensors());
  6509. - SUPPORT_ASSERT_OK_AND_ASSIGN(DetectionResult result,
  6510. - test_object_detector_->Postprocess(
  6511. - output_tensors, *dummy_frame_buffer_, /*roi=*/{}));
  6512. + SUPPORT_ASSERT_OK_AND_ASSIGN(
  6513. + DetectionResult result,
  6514. + test_object_detector_->Postprocess(output_tensors, *dummy_frame_buffer_,
  6515. + /*roi=*/{}));
  6516. ExpectApproximatelyEqual(
  6517. result,
  6518. @@ -576,21 +571,22 @@ TEST_F(PostprocessTest, SucceedsWithMaxResultsOption) {
  6519. TEST_F(PostprocessTest, SucceedsWithWhitelistOption) {
  6520. ObjectDetectorOptions options;
  6521. - options.mutable_model_file_with_metadata()->set_file_name(
  6522. - JoinPath("./" /*test src dir*/, kTestDataDirectory,
  6523. - kMobileSsdWithMetadata));
  6524. + options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
  6525. + "./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata));
  6526. options.add_class_name_whitelist("car");
  6527. options.add_class_name_whitelist("motorcycle");
  6528. SetUp(options);
  6529. ASSERT_TRUE(test_object_detector_ != nullptr) << init_status_;
  6530. - SUPPORT_ASSERT_OK_AND_ASSIGN(const std::vector<const TfLiteTensor*> output_tensors,
  6531. - FillAndGetOutputTensors());
  6532. + SUPPORT_ASSERT_OK_AND_ASSIGN(
  6533. + const std::vector<const TfLiteTensor*> output_tensors,
  6534. + FillAndGetOutputTensors());
  6535. - SUPPORT_ASSERT_OK_AND_ASSIGN(DetectionResult result,
  6536. - test_object_detector_->Postprocess(
  6537. - output_tensors, *dummy_frame_buffer_, /*roi=*/{}));
  6538. + SUPPORT_ASSERT_OK_AND_ASSIGN(
  6539. + DetectionResult result,
  6540. + test_object_detector_->Postprocess(output_tensors, *dummy_frame_buffer_,
  6541. + /*roi=*/{}));
  6542. ExpectApproximatelyEqual(
  6543. result,
  6544. @@ -608,9 +604,8 @@ TEST_F(PostprocessTest, SucceedsWithWhitelistOption) {
  6545. TEST_F(PostprocessTest, SucceedsWithBlacklistOption) {
  6546. ObjectDetectorOptions options;
  6547. - options.mutable_model_file_with_metadata()->set_file_name(
  6548. - JoinPath("./" /*test src dir*/, kTestDataDirectory,
  6549. - kMobileSsdWithMetadata));
  6550. + options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
  6551. + "./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata));
  6552. options.add_class_name_blacklist("car");
  6553. // Setting score threshold to discard the 7 padded-with-zeros results.
  6554. options.set_score_threshold(0.1);
  6555. @@ -618,12 +613,14 @@ TEST_F(PostprocessTest, SucceedsWithBlacklistOption) {
  6556. SetUp(options);
  6557. ASSERT_TRUE(test_object_detector_ != nullptr) << init_status_;
  6558. - SUPPORT_ASSERT_OK_AND_ASSIGN(const std::vector<const TfLiteTensor*> output_tensors,
  6559. - FillAndGetOutputTensors());
  6560. + SUPPORT_ASSERT_OK_AND_ASSIGN(
  6561. + const std::vector<const TfLiteTensor*> output_tensors,
  6562. + FillAndGetOutputTensors());
  6563. - SUPPORT_ASSERT_OK_AND_ASSIGN(DetectionResult result,
  6564. - test_object_detector_->Postprocess(
  6565. - output_tensors, *dummy_frame_buffer_, /*roi=*/{}));
  6566. + SUPPORT_ASSERT_OK_AND_ASSIGN(
  6567. + DetectionResult result,
  6568. + test_object_detector_->Postprocess(output_tensors, *dummy_frame_buffer_,
  6569. + /*roi=*/{}));
  6570. ExpectApproximatelyEqual(
  6571. result,
  6572. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/test_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/test_utils.cc
  6573. index 7937dbafb090b..c16815cb38061 100644
  6574. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/test_utils.cc
  6575. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/test_utils.cc
  6576. @@ -21,13 +21,16 @@ namespace tflite {
  6577. namespace task {
  6578. std::string JoinPath(absl::string_view path1, absl::string_view path2) {
  6579. - if (path1.empty()) return std::string(path2);
  6580. - if (path2.empty()) return std::string(path1);
  6581. + if (path1.empty())
  6582. + return std::string(path2);
  6583. + if (path2.empty())
  6584. + return std::string(path1);
  6585. if (path1.back() == '/') {
  6586. if (path2.front() == '/')
  6587. return absl::StrCat(path1, absl::ClippedSubstr(path2, 1));
  6588. } else {
  6589. - if (path2.front() != '/') return absl::StrCat(path1, "/", path2);
  6590. + if (path2.front() != '/')
  6591. + return absl::StrCat(path1, "/", path2);
  6592. }
  6593. return absl::StrCat(path1, path2);
  6594. }
  6595. @@ -44,14 +47,16 @@ std::string JoinPathImpl(bool honor_abs,
  6596. // This size calculation is worst-case: it assumes one extra "/" for every
  6597. // path other than the first.
  6598. size_t total_size = paths.size() - 1;
  6599. - for (const absl::string_view path : paths) total_size += path.size();
  6600. + for (const absl::string_view path : paths)
  6601. + total_size += path.size();
  6602. result.resize(total_size);
  6603. auto begin = result.begin();
  6604. auto out = begin;
  6605. bool trailing_slash = false;
  6606. for (absl::string_view path : paths) {
  6607. - if (path.empty()) continue;
  6608. + if (path.empty())
  6609. + continue;
  6610. if (path.front() == '/') {
  6611. if (honor_abs) {
  6612. out = begin; // wipe out whatever we've built up so far.
  6613. @@ -59,7 +64,8 @@ std::string JoinPathImpl(bool honor_abs,
  6614. path.remove_prefix(1);
  6615. }
  6616. } else {
  6617. - if (!trailing_slash && out != begin) *out++ = '/';
  6618. + if (!trailing_slash && out != begin)
  6619. + *out++ = '/';
  6620. }
  6621. const size_t this_size = path.size();
  6622. memcpy(&*out, path.data(), this_size);
  6623. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/test_utils.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/test_utils.h
  6624. index db72bc5d5ae98..1d730d5a6d981 100644
  6625. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/test_utils.h
  6626. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/test_utils.h
  6627. @@ -33,8 +33,10 @@ std::string JoinPathImpl(bool honor_abs,
  6628. std::string JoinPath(absl::string_view path1, absl::string_view path2);
  6629. template <typename... T>
  6630. -inline std::string JoinPath(absl::string_view path1, absl::string_view path2,
  6631. - absl::string_view path3, const T&... args) {
  6632. +inline std::string JoinPath(absl::string_view path1,
  6633. + absl::string_view path2,
  6634. + absl::string_view path3,
  6635. + const T&... args) {
  6636. return internal::JoinPathImpl(false, {path1, path2, path3, args...});
  6637. }
  6638. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.cc
  6639. index 6a050668edcbe..53c88310dde43 100644
  6640. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.cc
  6641. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.cc
  6642. @@ -31,7 +31,8 @@ FlatHashMapBackedWordpiece::FlatHashMapBackedWordpiece(
  6643. }
  6644. tensorflow::text::LookupStatus FlatHashMapBackedWordpiece::Contains(
  6645. - absl::string_view key, bool* value) const {
  6646. + absl::string_view key,
  6647. + bool* value) const {
  6648. *value = index_map_.contains(key);
  6649. return tensorflow::text::LookupStatus();
  6650. }
  6651. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.h
  6652. index aec178daf3cc5..1de54fa8f651c 100644
  6653. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.h
  6654. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.h
  6655. @@ -103,7 +103,8 @@ class BertTokenizer : public tflite::support::text::tokenizer::Tokenizer {
  6656. // Initialize the tokenizer from buffer and size of vocab and tokenizer
  6657. // configs.
  6658. - BertTokenizer(const char* vocab_buffer_data, size_t vocab_buffer_size,
  6659. + BertTokenizer(const char* vocab_buffer_data,
  6660. + size_t vocab_buffer_size,
  6661. const BertTokenizerOptions& options = {})
  6662. : BertTokenizer(
  6663. utils::LoadVocabFromBuffer(vocab_buffer_data, vocab_buffer_size),
  6664. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer_jni.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer_jni.cc
  6665. index 151161777863f..249bc2d1b6bc2 100644
  6666. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer_jni.cc
  6667. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer_jni.cc
  6668. @@ -31,9 +31,14 @@ using ::tflite::support::utils::StringListToVector;
  6669. extern "C" JNIEXPORT jlong JNICALL
  6670. Java_org_tensorflow_lite_support_text_tokenizers_BertTokenizer_nativeLoadResource( // NOLINT
  6671. - JNIEnv* env, jobject thiz, jobject vocab_list, jint max_bytes_per_token,
  6672. - jint max_chars_per_sub_token, jstring jsuffix_indicator,
  6673. - jboolean use_unknown_token, jstring junknown_token,
  6674. + JNIEnv* env,
  6675. + jobject thiz,
  6676. + jobject vocab_list,
  6677. + jint max_bytes_per_token,
  6678. + jint max_chars_per_sub_token,
  6679. + jstring jsuffix_indicator,
  6680. + jboolean use_unknown_token,
  6681. + jstring junknown_token,
  6682. jboolean split_unknown_chars) {
  6683. // Convert java.util.List<String> into std::vector<string>
  6684. std::vector<std::string> vocab = StringListToVector(env, vocab_list);
  6685. @@ -66,20 +71,28 @@ Java_org_tensorflow_lite_support_text_tokenizers_BertTokenizer_nativeLoadResourc
  6686. extern "C" JNIEXPORT jlong JNICALL
  6687. Java_org_tensorflow_lite_support_text_tokenizers_BertTokenizer_nativeUnloadResource( // NOLINT
  6688. - JNIEnv* env, jobject thiz, jlong handle) {
  6689. + JNIEnv* env,
  6690. + jobject thiz,
  6691. + jlong handle) {
  6692. delete reinterpret_cast<BertTokenizer*>(handle);
  6693. return 0;
  6694. }
  6695. extern "C" JNIEXPORT jobjectArray JNICALL
  6696. Java_org_tensorflow_lite_support_text_tokenizers_BertTokenizer_nativeTokenize(
  6697. - JNIEnv* env, jobject thiz, jlong handle, jstring jtext) {
  6698. + JNIEnv* env,
  6699. + jobject thiz,
  6700. + jlong handle,
  6701. + jstring jtext) {
  6702. return nativeTokenize(env, handle, jtext);
  6703. }
  6704. extern "C" JNIEXPORT jintArray JNICALL
  6705. Java_org_tensorflow_lite_support_text_tokenizers_BertTokenizer_nativeConvertTokensToIds( // NOLINT
  6706. - JNIEnv* env, jobject thiz, jlong handle, jobjectArray jtokens) {
  6707. + JNIEnv* env,
  6708. + jobject thiz,
  6709. + jlong handle,
  6710. + jobjectArray jtokens) {
  6711. return nativeConvertTokensToIds(env, handle, jtokens);
  6712. }
  6713. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.cc
  6714. index 832f9df42f824..ded6fbd13ea4a 100644
  6715. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.cc
  6716. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.cc
  6717. @@ -17,7 +17,7 @@ limitations under the License.
  6718. #include <iostream>
  6719. -#include "absl/strings/str_cat.h" // from @com_google_absl
  6720. +#include "absl/strings/str_cat.h" // from @com_google_absl
  6721. #include "absl/strings/substitute.h" // from @com_google_absl
  6722. #include "tensorflow_lite_support/cc/utils/common_utils.h"
  6723. namespace tflite {
  6724. @@ -70,7 +70,7 @@ TokenizerResult RegexTokenizer::Tokenize(const std::string& input) {
  6725. re2::StringPiece extracted_delim_token;
  6726. while (RE2::FindAndConsume(&leftover, delim_re_, &extracted_delim_token)) {
  6727. re2::StringPiece token(last_end.data(),
  6728. - extracted_delim_token.data() - last_end.data());
  6729. + extracted_delim_token.data() - last_end.data());
  6730. bool has_non_empty_token = token.length() > 0;
  6731. last_end = leftover;
  6732. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/sentencepiece_jni.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/sentencepiece_jni.cc
  6733. index 6ecfff0d2baa1..8ca14c52eb262 100644
  6734. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/sentencepiece_jni.cc
  6735. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/sentencepiece_jni.cc
  6736. @@ -20,7 +20,7 @@ limitations under the License.
  6737. #include <utility>
  6738. #include <vector>
  6739. -#include "absl/memory/memory.h" // from @com_google_absl
  6740. +#include "absl/memory/memory.h" // from @com_google_absl
  6741. #include "absl/strings/str_split.h" // from @com_google_absl
  6742. #include "tensorflow_lite_support/cc/text/tokenizers/sentencepiece_tokenizer.h"
  6743. #include "tensorflow_lite_support/cc/text/tokenizers/tokenizer_jni_lib.h"
  6744. @@ -34,7 +34,9 @@ using ::tflite::support::utils::GetMappedFileBuffer;
  6745. extern "C" JNIEXPORT jlong JNICALL
  6746. Java_org_tensorflow_lite_support_text_tokenizers_SentencePieceTokenizer_nativeLoadResource( // NOLINT
  6747. - JNIEnv* env, jobject obj, jobject model_buffer) {
  6748. + JNIEnv* env,
  6749. + jobject obj,
  6750. + jobject model_buffer) {
  6751. auto model = GetMappedFileBuffer(env, model_buffer);
  6752. auto handle =
  6753. absl::make_unique<SentencePieceTokenizer>(model.data(), model.size());
  6754. @@ -43,20 +45,28 @@ Java_org_tensorflow_lite_support_text_tokenizers_SentencePieceTokenizer_nativeLo
  6755. extern "C" JNIEXPORT jlong JNICALL
  6756. Java_org_tensorflow_lite_support_text_tokenizers_SentencePieceTokenizer_nativeUnloadResource( // NOLINT
  6757. - JNIEnv* env, jobject obj, jlong handle) {
  6758. + JNIEnv* env,
  6759. + jobject obj,
  6760. + jlong handle) {
  6761. delete reinterpret_cast<SentencePieceTokenizer*>(handle);
  6762. return 0;
  6763. }
  6764. extern "C" JNIEXPORT jobjectArray JNICALL
  6765. Java_org_tensorflow_lite_support_text_tokenizers_SentencePieceTokenizer_nativeTokenize( // NOLINT
  6766. - JNIEnv* env, jobject thiz, jlong handle, jstring jtext) {
  6767. + JNIEnv* env,
  6768. + jobject thiz,
  6769. + jlong handle,
  6770. + jstring jtext) {
  6771. return nativeTokenize(env, handle, jtext);
  6772. }
  6773. extern "C" JNIEXPORT jintArray JNICALL
  6774. Java_org_tensorflow_lite_support_text_tokenizers_SentencePieceTokenizer_nativeConvertTokensToIds( // NOLINT
  6775. - JNIEnv* env, jobject thiz, jlong handle, jobjectArray jtokens) {
  6776. + JNIEnv* env,
  6777. + jobject thiz,
  6778. + jlong handle,
  6779. + jobjectArray jtokens) {
  6780. return nativeConvertTokensToIds(env, handle, jtokens);
  6781. }
  6782. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_jni_lib.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_jni_lib.cc
  6783. index a72523be5984e..4e32bc5581a48 100644
  6784. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_jni_lib.cc
  6785. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_jni_lib.cc
  6786. @@ -54,7 +54,8 @@ jobjectArray nativeTokenize(JNIEnv* env, jlong handle, jstring jtext) {
  6787. return result;
  6788. }
  6789. -jintArray nativeConvertTokensToIds(JNIEnv* env, jlong handle,
  6790. +jintArray nativeConvertTokensToIds(JNIEnv* env,
  6791. + jlong handle,
  6792. jobjectArray jtokens) {
  6793. if (handle == 0) {
  6794. env->ThrowNew(env->FindClass(kIllegalStateException),
  6795. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_jni_lib.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_jni_lib.h
  6796. index 33677d305a853..fd76f3aa553e4 100644
  6797. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_jni_lib.h
  6798. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_jni_lib.h
  6799. @@ -25,7 +25,8 @@ namespace support {
  6800. jobjectArray nativeTokenize(JNIEnv* env, jlong handle, jstring jtext);
  6801. -jintArray nativeConvertTokensToIds(JNIEnv* env, jlong handle,
  6802. +jintArray nativeConvertTokensToIds(JNIEnv* env,
  6803. + jlong handle,
  6804. jobjectArray jtokens);
  6805. } // namespace support
  6806. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.cc
  6807. index 28f0137f54278..32957d155dce6 100644
  6808. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.cc
  6809. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.cc
  6810. @@ -73,9 +73,9 @@ StatusOr<std::unique_ptr<Tokenizer>> CreateTokenizerFromProcessUnit(
  6811. }
  6812. case ProcessUnitOptions_SentencePieceTokenizerOptions: {
  6813. return CreateStatusWithPayload(
  6814. - absl::StatusCode::kInvalidArgument,
  6815. - "Chromium does not support sentencepiece tokenization",
  6816. - TfLiteSupportStatus::kMetadataInvalidTokenizerError);
  6817. + absl::StatusCode::kInvalidArgument,
  6818. + "Chromium does not support sentencepiece tokenization",
  6819. + TfLiteSupportStatus::kMetadataInvalidTokenizerError);
  6820. }
  6821. case ProcessUnitOptions_RegexTokenizerOptions: {
  6822. const tflite::RegexTokenizerOptions* options =
  6823. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.h
  6824. index 2e50a79963f82..696c5d4e27db7 100644
  6825. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.h
  6826. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.h
  6827. @@ -26,7 +26,6 @@ namespace support {
  6828. namespace text {
  6829. namespace tokenizer {
  6830. -
  6831. // Create a Tokenizer from model metadata by extracting
  6832. tflite::support::StatusOr<std::unique_ptr<Tokenizer>>
  6833. CreateTokenizerFromProcessUnit(
  6834. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/common_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/common_utils.cc
  6835. index 84cc0ef6ae52e..3ea6b147fcdd6 100644
  6836. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/common_utils.cc
  6837. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/common_utils.cc
  6838. @@ -83,7 +83,8 @@ absl::node_hash_map<std::string, int> LoadVocabAndIndexFromFile(
  6839. }
  6840. absl::node_hash_map<std::string, int> LoadVocabAndIndexFromBuffer(
  6841. - const char* vocab_buffer_data, const size_t vocab_buffer_size) {
  6842. + const char* vocab_buffer_data,
  6843. + const size_t vocab_buffer_size) {
  6844. membuf sbuf(const_cast<char*>(vocab_buffer_data),
  6845. const_cast<char*>(vocab_buffer_data + vocab_buffer_size));
  6846. absl::node_hash_map<std::string, int> vocab_index_map;
  6847. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/common_utils.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/common_utils.h
  6848. index 6921d2f5ac01b..275c4932f8ec0 100644
  6849. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/common_utils.h
  6850. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/common_utils.h
  6851. @@ -41,7 +41,8 @@ absl::node_hash_map<std::string, int> LoadVocabAndIndexFromFile(
  6852. // Read a vocab buffer with one vocabulary and its corresponding index on each
  6853. // line separated by space, create a map of <vocab, index>.
  6854. absl::node_hash_map<std::string, int> LoadVocabAndIndexFromBuffer(
  6855. - const char* vocab_buffer_data, const size_t vocab_buffer_size);
  6856. + const char* vocab_buffer_data,
  6857. + const size_t vocab_buffer_size);
  6858. } // namespace utils
  6859. } // namespace support
  6860. } // namespace tflite
  6861. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/jni_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/jni_utils.cc
  6862. index bf9e93f9aa24a..35ce822951ad8 100644
  6863. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/jni_utils.cc
  6864. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/jni_utils.cc
  6865. @@ -18,8 +18,8 @@ limitations under the License.
  6866. #include <dlfcn.h>
  6867. #include <string.h>
  6868. -#include "absl/memory/memory.h" // from @com_google_absl
  6869. -#include "absl/status/status.h" // from @com_google_absl
  6870. +#include "absl/memory/memory.h" // from @com_google_absl
  6871. +#include "absl/status/status.h" // from @com_google_absl
  6872. #include "absl/strings/str_format.h" // from @com_google_absl
  6873. #include "tensorflow/lite/core/shims/c/experimental/acceleration/configuration/delegate_plugin.h"
  6874. #include "tensorflow/lite/core/shims/cc/experimental/acceleration/configuration/delegate_registry.h"
  6875. @@ -168,7 +168,8 @@ void ThrowException(JNIEnv* env, const char* clazz, const char* fmt, ...) {
  6876. va_end(args);
  6877. }
  6878. -void ThrowExceptionWithMessage(JNIEnv* env, const char* clazz,
  6879. +void ThrowExceptionWithMessage(JNIEnv* env,
  6880. + const char* clazz,
  6881. const char* message) {
  6882. jclass e_class = env->FindClass(clazz);
  6883. if (strcmp(clazz, kAssertionError) == 0) {
  6884. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/jni_utils.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/jni_utils.h
  6885. index 6d15bb43e75b3..f92f838bb9a71 100644
  6886. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/jni_utils.h
  6887. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/jni_utils.h
  6888. @@ -22,7 +22,7 @@ limitations under the License.
  6889. #include <string>
  6890. #include <vector>
  6891. -#include "absl/status/status.h" // from @com_google_absl
  6892. +#include "absl/status/status.h" // from @com_google_absl
  6893. #include "absl/strings/string_view.h" // from @com_google_absl
  6894. #include "tensorflow_lite_support/cc/port/configuration_proto_inc.h"
  6895. #include "tensorflow_lite_support/cc/port/statusor.h"
  6896. @@ -59,7 +59,9 @@ T CheckNotNull(JNIEnv* env, T&& t) {
  6897. // interable before adding it to the ArrayList.
  6898. template <typename Iterator>
  6899. jobject ConvertVectorToArrayList(
  6900. - JNIEnv* env, const Iterator& begin, const Iterator& end,
  6901. + JNIEnv* env,
  6902. + const Iterator& begin,
  6903. + const Iterator& end,
  6904. std::function<jobject(typename std::iterator_traits<Iterator>::value_type)>
  6905. converter) {
  6906. jclass array_list_class = env->FindClass("java/util/ArrayList");
  6907. @@ -94,7 +96,8 @@ jbyteArray CreateByteArray(JNIEnv* env, const jbyte* data, int num_bytes);
  6908. void ThrowException(JNIEnv* env, const char* clazz, const char* fmt, ...);
  6909. -void ThrowExceptionWithMessage(JNIEnv* env, const char* clazz,
  6910. +void ThrowExceptionWithMessage(JNIEnv* env,
  6911. + const char* clazz,
  6912. const char* message);
  6913. const char* GetExceptionClassNameForStatusCode(absl::StatusCode status_code);
  6914. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/codegen/android_java_generator.cc b/third_party/tflite_support/src/tensorflow_lite_support/codegen/android_java_generator.cc
  6915. index eb94cb7020475..bb8f1f4d40655 100644
  6916. --- a/third_party/tflite_support/src/tensorflow_lite_support/codegen/android_java_generator.cc
  6917. +++ b/third_party/tflite_support/src/tensorflow_lite_support/codegen/android_java_generator.cc
  6918. @@ -63,7 +63,8 @@ using details_android_java::TensorInfo;
  6919. // Using ctor and dtor to simulate an enter/exit schema like `with` in Python.
  6920. class AsBlock {
  6921. public:
  6922. - AsBlock(CodeWriter* code_writer, const std::string& before,
  6923. + AsBlock(CodeWriter* code_writer,
  6924. + const std::string& before,
  6925. bool trailing_blank_line = false)
  6926. : code_writer_(code_writer), trailing_blank_line_(trailing_blank_line) {
  6927. code_writer_->AppendNoNewLine(before);
  6928. @@ -105,7 +106,9 @@ std::string GetModelVersionedName(const ModelMetadata* metadata) {
  6929. }
  6930. TensorInfo CreateTensorInfo(const TensorMetadata* metadata,
  6931. - const std::string& name, bool is_input, int index,
  6932. + const std::string& name,
  6933. + bool is_input,
  6934. + int index,
  6935. ErrorReporter* err) {
  6936. TensorInfo tensor_info;
  6937. std::string tensor_identifier = is_input ? "input" : "output";
  6938. @@ -273,7 +276,8 @@ bool IsImageUsed(const ModelInfo& model) {
  6939. // The following functions generates the wrapper Java code for a model.
  6940. -bool GenerateWrapperFileContent(CodeWriter* code_writer, const ModelInfo& model,
  6941. +bool GenerateWrapperFileContent(CodeWriter* code_writer,
  6942. + const ModelInfo& model,
  6943. ErrorReporter* err) {
  6944. code_writer->Append("// Generated by TFLite Support.");
  6945. code_writer->Append("package {{PACKAGE}};");
  6946. @@ -291,7 +295,8 @@ bool GenerateWrapperFileContent(CodeWriter* code_writer, const ModelInfo& model,
  6947. return true;
  6948. }
  6949. -bool GenerateWrapperImports(CodeWriter* code_writer, const ModelInfo& model,
  6950. +bool GenerateWrapperImports(CodeWriter* code_writer,
  6951. + const ModelInfo& model,
  6952. ErrorReporter* err) {
  6953. const std::string support_pkg = "org.tensorflow.lite.support.";
  6954. std::vector<std::string> imports{
  6955. @@ -336,7 +341,8 @@ bool GenerateWrapperImports(CodeWriter* code_writer, const ModelInfo& model,
  6956. return true;
  6957. }
  6958. -bool GenerateWrapperClass(CodeWriter* code_writer, const ModelInfo& model,
  6959. +bool GenerateWrapperClass(CodeWriter* code_writer,
  6960. + const ModelInfo& model,
  6961. ErrorReporter* err) {
  6962. code_writer->SetTokenValue("MODEL_VERSIONED_NAME",
  6963. model.model_versioned_name);
  6964. @@ -373,7 +379,8 @@ private static final String MODEL_NAME = "{{MODEL_PATH}}";)");
  6965. return true;
  6966. }
  6967. -bool GenerateWrapperOutputs(CodeWriter* code_writer, const ModelInfo& model,
  6968. +bool GenerateWrapperOutputs(CodeWriter* code_writer,
  6969. + const ModelInfo& model,
  6970. ErrorReporter* err) {
  6971. code_writer->Append("/** Output wrapper of {@link {{MODEL_CLASS_NAME}}} */");
  6972. auto class_block = AsBlock(code_writer, "public static class Outputs");
  6973. @@ -459,7 +466,8 @@ bool GenerateWrapperOutputs(CodeWriter* code_writer, const ModelInfo& model,
  6974. return true;
  6975. }
  6976. -bool GenerateWrapperMetadata(CodeWriter* code_writer, const ModelInfo& model,
  6977. +bool GenerateWrapperMetadata(CodeWriter* code_writer,
  6978. + const ModelInfo& model,
  6979. ErrorReporter* err) {
  6980. code_writer->Append(
  6981. "/** Metadata accessors of {@link {{MODEL_CLASS_NAME}}} */");
  6982. @@ -605,7 +613,8 @@ public List<String> get{{NAME_U}}Labels() {
  6983. return true;
  6984. }
  6985. -bool GenerateWrapperAPI(CodeWriter* code_writer, const ModelInfo& model,
  6986. +bool GenerateWrapperAPI(CodeWriter* code_writer,
  6987. + const ModelInfo& model,
  6988. ErrorReporter* err) {
  6989. code_writer->Append(R"(public Metadata getMetadata() {
  6990. return metadata;
  6991. @@ -980,8 +989,10 @@ AndroidJavaGenerator::AndroidJavaGenerator(const std::string& module_root)
  6992. : CodeGenerator(), module_root_(module_root) {}
  6993. GenerationResult AndroidJavaGenerator::Generate(
  6994. - const Model* model, const std::string& package_name,
  6995. - const std::string& model_class_name, const std::string& model_asset_path) {
  6996. + const Model* model,
  6997. + const std::string& package_name,
  6998. + const std::string& model_class_name,
  6999. + const std::string& model_asset_path) {
  7000. GenerationResult result;
  7001. if (model == nullptr) {
  7002. err_.Error(
  7003. @@ -1006,8 +1017,10 @@ GenerationResult AndroidJavaGenerator::Generate(
  7004. }
  7005. GenerationResult AndroidJavaGenerator::Generate(
  7006. - const char* model_storage, const std::string& package_name,
  7007. - const std::string& model_class_name, const std::string& model_asset_path) {
  7008. + const char* model_storage,
  7009. + const std::string& package_name,
  7010. + const std::string& model_class_name,
  7011. + const std::string& model_asset_path) {
  7012. const Model* model = GetModel(model_storage);
  7013. return Generate(model, package_name, model_class_name, model_asset_path);
  7014. }
  7015. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/codegen/android_java_generator.h b/third_party/tflite_support/src/tensorflow_lite_support/codegen/android_java_generator.h
  7016. index 634ccf69f6c1a..1ea8bb2182a67 100644
  7017. --- a/third_party/tflite_support/src/tensorflow_lite_support/codegen/android_java_generator.h
  7018. +++ b/third_party/tflite_support/src/tensorflow_lite_support/codegen/android_java_generator.h
  7019. @@ -20,10 +20,10 @@ limitations under the License.
  7020. #include <string>
  7021. #include <vector>
  7022. +#include "tensorflow/lite/schema/schema_generated.h"
  7023. #include "tensorflow_lite_support/codegen/code_generator.h"
  7024. #include "tensorflow_lite_support/codegen/utils.h"
  7025. #include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
  7026. -#include "tensorflow/lite/schema/schema_generated.h"
  7027. namespace tflite {
  7028. namespace support {
  7029. @@ -90,7 +90,8 @@ class AndroidJavaGenerator : public CodeGenerator {
  7030. /// as "ImageClassifier", "MobileNetV2" or "MyModel".
  7031. /// - model_asset_path: The relevant path to the model file in the asset.
  7032. // TODO(b/141225157): Automatically generate model_class_name.
  7033. - GenerationResult Generate(const Model* model, const std::string& package_name,
  7034. + GenerationResult Generate(const Model* model,
  7035. + const std::string& package_name,
  7036. const std::string& model_class_name,
  7037. const std::string& model_asset_path);
  7038. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/codegen/code_generator.cc b/third_party/tflite_support/src/tensorflow_lite_support/codegen/code_generator.cc
  7039. index 1337708d4ac66..b6ec55cbc5e8b 100644
  7040. --- a/third_party/tflite_support/src/tensorflow_lite_support/codegen/code_generator.cc
  7041. +++ b/third_party/tflite_support/src/tensorflow_lite_support/codegen/code_generator.cc
  7042. @@ -144,7 +144,8 @@ std::string CodeGenerator::NameTensor(const TensorMetadata& tensor,
  7043. }
  7044. void CodeGenerator::ResolveConflictedInputAndOutputNames(
  7045. - std::vector<std::string>* inputs, std::vector<std::string>* outputs) {
  7046. + std::vector<std::string>* inputs,
  7047. + std::vector<std::string>* outputs) {
  7048. std::unordered_set<std::string> io_conflict;
  7049. auto& input_names = *inputs;
  7050. auto& output_names = *outputs;
  7051. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/codegen/code_generator.h b/third_party/tflite_support/src/tensorflow_lite_support/codegen/code_generator.h
  7052. index b557773ddcc7a..fe67327986bd7 100644
  7053. --- a/third_party/tflite_support/src/tensorflow_lite_support/codegen/code_generator.h
  7054. +++ b/third_party/tflite_support/src/tensorflow_lite_support/codegen/code_generator.h
  7055. @@ -70,7 +70,8 @@ class CodeGenerator {
  7056. static std::string NameTensor(const TensorMetadata& tensor,
  7057. const std::string& default_name);
  7058. static void ResolveConflictedInputAndOutputNames(
  7059. - std::vector<std::string>* input, std::vector<std::string>* output);
  7060. + std::vector<std::string>* input,
  7061. + std::vector<std::string>* output);
  7062. };
  7063. } // namespace codegen
  7064. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/codegen/code_generator_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/codegen/code_generator_test.cc
  7065. index 5e9d64a0d8f98..ccc87668ed3cb 100644
  7066. --- a/third_party/tflite_support/src/tensorflow_lite_support/codegen/code_generator_test.cc
  7067. +++ b/third_party/tflite_support/src/tensorflow_lite_support/codegen/code_generator_test.cc
  7068. @@ -36,7 +36,8 @@ class CodeGeneratorTest : public ::testing::Test {
  7069. return CodeGenerator::ConvertToValidName(name);
  7070. }
  7071. static void ResolveConflictedInputAndOutputNames(
  7072. - std::vector<std::string>* input, std::vector<std::string>* output) {
  7073. + std::vector<std::string>* input,
  7074. + std::vector<std::string>* output) {
  7075. CodeGenerator::ResolveConflictedInputAndOutputNames(input, output);
  7076. }
  7077. };
  7078. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/codegen/metadata_helper.h b/third_party/tflite_support/src/tensorflow_lite_support/codegen/metadata_helper.h
  7079. index 8e3dc6abaed66..193dfb2fb23f3 100644
  7080. --- a/third_party/tflite_support/src/tensorflow_lite_support/codegen/metadata_helper.h
  7081. +++ b/third_party/tflite_support/src/tensorflow_lite_support/codegen/metadata_helper.h
  7082. @@ -18,9 +18,9 @@ limitations under the License.
  7083. #include <string>
  7084. +#include "tensorflow/lite/schema/schema_generated.h"
  7085. #include "tensorflow_lite_support/codegen/utils.h"
  7086. #include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
  7087. -#include "tensorflow/lite/schema/schema_generated.h"
  7088. namespace tflite {
  7089. namespace support {
  7090. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/codegen/python/codegen_lib.cc b/third_party/tflite_support/src/tensorflow_lite_support/codegen/python/codegen_lib.cc
  7091. index 6b2cd5ea9a778..a9da2403afc4f 100644
  7092. --- a/third_party/tflite_support/src/tensorflow_lite_support/codegen/python/codegen_lib.cc
  7093. +++ b/third_party/tflite_support/src/tensorflow_lite_support/codegen/python/codegen_lib.cc
  7094. @@ -29,11 +29,10 @@ using overload_cast_ = pybind11::detail::overload_cast_impl<Args...>;
  7095. PYBIND11_MODULE(_pywrap_codegen, m) {
  7096. pybind11::class_<AndroidJavaGenerator>(m, "AndroidJavaGenerator")
  7097. - .def(pybind11::init<const std::string &>())
  7098. - .def("generate",
  7099. - overload_cast_<const char *, const std::string &,
  7100. - const std::string &, const std::string &>()(
  7101. - &AndroidJavaGenerator::Generate))
  7102. + .def(pybind11::init<const std::string&>())
  7103. + .def("generate", overload_cast_<const char*, const std::string&,
  7104. + const std::string&, const std::string&>()(
  7105. + &AndroidJavaGenerator::Generate))
  7106. .def("get_error_message", &AndroidJavaGenerator::GetErrorMessage);
  7107. pybind11::class_<GenerationResult>(m, "GenerationResult")
  7108. .def(pybind11::init<>())
  7109. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/codegen/utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/codegen/utils.cc
  7110. index c75fc5fae631d..e89d09629dda1 100644
  7111. --- a/third_party/tflite_support/src/tensorflow_lite_support/codegen/utils.cc
  7112. +++ b/third_party/tflite_support/src/tensorflow_lite_support/codegen/utils.cc
  7113. @@ -32,7 +32,8 @@ int ErrorReporter::Error(const char* format, ...) {
  7114. return Report("[ERROR] ", format, args);
  7115. }
  7116. -int ErrorReporter::Report(const char* prefix, const char* format,
  7117. +int ErrorReporter::Report(const char* prefix,
  7118. + const char* format,
  7119. va_list args) {
  7120. char buf[1024];
  7121. int formatted = vsnprintf(buf, sizeof(buf), format, args);
  7122. @@ -69,9 +70,13 @@ void CodeWriter::SetIndentString(const std::string& indent_str) {
  7123. indent_str_ = indent_str;
  7124. }
  7125. -void CodeWriter::Indent() { indent_++; }
  7126. +void CodeWriter::Indent() {
  7127. + indent_++;
  7128. +}
  7129. -void CodeWriter::Outdent() { indent_--; }
  7130. +void CodeWriter::Outdent() {
  7131. + indent_--;
  7132. +}
  7133. std::string CodeWriter::GenerateIndent() const {
  7134. std::string res;
  7135. @@ -82,7 +87,9 @@ std::string CodeWriter::GenerateIndent() const {
  7136. return res;
  7137. }
  7138. -void CodeWriter::Append(const std::string& text) { AppendInternal(text, true); }
  7139. +void CodeWriter::Append(const std::string& text) {
  7140. + AppendInternal(text, true);
  7141. +}
  7142. void CodeWriter::AppendNoNewLine(const std::string& text) {
  7143. AppendInternal(text, false);
  7144. @@ -144,15 +151,21 @@ void CodeWriter::AppendInternal(const std::string& text, bool newline) {
  7145. }
  7146. }
  7147. -void CodeWriter::NewLine() { Append(""); }
  7148. +void CodeWriter::NewLine() {
  7149. + Append("");
  7150. +}
  7151. void CodeWriter::Backspace(int n) {
  7152. buffer_.resize(buffer_.size() > n ? buffer_.size() - n : 0);
  7153. }
  7154. -std::string CodeWriter::ToString() const { return buffer_; }
  7155. +std::string CodeWriter::ToString() const {
  7156. + return buffer_;
  7157. +}
  7158. -bool CodeWriter::IsStreamEmpty() const { return buffer_.empty(); }
  7159. +bool CodeWriter::IsStreamEmpty() const {
  7160. + return buffer_.empty();
  7161. +}
  7162. void CodeWriter::Clear() {
  7163. buffer_.clear();
  7164. @@ -181,11 +194,14 @@ std::string SnakeCaseToCamelCase(const std::string& s) {
  7165. }
  7166. std::string JoinPath(const std::string& a, const std::string& b) {
  7167. - if (a.empty()) return b;
  7168. + if (a.empty())
  7169. + return b;
  7170. std::string a_fixed = a;
  7171. - if (!a_fixed.empty() && a_fixed.back() == '/') a_fixed.pop_back();
  7172. + if (!a_fixed.empty() && a_fixed.back() == '/')
  7173. + a_fixed.pop_back();
  7174. std::string b_fixed = b;
  7175. - if (!b_fixed.empty() && b_fixed.front() == '/') b_fixed.erase(0, 1);
  7176. + if (!b_fixed.empty() && b_fixed.front() == '/')
  7177. + b_fixed.erase(0, 1);
  7178. return a_fixed + "/" + b_fixed;
  7179. }
  7180. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ngrams.cc b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ngrams.cc
  7181. index 3831c63ca17cc..f55ffb907f133 100644
  7182. --- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ngrams.cc
  7183. +++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ngrams.cc
  7184. @@ -66,7 +66,9 @@ struct NgramsAttributes {
  7185. string_separator(m["string_separator"].ToString()) {}
  7186. };
  7187. -inline bool OutputIsTensor(TfLiteNode* node) { return NumOutputs(node) == 1; }
  7188. +inline bool OutputIsTensor(TfLiteNode* node) {
  7189. + return NumOutputs(node) == 1;
  7190. +}
  7191. inline int NumRowSplits(TfLiteNode* node) {
  7192. return NumInputs(node) - kRowSplitsStart;
  7193. }
  7194. @@ -176,7 +178,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
  7195. std::vector<StringRef> tokens;
  7196. for (int j = input_row_splits[i]; j < input_row_splits[i + 1]; ++j) {
  7197. tokens.emplace_back(GetString(input_values, j));
  7198. - if (tokens.size() < attributes.width) continue;
  7199. + if (tokens.size() < attributes.width)
  7200. + continue;
  7201. tokens.erase(tokens.begin(),
  7202. tokens.begin() + tokens.size() - attributes.width);
  7203. buffer.AddJoinedString(tokens, separator);
  7204. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ngrams_op_resolver.cc b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ngrams_op_resolver.cc
  7205. index b87fcac328623..dc21f37beb3bf 100644
  7206. --- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ngrams_op_resolver.cc
  7207. +++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ngrams_op_resolver.cc
  7208. @@ -15,8 +15,8 @@ limitations under the License.
  7209. #include "tensorflow_lite_support/custom_ops/kernel/ngrams_op_resolver.h"
  7210. -#include "tensorflow_lite_support/custom_ops/kernel/ngrams.h"
  7211. #include "tensorflow/lite/mutable_op_resolver.h"
  7212. +#include "tensorflow_lite_support/custom_ops/kernel/ngrams.h"
  7213. namespace tflite {
  7214. namespace ops {
  7215. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ngrams_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ngrams_test.cc
  7216. index 91ef47af6fd0f..4a5e671fa0987 100644
  7217. --- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ngrams_test.cc
  7218. +++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ngrams_test.cc
  7219. @@ -40,7 +40,8 @@ using ::testing::ElementsAreArray;
  7220. class NgramsModel : public SingleOpModel {
  7221. public:
  7222. // Constructor for testing the op with a tf.Tensor
  7223. - NgramsModel(int width, const std::string& string_separator,
  7224. + NgramsModel(int width,
  7225. + const std::string& string_separator,
  7226. const std::vector<std::string>& input_values,
  7227. const std::vector<int>& input_shape) {
  7228. input_values_ = AddInput(TensorType_STRING);
  7229. @@ -56,7 +57,8 @@ class NgramsModel : public SingleOpModel {
  7230. // Constructor for the op with a tf.RaggedTensor
  7231. // Note: This interface uses row_lengths, as they're closer to the
  7232. // dimensions in a TensorShape, but internally everything is row_splits.
  7233. - NgramsModel(int width, const std::string& string_separator,
  7234. + NgramsModel(int width,
  7235. + const std::string& string_separator,
  7236. const std::vector<std::string>& input_values,
  7237. const std::vector<std::vector<int64_t>> nested_row_lengths) {
  7238. std::vector<std::vector<int>> input_shapes;
  7239. @@ -203,8 +205,7 @@ TEST(NgramsTest, TensorMultidimensionalInputWidthTwo) {
  7240. TEST(NgramsTest, RaggedTensorSingleSequenceWidthTwo) {
  7241. std::vector<std::vector<int64_t>> nested_row_lengths;
  7242. nested_row_lengths.push_back({4});
  7243. - NgramsModel m(2, " ", {"this", "is", "a", "test"},
  7244. - nested_row_lengths);
  7245. + NgramsModel m(2, " ", {"this", "is", "a", "test"}, nested_row_lengths);
  7246. EXPECT_THAT(m.GetValuesTensorShape(), ElementsAre(3));
  7247. EXPECT_THAT(m.ExtractValuesTensorVector(),
  7248. ElementsAre("this is", "is a", "a test"));
  7249. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/py_tflite_registerer.h b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/py_tflite_registerer.h
  7250. index ade3c5c178920..811be781d27fe 100644
  7251. --- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/py_tflite_registerer.h
  7252. +++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/py_tflite_registerer.h
  7253. @@ -20,6 +20,6 @@ limitations under the License.
  7254. // C-function that is called from the Python Wrapper.
  7255. extern "C" void TFLite_RaggedTensorToTensorRegisterer(
  7256. - tflite::MutableOpResolver *resolver);
  7257. + tflite::MutableOpResolver* resolver);
  7258. #endif // TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_RAGGED_PY_TFLITE_REGISTERER_H_
  7259. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_range_tflite.cc b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_range_tflite.cc
  7260. index a35a6db9ad48f..9fc73dd0f9778 100644
  7261. --- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_range_tflite.cc
  7262. +++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_range_tflite.cc
  7263. @@ -71,9 +71,12 @@ TfLiteStatus EvalT(TfLiteContext* context, TfLiteNode* node) {
  7264. // nrows (number of output rows) is the size of the non-broadcast inputs,
  7265. // or 1 if all inputs are scalars.
  7266. std::vector<int> in_sizes;
  7267. - if (!broadcast_starts) in_sizes.push_back(input_starts.dims->data[0]);
  7268. - if (!broadcast_limits) in_sizes.push_back(input_limits.dims->data[0]);
  7269. - if (!broadcast_deltas) in_sizes.push_back(input_deltas.dims->data[0]);
  7270. + if (!broadcast_starts)
  7271. + in_sizes.push_back(input_starts.dims->data[0]);
  7272. + if (!broadcast_limits)
  7273. + in_sizes.push_back(input_limits.dims->data[0]);
  7274. + if (!broadcast_deltas)
  7275. + in_sizes.push_back(input_deltas.dims->data[0]);
  7276. if (std::adjacent_find(std::begin(in_sizes), std::end(in_sizes),
  7277. std::not_equal_to<>()) != std::end(in_sizes)) {
  7278. context->ReportError(
  7279. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_range_tflite_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_range_tflite_test.cc
  7280. index 75a460538aaaa..fc838bee4d98b 100644
  7281. --- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_range_tflite_test.cc
  7282. +++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_range_tflite_test.cc
  7283. @@ -39,7 +39,8 @@ class RaggedRangeOpModel : public SingleOpModel {
  7284. public:
  7285. static TensorType GetType();
  7286. - RaggedRangeOpModel(const std::vector<T>& start, const std::vector<T>& limits,
  7287. + RaggedRangeOpModel(const std::vector<T>& start,
  7288. + const std::vector<T>& limits,
  7289. const std::vector<T>& deltas) {
  7290. const TensorType value_type = GetType();
  7291. std::vector<std::vector<int>> shapes;
  7292. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_tensor_to_tensor_tflite.cc b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_tensor_to_tensor_tflite.cc
  7293. index 09ac76c71b26c..ff5c14b8e5e08 100644
  7294. --- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_tensor_to_tensor_tflite.cc
  7295. +++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_tensor_to_tensor_tflite.cc
  7296. @@ -140,8 +140,10 @@ RuntimeShape TensorShapeFromTensor(const TfLiteTensor& tensor) {
  7297. }
  7298. const TfLiteTensor* GetRowPartitionTensor(
  7299. - const ConversionAttributes& conversion_attributes, TfLiteContext* context,
  7300. - TfLiteNode* node, int dimension) {
  7301. + const ConversionAttributes& conversion_attributes,
  7302. + TfLiteContext* context,
  7303. + TfLiteNode* node,
  7304. + int dimension) {
  7305. if (conversion_attributes.partition_types.front() ==
  7306. tensorflow::RowPartitionType::FIRST_DIM_SIZE) {
  7307. return &context->tensors[node->inputs->data[kFirstPartitionInputIndex + 1 +
  7308. @@ -211,7 +213,9 @@ int GetMaxWidthRowSplit(const TfLiteTensor* tensor) {
  7309. }
  7310. int GetMaxWidth(const ConversionAttributes& conversion_attributes,
  7311. - TfLiteContext* context, TfLiteNode* node, int dimension) {
  7312. + TfLiteContext* context,
  7313. + TfLiteNode* node,
  7314. + int dimension) {
  7315. const TfLiteTensor* tensor = GetRowPartitionTensor(
  7316. conversion_attributes, context, node, dimension - 1);
  7317. switch (conversion_attributes.GetRowPartitionTypeByDimension(dimension - 1)) {
  7318. @@ -226,7 +230,8 @@ int GetMaxWidth(const ConversionAttributes& conversion_attributes,
  7319. }
  7320. RuntimeShape CombineRaggedTensorToTensorShapes(
  7321. - int ragged_rank, const RuntimeShape& output_shape,
  7322. + int ragged_rank,
  7323. + const RuntimeShape& output_shape,
  7324. const RuntimeShape& value_shape) {
  7325. // TODO(mgubin): No checks, see
  7326. // third_party/tensorflow/core/ops/ragged_to_dense_util.cc
  7327. @@ -247,9 +252,13 @@ RuntimeShape CombineRaggedTensorToTensorShapes(
  7328. }
  7329. RuntimeShape CalculateOutputSize(
  7330. - const ConversionAttributes& conversion_attributes, TfLiteContext* context,
  7331. - TfLiteNode* node, int first_dimension, int ragged_rank,
  7332. - const TfLiteTensor& values, const TfLiteTensor& default_value,
  7333. + const ConversionAttributes& conversion_attributes,
  7334. + TfLiteContext* context,
  7335. + TfLiteNode* node,
  7336. + int first_dimension,
  7337. + int ragged_rank,
  7338. + const TfLiteTensor& values,
  7339. + const TfLiteTensor& default_value,
  7340. const TfLiteTensor& output_shape) {
  7341. RuntimeShape values_shape(values.dims->size, values.dims->data);
  7342. RuntimeShape default_value_shape(default_value.dims->size,
  7343. @@ -331,7 +340,8 @@ void CalculateFirstParentOutputIndex(int first_dimension,
  7344. void CalculateOutputIndexValueRowID(const TfLiteTensor& value_rowids,
  7345. const std::vector<int>& parent_output_index,
  7346. int output_index_multiplier,
  7347. - int output_size, std::vector<int>* result) {
  7348. + int output_size,
  7349. + std::vector<int>* result) {
  7350. const RuntimeShape tensor_shape(value_rowids.dims->size,
  7351. value_rowids.dims->data);
  7352. const int index_size = tensor_shape.FlatSize();
  7353. @@ -380,7 +390,8 @@ void CalculateOutputIndexValueRowID(const TfLiteTensor& value_rowids,
  7354. void CalculateOutputIndexRowSplit(const TfLiteTensor& row_split,
  7355. const std::vector<int>& parent_output_index,
  7356. - int output_index_multiplier, int output_size,
  7357. + int output_index_multiplier,
  7358. + int output_size,
  7359. std::vector<int>* result) {
  7360. const RuntimeShape row_split_shape(row_split.dims->size,
  7361. row_split.dims->data);
  7362. @@ -421,10 +432,14 @@ void CalculateOutputIndexRowSplit(const TfLiteTensor& row_split,
  7363. }
  7364. TfLiteStatus CalculateOutputIndex(
  7365. - const ConversionAttributes& conversion_attributes, TfLiteContext* context,
  7366. - TfLiteNode* node, int dimension,
  7367. - const std::vector<int>& parent_output_index, int output_index_multiplier,
  7368. - int output_size, std::vector<int>* result) {
  7369. + const ConversionAttributes& conversion_attributes,
  7370. + TfLiteContext* context,
  7371. + TfLiteNode* node,
  7372. + int dimension,
  7373. + const std::vector<int>& parent_output_index,
  7374. + int output_index_multiplier,
  7375. + int output_size,
  7376. + std::vector<int>* result) {
  7377. const TfLiteTensor* row_partition_tensor =
  7378. GetRowPartitionTensor(conversion_attributes, context, node, dimension);
  7379. auto partition_type =
  7380. @@ -447,7 +462,8 @@ TfLiteStatus CalculateOutputIndex(
  7381. }
  7382. template <typename VALUE_TYPE>
  7383. -void SetOutputT(TfLiteContext* context, int ragged_rank,
  7384. +void SetOutputT(TfLiteContext* context,
  7385. + int ragged_rank,
  7386. const std::vector<int>& output_index,
  7387. const TfLiteTensor& values_tensor,
  7388. const TfLiteTensor& default_value_tensor,
  7389. @@ -522,7 +538,8 @@ void SetOutputT(TfLiteContext* context, int ragged_rank,
  7390. }
  7391. }
  7392. -void SetOutput(TfLiteContext* context, int ragged_rank,
  7393. +void SetOutput(TfLiteContext* context,
  7394. + int ragged_rank,
  7395. const std::vector<int>& output_index,
  7396. const TfLiteTensor& values_tensor,
  7397. const TfLiteTensor& default_value_tensor,
  7398. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_tensor_to_tensor_tflite_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_tensor_to_tensor_tflite_test.cc
  7399. index b1cde57c47c68..2f7a2a95b8478 100644
  7400. --- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_tensor_to_tensor_tflite_test.cc
  7401. +++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_tensor_to_tensor_tflite_test.cc
  7402. @@ -82,7 +82,8 @@ class RaggedTensorToTensorOpModel : public SingleOpModel {
  7403. std::vector<int32> GetOutputInt() { return ExtractVector<int32>(output_); }
  7404. void InvokeFloat(const std::vector<int>& shape,
  7405. - const std::vector<float>& values, float default_value,
  7406. + const std::vector<float>& values,
  7407. + float default_value,
  7408. const std::vector<std::vector<int>>& partition_values) {
  7409. PopulateTensor(input_shape_, shape);
  7410. PopulateTensor(input_values_, values);
  7411. @@ -93,7 +94,8 @@ class RaggedTensorToTensorOpModel : public SingleOpModel {
  7412. SingleOpModel::Invoke();
  7413. }
  7414. void InvokeInt(const std::vector<int>& shape,
  7415. - const std::vector<int32>& values, int32 default_value,
  7416. + const std::vector<int32>& values,
  7417. + int32 default_value,
  7418. const std::vector<std::vector<int>>& partition_values) {
  7419. PopulateTensor(input_shape_, shape);
  7420. PopulateTensor(input_values_, values);
  7421. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/model_converter.cc b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/model_converter.cc
  7422. index 4e2b87de37327..47ba9fdfebcae 100644
  7423. --- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/model_converter.cc
  7424. +++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/model_converter.cc
  7425. @@ -15,8 +15,8 @@ limitations under the License.
  7426. #include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/model_converter.h"
  7427. -#include "absl/status/status.h" // from @com_google_absl
  7428. -#include "absl/strings/str_replace.h" // from @com_google_absl
  7429. +#include "absl/status/status.h" // from @com_google_absl
  7430. +#include "absl/strings/str_replace.h" // from @com_google_absl
  7431. #include "src/sentencepiece_model.pb.h" // from @com_google_sentencepiece
  7432. #include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/decoder_config_generated.h"
  7433. #include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/double_array_trie_builder.h"
  7434. @@ -48,7 +48,8 @@ DecodePrecompiledCharsmap(
  7435. }
  7436. tflite::support::StatusOr<std::string> ConvertSentencepieceModelToFlatBuffer(
  7437. - const std::string& model_config_str, int encoding_offset) {
  7438. + const std::string& model_config_str,
  7439. + int encoding_offset) {
  7440. ::sentencepiece::ModelProto model_config;
  7441. if (!model_config.ParseFromString(model_config_str)) {
  7442. return absl::InvalidArgumentError(
  7443. @@ -128,7 +129,8 @@ tflite::support::StatusOr<std::string> ConvertSentencepieceModelToFlatBuffer(
  7444. tflite::support::StatusOr<std::string>
  7445. ConvertSentencepieceModelToFlatBufferForDecoder(
  7446. - const std::string& model_config_str, int encoding_offset) {
  7447. + const std::string& model_config_str,
  7448. + int encoding_offset) {
  7449. ::sentencepiece::ModelProto model_config;
  7450. if (!model_config.ParseFromString(model_config_str)) {
  7451. return absl::InvalidArgumentError(
  7452. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/model_converter.h b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/model_converter.h
  7453. index 5687b6287d140..03b3596820886 100644
  7454. --- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/model_converter.h
  7455. +++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/model_converter.h
  7456. @@ -27,13 +27,15 @@ namespace sentencepiece {
  7457. // Converts Sentencepiece configuration to flatbuffer format.
  7458. // encoding_offset is used by some encoders that combine different encodings.
  7459. tflite::support::StatusOr<std::string> ConvertSentencepieceModelToFlatBuffer(
  7460. - const std::string& model_config_str, int encoding_offset = 0);
  7461. + const std::string& model_config_str,
  7462. + int encoding_offset = 0);
  7463. // Converts Sentencepiece configuration to flatbuffer format for encoder.
  7464. // encoding_offset is used by some encoders that combine different encodings.
  7465. tflite::support::StatusOr<std::string>
  7466. ConvertSentencepieceModelToFlatBufferForDecoder(
  7467. - const std::string& model_config_str, int encoding_offset = 0);
  7468. + const std::string& model_config_str,
  7469. + int encoding_offset = 0);
  7470. // The functions that are provided for the Python wrapper.
  7471. std::string ConvertSentencepieceModel(const std::string& model_string);
  7472. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_decoder_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_decoder_test.cc
  7473. index 8e130ef73b9b6..94161c2ac4c4e 100644
  7474. --- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_decoder_test.cc
  7475. +++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_decoder_test.cc
  7476. @@ -19,9 +19,9 @@ limitations under the License.
  7477. #include <gmock/gmock.h>
  7478. #include <gtest/gtest.h>
  7479. -#include "absl/flags/flag.h" // from @com_google_absl
  7480. -#include "absl/strings/str_format.h" // from @com_google_absl
  7481. -#include "src/sentencepiece.pb.h" // from @com_google_sentencepiece
  7482. +#include "absl/flags/flag.h" // from @com_google_absl
  7483. +#include "absl/strings/str_format.h" // from @com_google_absl
  7484. +#include "src/sentencepiece.pb.h" // from @com_google_sentencepiece
  7485. #include "src/sentencepiece_processor.h" // from @com_google_sentencepiece
  7486. #include "tensorflow/core/platform/env.h"
  7487. #include "tensorflow_lite_support/cc/test/test_utils.h"
  7488. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder.cc b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder.cc
  7489. index 45fde32237c65..4148f8e96627a 100644
  7490. --- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder.cc
  7491. +++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder.cc
  7492. @@ -31,7 +31,8 @@ const char kSpaceSymbol[] = "\xe2\x96\x81";
  7493. template <typename processing_callback>
  7494. std::tuple<std::string, std::vector<int>> process_string(
  7495. - const std::string& input, const std::vector<int>& offsets,
  7496. + const std::string& input,
  7497. + const std::vector<int>& offsets,
  7498. const processing_callback& pc) {
  7499. std::string result_string;
  7500. result_string.reserve(input.size());
  7501. @@ -78,7 +79,9 @@ std::tuple<int, utils::string_view> remove_extra_whitespaces(const char* data,
  7502. }
  7503. std::tuple<int, utils::string_view> find_replacement(
  7504. - const char* data, int len, const DoubleArrayTrie& dat,
  7505. + const char* data,
  7506. + int len,
  7507. + const DoubleArrayTrie& dat,
  7508. const flatbuffers::Vector<int8_t>& replacements) {
  7509. const auto max_match = dat.LongestPrefixMatch(utils::string_view(data, len));
  7510. if (!max_match.empty()) {
  7511. @@ -94,7 +97,8 @@ std::tuple<int, utils::string_view> find_replacement(
  7512. } // namespace
  7513. std::tuple<std::string, std::vector<int>> NormalizeString(
  7514. - const std::string& in_string, const EncoderConfig& config) {
  7515. + const std::string& in_string,
  7516. + const EncoderConfig& config) {
  7517. std::vector<int> output_offsets;
  7518. std::string result = in_string;
  7519. output_offsets.reserve(in_string.length());
  7520. @@ -145,8 +149,10 @@ std::tuple<std::string, std::vector<int>> NormalizeString(
  7521. EncoderResult EncodeNormalizedString(const std::string& str,
  7522. const std::vector<int>& offsets,
  7523. - const EncoderConfig& config, bool add_bos,
  7524. - bool add_eos, bool reverse) {
  7525. + const EncoderConfig& config,
  7526. + bool add_bos,
  7527. + bool add_eos,
  7528. + bool reverse) {
  7529. const DoubleArrayTrie piece_matcher(config.pieces()->nodes());
  7530. const flatbuffers::Vector<float>* piece_scores = config.pieces_scores();
  7531. const int unknown_code = config.unknown_code();
  7532. @@ -219,8 +225,11 @@ EncoderResult EncodeNormalizedString(const std::string& str,
  7533. return result;
  7534. }
  7535. -EncoderResult EncodeString(const std::string& string, const void* config_buffer,
  7536. - bool add_bos, bool add_eos, bool reverse) {
  7537. +EncoderResult EncodeString(const std::string& string,
  7538. + const void* config_buffer,
  7539. + bool add_bos,
  7540. + bool add_eos,
  7541. + bool reverse) {
  7542. // Get the config from the buffer.
  7543. const EncoderConfig* config = GetEncoderConfig(config_buffer);
  7544. if (config->version() != EncoderVersion::EncoderVersion_SENTENCE_PIECE) {
  7545. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder.h b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder.h
  7546. index 44d6e88f2531c..b89154cbfa396 100644
  7547. --- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder.h
  7548. +++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder.h
  7549. @@ -37,12 +37,16 @@ struct EncoderResult {
  7550. std::vector<int> offsets;
  7551. };
  7552. std::tuple<std::string, std::vector<int>> NormalizeString(
  7553. - const std::string& in_string, const EncoderConfig& config);
  7554. + const std::string& in_string,
  7555. + const EncoderConfig& config);
  7556. // Encodes one string and returns ids and offsets. Takes the configuration as a
  7557. // type-erased buffer.
  7558. -EncoderResult EncodeString(const std::string& string, const void* config_buffer,
  7559. - bool add_bos, bool add_eos, bool reverse);
  7560. +EncoderResult EncodeString(const std::string& string,
  7561. + const void* config_buffer,
  7562. + bool add_bos,
  7563. + bool add_eos,
  7564. + bool reverse);
  7565. } // namespace sentencepiece
  7566. } // namespace custom
  7567. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder_test.cc
  7568. index e2787c785e8c4..dd956a22b26c1 100644
  7569. --- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder_test.cc
  7570. +++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder_test.cc
  7571. @@ -19,10 +19,10 @@ limitations under the License.
  7572. #include <gmock/gmock.h>
  7573. #include <gtest/gtest.h>
  7574. -#include "absl/flags/flag.h" // from @com_google_absl
  7575. -#include "absl/status/status.h" // from @com_google_absl
  7576. -#include "absl/strings/str_format.h" // from @com_google_absl
  7577. -#include "src/sentencepiece.pb.h" // from @com_google_sentencepiece
  7578. +#include "absl/flags/flag.h" // from @com_google_absl
  7579. +#include "absl/status/status.h" // from @com_google_absl
  7580. +#include "absl/strings/str_format.h" // from @com_google_absl
  7581. +#include "src/sentencepiece.pb.h" // from @com_google_sentencepiece
  7582. #include "src/sentencepiece_processor.h" // from @com_google_sentencepiece
  7583. #include "tensorflow/core/platform/env.h"
  7584. #include "tensorflow_lite_support/cc/test/test_utils.h"
  7585. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/py_tflite_registerer.h b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/py_tflite_registerer.h
  7586. index deb4e4ee08dc2..3efcfefc6438d 100644
  7587. --- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/py_tflite_registerer.h
  7588. +++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/py_tflite_registerer.h
  7589. @@ -20,6 +20,6 @@ limitations under the License.
  7590. // C-function that is called from the Python Wrapper.
  7591. extern "C" void TFLite_SentencepieceTokenizerRegisterer(
  7592. - tflite::MutableOpResolver *resolver);
  7593. + tflite::MutableOpResolver* resolver);
  7594. #endif // TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_PY_TFLITE_REGISTERER_H_
  7595. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_detokenizer_tflite.cc b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_detokenizer_tflite.cc
  7596. index 54b34e4e33196..f5be376b45e12 100644
  7597. --- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_detokenizer_tflite.cc
  7598. +++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_detokenizer_tflite.cc
  7599. @@ -35,7 +35,8 @@ namespace detokenizer {
  7600. constexpr int kOutputValuesInd = 0;
  7601. // Initializes text encoder object from serialized parameters.
  7602. -void* Initialize(TfLiteContext* /*context*/, const char* /*buffer*/,
  7603. +void* Initialize(TfLiteContext* /*context*/,
  7604. + const char* /*buffer*/,
  7605. size_t /*length*/) {
  7606. return nullptr;
  7607. }
  7608. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_tokenizer_op.cc b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_tokenizer_op.cc
  7609. index 41fc5aa28bf30..68f8e64492394 100644
  7610. --- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_tokenizer_op.cc
  7611. +++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_tokenizer_op.cc
  7612. @@ -16,16 +16,16 @@ limitations under the License.
  7613. #include <iterator>
  7614. #include <vector>
  7615. -#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder.h"
  7616. -#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_tokenizer.h"
  7617. #include "tensorflow/core/framework/op.h"
  7618. #include "tensorflow/core/framework/op_kernel.h"
  7619. #include "tensorflow/core/framework/shape_inference.h"
  7620. #include "tensorflow/core/framework/tensor.h"
  7621. #include "tensorflow/core/protobuf/error_codes.pb.h"
  7622. +#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder.h"
  7623. +#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_tokenizer.h"
  7624. namespace tensorflow {
  7625. -namespace ops{
  7626. +namespace ops {
  7627. // copied from third_party/tensorflow_text/core/ops/sentencepiece_ops.cc
  7628. REGISTER_OP("TFSentencepieceTokenizeOp")
  7629. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_tokenizer_tflite.cc b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_tokenizer_tflite.cc
  7630. index 8309a6a2616fd..edb0160b508a3 100644
  7631. --- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_tokenizer_tflite.cc
  7632. +++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_tokenizer_tflite.cc
  7633. @@ -16,8 +16,6 @@ limitations under the License.
  7634. /**
  7635. * Sentencepiece tflite tokenizer implementation.
  7636. */
  7637. -#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder.h"
  7638. -#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_tokenizer.h"
  7639. #include "flatbuffers/flexbuffers.h" // from @flatbuffers
  7640. #include "tensorflow/lite/c/common.h"
  7641. #include "tensorflow/lite/context.h"
  7642. @@ -25,6 +23,8 @@ limitations under the License.
  7643. #include "tensorflow/lite/kernels/kernel_util.h"
  7644. #include "tensorflow/lite/model.h"
  7645. #include "tensorflow/lite/string_util.h"
  7646. +#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder.h"
  7647. +#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_tokenizer.h"
  7648. namespace tflite {
  7649. namespace ops {
  7650. @@ -47,7 +47,8 @@ TfLiteIntArray* CreateSizeArray(const std::initializer_list<int>& sizes) {
  7651. } // namespace
  7652. // Initializes text encoder object from serialized parameters.
  7653. -void* Initialize(TfLiteContext* /*context*/, const char* /*buffer*/,
  7654. +void* Initialize(TfLiteContext* /*context*/,
  7655. + const char* /*buffer*/,
  7656. size_t /*length*/) {
  7657. return nullptr;
  7658. }
  7659. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer.cc b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer.cc
  7660. index dad2f0004be06..8096a5008bd12 100644
  7661. --- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer.cc
  7662. +++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer.cc
  7663. @@ -19,10 +19,10 @@ limitations under the License.
  7664. #include <utility>
  7665. #include <vector>
  7666. +#include "libutf/utf.h"
  7667. #include "tensorflow/lite/context.h"
  7668. #include "tensorflow/lite/kernels/kernel_util.h"
  7669. #include "tensorflow/lite/string_util.h"
  7670. -#include "libutf/utf.h"
  7671. constexpr int kInput = 0;
  7672. constexpr int kOutputValues = 0;
  7673. @@ -49,7 +49,7 @@ inline bool OutputIsPaddedTensor(TfLiteNode* node) {
  7674. }
  7675. inline int charntorune(Rune* r, const char* s, int n) {
  7676. - const int bytes_read = chartorune(r, const_cast<char *>(s));
  7677. + const int bytes_read = chartorune(r, const_cast<char*>(s));
  7678. if (bytes_read > n) {
  7679. *r = Runeerror;
  7680. return 0;
  7681. @@ -66,7 +66,8 @@ std::vector<std::pair<const char*, int>> Tokenize(StringRef str) {
  7682. while (n > 0) {
  7683. Rune r;
  7684. int c = charntorune(&r, p, n);
  7685. - if (r == Runeerror) break;
  7686. + if (r == Runeerror)
  7687. + break;
  7688. if (isspacerune(r)) {
  7689. if (start != nullptr) {
  7690. @@ -91,7 +92,8 @@ std::vector<std::pair<const char*, int>> Tokenize(StringRef str) {
  7691. TfLiteStatus WritePaddedOutput(
  7692. const std::vector<std::vector<std::pair<const char*, int>>>& list_of_tokens,
  7693. - const TfLiteTensor* input, TfLiteTensor* output_values) {
  7694. + const TfLiteTensor* input,
  7695. + TfLiteTensor* output_values) {
  7696. TfLiteIntArray* output_shape = TfLiteIntArrayCreate(NumDimensions(input) + 1);
  7697. for (int i = 0; i < NumDimensions(input); ++i) {
  7698. output_shape->data[i] = SizeOfDimension(input, i);
  7699. @@ -118,7 +120,8 @@ TfLiteStatus WritePaddedOutput(
  7700. TfLiteStatus WriteRaggedOutput(
  7701. const std::vector<std::vector<std::pair<const char*, int>>>& list_of_tokens,
  7702. - const TfLiteTensor* input, TfLiteTensor* output_values,
  7703. + const TfLiteTensor* input,
  7704. + TfLiteTensor* output_values,
  7705. std::vector<TfLiteTensor*> nested_row_splits) {
  7706. // The outer dimensions of the ragged tensor are all non-ragged.
  7707. for (int i = 0; i < nested_row_splits.size() - 1; ++i) {
  7708. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer_op_resolver.cc b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer_op_resolver.cc
  7709. index 534fbef4aff2d..6166bc149bc00 100644
  7710. --- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer_op_resolver.cc
  7711. +++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer_op_resolver.cc
  7712. @@ -15,8 +15,8 @@ limitations under the License.
  7713. #include "tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer_op_resolver.h"
  7714. -#include "tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer.h"
  7715. #include "tensorflow/lite/mutable_op_resolver.h"
  7716. +#include "tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer.h"
  7717. namespace tflite {
  7718. namespace ops {
  7719. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/audio_classifier_demo.cc b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/audio_classifier_demo.cc
  7720. index 7447870046f48..904673a95b799 100644
  7721. --- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/audio_classifier_demo.cc
  7722. +++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/audio_classifier_demo.cc
  7723. @@ -28,18 +28,26 @@ limitations under the License.
  7724. #include "absl/flags/parse.h" // from @com_google_absl
  7725. #include "tensorflow_lite_support/examples/task/audio/desktop/audio_classifier_lib.h"
  7726. -ABSL_FLAG(std::string, model_path, "",
  7727. +ABSL_FLAG(std::string,
  7728. + model_path,
  7729. + "",
  7730. "Absolute path to the '.tflite' audio classification model.");
  7731. -ABSL_FLAG(std::string, audio_wav_path, "",
  7732. +ABSL_FLAG(std::string,
  7733. + audio_wav_path,
  7734. + "",
  7735. "Absolute path to the 16-bit PCM WAV file to classify. The WAV "
  7736. "file must be monochannel and has a sampling rate matches the model "
  7737. "expected sampling rate (as in the Metadata). If the WAV file is "
  7738. "longer than what the model requires, only the beginning section is "
  7739. "used for inference.");
  7740. -ABSL_FLAG(float, score_threshold, 0.001f,
  7741. +ABSL_FLAG(float,
  7742. + score_threshold,
  7743. + 0.001f,
  7744. "Apply a filter on the results. Only display classes with score "
  7745. "higher than the threshold.");
  7746. -ABSL_FLAG(bool, use_coral, false,
  7747. +ABSL_FLAG(bool,
  7748. + use_coral,
  7749. + false,
  7750. "If true, inference will be delegated to a connected Coral Edge TPU "
  7751. "device.");
  7752. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/audio_classifier_lib.cc b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/audio_classifier_lib.cc
  7753. index 36d6633d902e3..a843501ec3d75 100644
  7754. --- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/audio_classifier_lib.cc
  7755. +++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/audio_classifier_lib.cc
  7756. @@ -19,7 +19,7 @@ limitations under the License.
  7757. #include <string>
  7758. #include <vector>
  7759. -#include "absl/status/status.h" // from @com_google_absl
  7760. +#include "absl/status/status.h" // from @com_google_absl
  7761. #include "absl/strings/str_format.h" // from @com_google_absl
  7762. #include "tensorflow_lite_support/cc/port/status_macros.h"
  7763. #include "tensorflow_lite_support/cc/port/statusor.h"
  7764. @@ -34,7 +34,8 @@ namespace task {
  7765. namespace audio {
  7766. tflite::support::StatusOr<AudioBuffer> LoadAudioBufferFromFile(
  7767. - const std::string& wav_file, int buffer_size,
  7768. + const std::string& wav_file,
  7769. + int buffer_size,
  7770. std::vector<float>* wav_data) {
  7771. std::string contents = ReadFile(wav_file);
  7772. @@ -55,7 +56,8 @@ tflite::support::StatusOr<AudioBuffer> LoadAudioBufferFromFile(
  7773. }
  7774. tflite::support::StatusOr<ClassificationResult> Classify(
  7775. - const std::string& model_path, const std::string& wav_file,
  7776. + const std::string& model_path,
  7777. + const std::string& wav_file,
  7778. bool use_coral) {
  7779. AudioClassifierOptions options;
  7780. options.mutable_base_options()->mutable_model_file()->set_file_name(
  7781. @@ -97,7 +99,8 @@ void Display(const ClassificationResult& result, float score_threshold) {
  7782. std::cout << absl::StrFormat("\nHead[%d]: %s\n", i, head.head_name());
  7783. for (int j = 0; j < head.classes_size(); j++) {
  7784. const auto& category = head.classes(j);
  7785. - if (category.score() < score_threshold) continue;
  7786. + if (category.score() < score_threshold)
  7787. + continue;
  7788. std::cout << absl::StrFormat("\tcategory[%s]: %.5f\t",
  7789. category.class_name(), category.score());
  7790. if (!category.display_name().empty()) {
  7791. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/audio_classifier_lib.h b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/audio_classifier_lib.h
  7792. index 6d23078ba3e19..13b2d7792e025 100644
  7793. --- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/audio_classifier_lib.h
  7794. +++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/audio_classifier_lib.h
  7795. @@ -28,7 +28,8 @@ namespace audio {
  7796. // than what the model requires, only the beginning section is used for
  7797. // inference.
  7798. tflite::support::StatusOr<ClassificationResult> Classify(
  7799. - const std::string& model_path, const std::string& wav_file,
  7800. + const std::string& model_path,
  7801. + const std::string& wav_file,
  7802. bool use_coral = false);
  7803. // Prints the output classification result in the standard output. It only
  7804. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/bert_nl_classifier_demo.cc b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/bert_nl_classifier_demo.cc
  7805. index 02eed2332b2e4..5203200808d60 100644
  7806. --- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/bert_nl_classifier_demo.cc
  7807. +++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/bert_nl_classifier_demo.cc
  7808. @@ -15,18 +15,22 @@ limitations under the License.
  7809. #include <iostream>
  7810. #include <limits>
  7811. -#include "absl/flags/flag.h" // from @com_google_absl
  7812. -#include "absl/flags/parse.h" // from @com_google_absl
  7813. -#include "absl/status/status.h" // from @com_google_absl
  7814. +#include "absl/flags/flag.h" // from @com_google_absl
  7815. +#include "absl/flags/parse.h" // from @com_google_absl
  7816. +#include "absl/status/status.h" // from @com_google_absl
  7817. #include "absl/strings/str_format.h" // from @com_google_absl
  7818. #include "tensorflow_lite_support/cc/port/statusor.h"
  7819. #include "tensorflow_lite_support/cc/task/core/category.h"
  7820. #include "tensorflow_lite_support/cc/task/text/bert_nl_classifier.h"
  7821. -ABSL_FLAG(std::string, model_path, "",
  7822. +ABSL_FLAG(std::string,
  7823. + model_path,
  7824. + "",
  7825. "Absolute path to the '.tflite' bert classification model.");
  7826. ABSL_FLAG(std::string, text, "", "Text to classify.");
  7827. -ABSL_FLAG(bool, use_coral, false,
  7828. +ABSL_FLAG(bool,
  7829. + use_coral,
  7830. + false,
  7831. "If true, inference will be delegated to a connected Coral Edge TPU "
  7832. "device.");
  7833. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/bert_question_answerer_demo.cc b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/bert_question_answerer_demo.cc
  7834. index 4eaa2bbbdd9f5..f2577cfad54c2 100644
  7835. --- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/bert_question_answerer_demo.cc
  7836. +++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/bert_question_answerer_demo.cc
  7837. @@ -15,19 +15,25 @@ limitations under the License.
  7838. #include <iostream>
  7839. #include <limits>
  7840. -#include "absl/flags/flag.h" // from @com_google_absl
  7841. -#include "absl/flags/parse.h" // from @com_google_absl
  7842. -#include "absl/status/status.h" // from @com_google_absl
  7843. +#include "absl/flags/flag.h" // from @com_google_absl
  7844. +#include "absl/flags/parse.h" // from @com_google_absl
  7845. +#include "absl/status/status.h" // from @com_google_absl
  7846. #include "absl/strings/str_format.h" // from @com_google_absl
  7847. #include "tensorflow_lite_support/cc/port/statusor.h"
  7848. #include "tensorflow_lite_support/cc/task/text/bert_question_answerer.h"
  7849. -ABSL_FLAG(std::string, model_path, "",
  7850. +ABSL_FLAG(std::string,
  7851. + model_path,
  7852. + "",
  7853. "Absolute path to the '.tflite' bert question answerer model.");
  7854. ABSL_FLAG(std::string, question, "", "Question to ask.");
  7855. -ABSL_FLAG(std::string, context, "",
  7856. +ABSL_FLAG(std::string,
  7857. + context,
  7858. + "",
  7859. "Context the asked question is based upon.");
  7860. -ABSL_FLAG(bool, use_coral, false,
  7861. +ABSL_FLAG(bool,
  7862. + use_coral,
  7863. + false,
  7864. "If true, inference will be delegated to a connected Coral Edge TPU "
  7865. "device.");
  7866. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/nl_classifier_demo.cc b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/nl_classifier_demo.cc
  7867. index 49f233ce1e74c..613744ffdb20b 100644
  7868. --- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/nl_classifier_demo.cc
  7869. +++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/nl_classifier_demo.cc
  7870. @@ -15,18 +15,22 @@ limitations under the License.
  7871. #include <iostream>
  7872. #include <limits>
  7873. -#include "absl/flags/flag.h" // from @com_google_absl
  7874. -#include "absl/flags/parse.h" // from @com_google_absl
  7875. -#include "absl/status/status.h" // from @com_google_absl
  7876. +#include "absl/flags/flag.h" // from @com_google_absl
  7877. +#include "absl/flags/parse.h" // from @com_google_absl
  7878. +#include "absl/status/status.h" // from @com_google_absl
  7879. #include "absl/strings/str_format.h" // from @com_google_absl
  7880. #include "tensorflow_lite_support/cc/port/statusor.h"
  7881. #include "tensorflow_lite_support/cc/task/core/category.h"
  7882. #include "tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h"
  7883. -ABSL_FLAG(std::string, model_path, "",
  7884. +ABSL_FLAG(std::string,
  7885. + model_path,
  7886. + "",
  7887. "Absolute path to the '.tflite' classification model.");
  7888. ABSL_FLAG(std::string, text, "", "Text to classify.");
  7889. -ABSL_FLAG(bool, use_coral, false,
  7890. +ABSL_FLAG(bool,
  7891. + use_coral,
  7892. + false,
  7893. "If true, inference will be delegated to a connected Coral Edge TPU "
  7894. "device.");
  7895. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/text_embedder_demo.cc b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/text_embedder_demo.cc
  7896. index 875b5f4a771bd..eca8a002d3293 100644
  7897. --- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/text_embedder_demo.cc
  7898. +++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/text_embedder_demo.cc
  7899. @@ -24,9 +24,9 @@ limitations under the License.
  7900. #include <iostream>
  7901. #include <memory>
  7902. -#include "absl/flags/flag.h" // from @com_google_absl
  7903. -#include "absl/flags/parse.h" // from @com_google_absl
  7904. -#include "absl/status/status.h" // from @com_google_absl
  7905. +#include "absl/flags/flag.h" // from @com_google_absl
  7906. +#include "absl/flags/parse.h" // from @com_google_absl
  7907. +#include "absl/status/status.h" // from @com_google_absl
  7908. #include "absl/strings/str_format.h" // from @com_google_absl
  7909. #include "tensorflow_lite_support/cc/port/configuration_proto_inc.h"
  7910. #include "tensorflow_lite_support/cc/port/status_macros.h"
  7911. @@ -36,19 +36,29 @@ limitations under the License.
  7912. #include "tensorflow_lite_support/cc/task/text/text_embedder.h"
  7913. #include "tensorflow_lite_support/examples/task/text/desktop/universal_sentence_encoder_qa_op_resolver.h"
  7914. -ABSL_FLAG(std::string, model_path, "",
  7915. +ABSL_FLAG(std::string,
  7916. + model_path,
  7917. + "",
  7918. "Absolute path to the '.tflite' text embedder model.");
  7919. -ABSL_FLAG(std::string, first_sentence, "",
  7920. +ABSL_FLAG(std::string,
  7921. + first_sentence,
  7922. + "",
  7923. "First sentence, whose feature vector will be extracted and compared "
  7924. "to the second sentence using cosine similarity.");
  7925. -ABSL_FLAG(std::string, second_sentence, "",
  7926. +ABSL_FLAG(std::string,
  7927. + second_sentence,
  7928. + "",
  7929. "Second sentence, whose feature vector will be extracted and "
  7930. "compared to the first sentence using cosine similarity.");
  7931. -ABSL_FLAG(bool, l2_normalize, false,
  7932. +ABSL_FLAG(bool,
  7933. + l2_normalize,
  7934. + false,
  7935. "If true, the raw feature vectors returned by the image embedder "
  7936. "will be normalized with L2-norm. Generally only needed if the model "
  7937. "doesn't already contain a L2_NORMALIZATION TFLite Op.");
  7938. -ABSL_FLAG(bool, use_coral, false,
  7939. +ABSL_FLAG(bool,
  7940. + use_coral,
  7941. + false,
  7942. "If true, inference will be delegated to a connected Coral Edge TPU "
  7943. "device.");
  7944. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/text_searcher_demo.cc b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/text_searcher_demo.cc
  7945. index 5ea9b7e63b50e..0299428964797 100644
  7946. --- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/text_searcher_demo.cc
  7947. +++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/text_searcher_demo.cc
  7948. @@ -24,9 +24,9 @@ limitations under the License.
  7949. #include <iostream>
  7950. #include <memory>
  7951. -#include "absl/flags/flag.h" // from @com_google_absl
  7952. -#include "absl/flags/parse.h" // from @com_google_absl
  7953. -#include "absl/status/status.h" // from @com_google_absl
  7954. +#include "absl/flags/flag.h" // from @com_google_absl
  7955. +#include "absl/flags/parse.h" // from @com_google_absl
  7956. +#include "absl/status/status.h" // from @com_google_absl
  7957. #include "absl/strings/str_format.h" // from @com_google_absl
  7958. #include "tensorflow_lite_support/cc/port/configuration_proto_inc.h"
  7959. #include "tensorflow_lite_support/cc/port/status_macros.h"
  7960. @@ -39,21 +39,33 @@ limitations under the License.
  7961. #include "tensorflow_lite_support/cc/task/text/text_searcher.h"
  7962. #include "tensorflow_lite_support/examples/task/text/desktop/universal_sentence_encoder_qa_op_resolver.h"
  7963. -ABSL_FLAG(std::string, model_path, "",
  7964. +ABSL_FLAG(std::string,
  7965. + model_path,
  7966. + "",
  7967. "Absolute path to the '.tflite' text embedder model.");
  7968. -ABSL_FLAG(std::string, index_path, "",
  7969. +ABSL_FLAG(std::string,
  7970. + index_path,
  7971. + "",
  7972. "Absolute path to the index to search into. Mandatory only if the "
  7973. "index is not attached to the output tensor metadata of the embedder "
  7974. "model as an AssociatedFile with type SCANN_INDEX_FILE.");
  7975. -ABSL_FLAG(std::string, input_sentence, "",
  7976. +ABSL_FLAG(std::string,
  7977. + input_sentence,
  7978. + "",
  7979. "Input sentence whose nearest-neighbors to search for in the index.");
  7980. -ABSL_FLAG(int32, max_results, 5,
  7981. +ABSL_FLAG(int32,
  7982. + max_results,
  7983. + 5,
  7984. "Maximum number of nearest-neghbors to display.");
  7985. -ABSL_FLAG(bool, l2_normalize, false,
  7986. +ABSL_FLAG(bool,
  7987. + l2_normalize,
  7988. + false,
  7989. "If true, the raw feature vectors returned by the image embedder "
  7990. "will be normalized with L2-norm. Generally only needed if the model "
  7991. "doesn't already contain a L2_NORMALIZATION TFLite Op.");
  7992. -ABSL_FLAG(bool, use_coral, false,
  7993. +ABSL_FLAG(bool,
  7994. + use_coral,
  7995. + false,
  7996. "If true, inference will be delegated to a connected Coral Edge TPU "
  7997. "device.");
  7998. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/universal_sentence_encoder_qa_demo.cc b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/universal_sentence_encoder_qa_demo.cc
  7999. index 076a60a2330af..f7621a5a8a1b4 100644
  8000. --- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/universal_sentence_encoder_qa_demo.cc
  8001. +++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/universal_sentence_encoder_qa_demo.cc
  8002. @@ -14,9 +14,9 @@ limitations under the License.
  8003. ==============================================================================*/
  8004. // Demostration the usage of UniversalSentenceEncoderQA.
  8005. -#include "absl/flags/flag.h" // from @com_google_absl
  8006. -#include "absl/flags/parse.h" // from @com_google_absl
  8007. -#include "absl/status/status.h" // from @com_google_absl
  8008. +#include "absl/flags/flag.h" // from @com_google_absl
  8009. +#include "absl/flags/parse.h" // from @com_google_absl
  8010. +#include "absl/status/status.h" // from @com_google_absl
  8011. #include "absl/strings/str_split.h" // from @com_google_absl
  8012. #include "tensorflow_lite_support/cc/task/text/universal_sentence_encoder_qa.h"
  8013. #include "tensorflow_lite_support/examples/task/text/desktop/universal_sentence_encoder_qa_op_resolver.h"
  8014. @@ -29,12 +29,17 @@ using tflite::task::text::RetrievalOutput;
  8015. using tflite::task::text::UniversalSentenceEncoderQA;
  8016. } // namespace
  8017. -ABSL_FLAG(std::string, model_path, "",
  8018. +ABSL_FLAG(std::string,
  8019. + model_path,
  8020. + "",
  8021. "Absolute path to the '.tflite' UniversalSentenceEncoderQA model.");
  8022. -ABSL_FLAG(std::string, question, "How are you feeling today?",
  8023. +ABSL_FLAG(std::string,
  8024. + question,
  8025. + "How are you feeling today?",
  8026. "Question to ask.");
  8027. ABSL_FLAG(
  8028. - std::string, answers,
  8029. + std::string,
  8030. + answers,
  8031. "I'm not feeling very well.:Paris is the capital of France.:He looks good.",
  8032. "Candidate answers seperated by `:`.");
  8033. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_classifier_demo.cc b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_classifier_demo.cc
  8034. index f29bd2de9c535..0904920faa7dd 100644
  8035. --- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_classifier_demo.cc
  8036. +++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_classifier_demo.cc
  8037. @@ -22,9 +22,9 @@ limitations under the License.
  8038. #include <iostream>
  8039. -#include "absl/flags/flag.h" // from @com_google_absl
  8040. -#include "absl/flags/parse.h" // from @com_google_absl
  8041. -#include "absl/status/status.h" // from @com_google_absl
  8042. +#include "absl/flags/flag.h" // from @com_google_absl
  8043. +#include "absl/flags/parse.h" // from @com_google_absl
  8044. +#include "absl/status/status.h" // from @com_google_absl
  8045. #include "absl/strings/str_format.h" // from @com_google_absl
  8046. #include "tensorflow_lite_support/cc/port/statusor.h"
  8047. #include "tensorflow_lite_support/cc/task/core/external_file_handler.h"
  8048. @@ -36,29 +36,43 @@ limitations under the License.
  8049. #include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h"
  8050. #include "tensorflow_lite_support/cc/task/vision/utils/image_utils.h"
  8051. -ABSL_FLAG(std::string, model_path, "",
  8052. +ABSL_FLAG(std::string,
  8053. + model_path,
  8054. + "",
  8055. "Absolute path to the '.tflite' image classifier model.");
  8056. -ABSL_FLAG(std::string, image_path, "",
  8057. +ABSL_FLAG(std::string,
  8058. + image_path,
  8059. + "",
  8060. "Absolute path to the image to classify. The image must be RGB or "
  8061. "RGBA (grayscale is not supported). The image EXIF orientation "
  8062. "flag, if any, is NOT taken into account.");
  8063. -ABSL_FLAG(int32, max_results, 5,
  8064. +ABSL_FLAG(int32,
  8065. + max_results,
  8066. + 5,
  8067. "Maximum number of classification results to display.");
  8068. -ABSL_FLAG(float, score_threshold, 0,
  8069. +ABSL_FLAG(float,
  8070. + score_threshold,
  8071. + 0,
  8072. "Classification results with a confidence score below this value are "
  8073. "rejected. If >= 0, overrides the score threshold(s) provided in the "
  8074. "TFLite Model Metadata. Ignored otherwise.");
  8075. ABSL_FLAG(
  8076. - std::vector<std::string>, class_name_whitelist, {},
  8077. + std::vector<std::string>,
  8078. + class_name_whitelist,
  8079. + {},
  8080. "Comma-separated list of class names that acts as a whitelist. If "
  8081. "non-empty, classification results whose 'class_name' is not in this list "
  8082. "are filtered out. Mutually exclusive with 'class_name_blacklist'.");
  8083. ABSL_FLAG(
  8084. - std::vector<std::string>, class_name_blacklist, {},
  8085. + std::vector<std::string>,
  8086. + class_name_blacklist,
  8087. + {},
  8088. "Comma-separated list of class names that acts as a blacklist. If "
  8089. "non-empty, classification results whose 'class_name' is in this list "
  8090. "are filtered out. Mutually exclusive with 'class_name_whitelist'.");
  8091. -ABSL_FLAG(bool, use_coral, false,
  8092. +ABSL_FLAG(bool,
  8093. + use_coral,
  8094. + false,
  8095. "If true, inference will be delegated to a connected Coral Edge TPU "
  8096. "device.");
  8097. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_embedder_demo.cc b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_embedder_demo.cc
  8098. index 50d615a486751..f8b1796bc3865 100644
  8099. --- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_embedder_demo.cc
  8100. +++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_embedder_demo.cc
  8101. @@ -26,9 +26,9 @@ limitations under the License.
  8102. #include <iostream>
  8103. -#include "absl/flags/flag.h" // from @com_google_absl
  8104. -#include "absl/flags/parse.h" // from @com_google_absl
  8105. -#include "absl/status/status.h" // from @com_google_absl
  8106. +#include "absl/flags/flag.h" // from @com_google_absl
  8107. +#include "absl/flags/parse.h" // from @com_google_absl
  8108. +#include "absl/status/status.h" // from @com_google_absl
  8109. #include "absl/strings/str_format.h" // from @com_google_absl
  8110. #include "tensorflow_lite_support/cc/port/statusor.h"
  8111. #include "tensorflow_lite_support/cc/task/core/external_file_handler.h"
  8112. @@ -39,28 +39,40 @@ limitations under the License.
  8113. #include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h"
  8114. #include "tensorflow_lite_support/cc/task/vision/utils/image_utils.h"
  8115. -ABSL_FLAG(std::string, model_path, "",
  8116. +ABSL_FLAG(std::string,
  8117. + model_path,
  8118. + "",
  8119. "Absolute path to the '.tflite' image embedder model.");
  8120. -ABSL_FLAG(std::string, first_image_path, "",
  8121. +ABSL_FLAG(std::string,
  8122. + first_image_path,
  8123. + "",
  8124. "Absolute path to the first image, whose feature vector will be "
  8125. "extracted and compared to the second image using cosine similarity. "
  8126. "The image must be RGB or RGBA (grayscale is not supported). The "
  8127. "image EXIF orientation flag, if any, is NOT taken into account.");
  8128. -ABSL_FLAG(std::string, second_image_path, "",
  8129. +ABSL_FLAG(std::string,
  8130. + second_image_path,
  8131. + "",
  8132. "Absolute path to the second image, whose feature vector will be "
  8133. "extracted and compared to the first image using cosine similarity. "
  8134. "The image must be RGB or RGBA (grayscale is not supported). The "
  8135. "image EXIF orientation flag, if any, is NOT taken into account.");
  8136. -ABSL_FLAG(bool, l2_normalize, false,
  8137. +ABSL_FLAG(bool,
  8138. + l2_normalize,
  8139. + false,
  8140. "If true, the raw feature vectors returned by the image embedder "
  8141. "will be normalized with L2-norm. Generally only needed if the model "
  8142. "doesn't already contain a L2_NORMALIZATION TFLite Op.");
  8143. ABSL_FLAG(
  8144. - bool, quantize, false,
  8145. + bool,
  8146. + quantize,
  8147. + false,
  8148. "If true, the raw feature vectors returned by the image embedder will "
  8149. "be quantized to 8 bit integers (uniform quantization) via post-processing "
  8150. "before cosine similarity is computed.");
  8151. -ABSL_FLAG(bool, use_coral, false,
  8152. +ABSL_FLAG(bool,
  8153. + use_coral,
  8154. + false,
  8155. "If true, inference will be delegated to a connected Coral Edge TPU "
  8156. "device.");
  8157. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_searcher_demo.cc b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_searcher_demo.cc
  8158. index b661447614bc7..e4074f76dba5b 100644
  8159. --- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_searcher_demo.cc
  8160. +++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_searcher_demo.cc
  8161. @@ -25,9 +25,9 @@ limitations under the License.
  8162. #include <iostream>
  8163. #include <memory>
  8164. -#include "absl/flags/flag.h" // from @com_google_absl
  8165. -#include "absl/flags/parse.h" // from @com_google_absl
  8166. -#include "absl/status/status.h" // from @com_google_absl
  8167. +#include "absl/flags/flag.h" // from @com_google_absl
  8168. +#include "absl/flags/parse.h" // from @com_google_absl
  8169. +#include "absl/status/status.h" // from @com_google_absl
  8170. #include "absl/strings/str_format.h" // from @com_google_absl
  8171. #include "tensorflow_lite_support/cc/port/status_macros.h"
  8172. #include "tensorflow_lite_support/cc/port/statusor.h"
  8173. @@ -42,23 +42,35 @@ limitations under the License.
  8174. #include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h"
  8175. #include "tensorflow_lite_support/cc/task/vision/utils/image_utils.h"
  8176. -ABSL_FLAG(std::string, model_path, "",
  8177. +ABSL_FLAG(std::string,
  8178. + model_path,
  8179. + "",
  8180. "Absolute path to the '.tflite' image embedder model.");
  8181. -ABSL_FLAG(std::string, index_path, "",
  8182. +ABSL_FLAG(std::string,
  8183. + index_path,
  8184. + "",
  8185. "Absolute path to the index to search into. Mandatory only if the "
  8186. "index is not attached to the output tensor metadata of the embedder "
  8187. "model as an AssociatedFile with type SCANN_INDEX_FILE.");
  8188. -ABSL_FLAG(std::string, image_path, "",
  8189. +ABSL_FLAG(std::string,
  8190. + image_path,
  8191. + "",
  8192. "Absolute path to the image to search. The image must be RGB or "
  8193. "RGBA (grayscale is not supported). The image EXIF orientation "
  8194. "flag, if any, is NOT taken into account.");
  8195. -ABSL_FLAG(int32, max_results, 5,
  8196. +ABSL_FLAG(int32,
  8197. + max_results,
  8198. + 5,
  8199. "Maximum number of nearest-neighbor results to display.");
  8200. -ABSL_FLAG(bool, l2_normalize, false,
  8201. +ABSL_FLAG(bool,
  8202. + l2_normalize,
  8203. + false,
  8204. "If true, the raw feature vectors returned by the image embedder "
  8205. "will be normalized with L2-norm. Generally only needed if the model "
  8206. "doesn't already contain a L2_NORMALIZATION TFLite Op.");
  8207. -ABSL_FLAG(bool, use_coral, false,
  8208. +ABSL_FLAG(bool,
  8209. + use_coral,
  8210. + false,
  8211. "If true, inference will be delegated to a connected Coral Edge TPU "
  8212. "device.");
  8213. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_segmenter_demo.cc b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_segmenter_demo.cc
  8214. index 5a566ecbcf921..fdc787288fa06 100644
  8215. --- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_segmenter_demo.cc
  8216. +++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_segmenter_demo.cc
  8217. @@ -23,10 +23,10 @@ limitations under the License.
  8218. #include <iostream>
  8219. -#include "absl/flags/flag.h" // from @com_google_absl
  8220. -#include "absl/flags/parse.h" // from @com_google_absl
  8221. -#include "absl/status/status.h" // from @com_google_absl
  8222. -#include "absl/strings/match.h" // from @com_google_absl
  8223. +#include "absl/flags/flag.h" // from @com_google_absl
  8224. +#include "absl/flags/parse.h" // from @com_google_absl
  8225. +#include "absl/status/status.h" // from @com_google_absl
  8226. +#include "absl/strings/match.h" // from @com_google_absl
  8227. #include "absl/strings/str_format.h" // from @com_google_absl
  8228. #include "tensorflow_lite_support/cc/port/statusor.h"
  8229. #include "tensorflow_lite_support/cc/task/core/external_file_handler.h"
  8230. @@ -37,16 +37,24 @@ limitations under the License.
  8231. #include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h"
  8232. #include "tensorflow_lite_support/cc/task/vision/utils/image_utils.h"
  8233. -ABSL_FLAG(std::string, model_path, "",
  8234. +ABSL_FLAG(std::string,
  8235. + model_path,
  8236. + "",
  8237. "Absolute path to the '.tflite' image segmenter model.");
  8238. -ABSL_FLAG(std::string, image_path, "",
  8239. +ABSL_FLAG(std::string,
  8240. + image_path,
  8241. + "",
  8242. "Absolute path to the image to segment. The image must be RGB or "
  8243. "RGBA (grayscale is not supported). The image EXIF orientation "
  8244. "flag, if any, is NOT taken into account.");
  8245. -ABSL_FLAG(std::string, output_mask_png, "",
  8246. +ABSL_FLAG(std::string,
  8247. + output_mask_png,
  8248. + "",
  8249. "Absolute path to the output category mask (confidence masks outputs "
  8250. "are not supported by this tool). Must have a '.png' extension.");
  8251. -ABSL_FLAG(bool, use_coral, false,
  8252. +ABSL_FLAG(bool,
  8253. + use_coral,
  8254. + false,
  8255. "If true, inference will be delegated to a connected Coral Edge TPU "
  8256. "device.");
  8257. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/object_detector_demo.cc b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/object_detector_demo.cc
  8258. index 20f7403207c2e..fd000fccf2f29 100644
  8259. --- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/object_detector_demo.cc
  8260. +++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/object_detector_demo.cc
  8261. @@ -24,10 +24,10 @@ limitations under the License.
  8262. #include <iostream>
  8263. #include <limits>
  8264. -#include "absl/flags/flag.h" // from @com_google_absl
  8265. -#include "absl/flags/parse.h" // from @com_google_absl
  8266. -#include "absl/status/status.h" // from @com_google_absl
  8267. -#include "absl/strings/match.h" // from @com_google_absl
  8268. +#include "absl/flags/flag.h" // from @com_google_absl
  8269. +#include "absl/flags/parse.h" // from @com_google_absl
  8270. +#include "absl/status/status.h" // from @com_google_absl
  8271. +#include "absl/strings/match.h" // from @com_google_absl
  8272. #include "absl/strings/str_format.h" // from @com_google_absl
  8273. #include "tensorflow_lite_support/cc/port/statusor.h"
  8274. #include "tensorflow_lite_support/cc/task/core/external_file_handler.h"
  8275. @@ -40,32 +40,48 @@ limitations under the License.
  8276. #include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h"
  8277. #include "tensorflow_lite_support/cc/task/vision/utils/image_utils.h"
  8278. -ABSL_FLAG(std::string, model_path, "",
  8279. +ABSL_FLAG(std::string,
  8280. + model_path,
  8281. + "",
  8282. "Absolute path to the '.tflite' object detector model.");
  8283. -ABSL_FLAG(std::string, image_path, "",
  8284. +ABSL_FLAG(std::string,
  8285. + image_path,
  8286. + "",
  8287. "Absolute path to the image to run detection on. The image must be "
  8288. "RGB or RGBA (grayscale is not supported). The image EXIF "
  8289. "orientation flag, if any, is NOT taken into account.");
  8290. -ABSL_FLAG(std::string, output_png, "",
  8291. +ABSL_FLAG(std::string,
  8292. + output_png,
  8293. + "",
  8294. "Absolute path to a file where to draw the detection results on top "
  8295. "of the input image. Must have a '.png' extension.");
  8296. -ABSL_FLAG(int32, max_results, 5,
  8297. +ABSL_FLAG(int32,
  8298. + max_results,
  8299. + 5,
  8300. "Maximum number of detection results to display.");
  8301. ABSL_FLAG(
  8302. - float, score_threshold, std::numeric_limits<float>::lowest(),
  8303. + float,
  8304. + score_threshold,
  8305. + std::numeric_limits<float>::lowest(),
  8306. "Detection results with a confidence score below this value are "
  8307. "rejected. If specified, overrides the score threshold(s) provided in the "
  8308. "TFLite Model Metadata. Ignored otherwise.");
  8309. ABSL_FLAG(
  8310. - std::vector<std::string>, class_name_whitelist, {},
  8311. + std::vector<std::string>,
  8312. + class_name_whitelist,
  8313. + {},
  8314. "Comma-separated list of class names that acts as a whitelist. If "
  8315. "non-empty, detections results whose 'class_name' is not in this list "
  8316. "are filtered out. Mutually exclusive with 'class_name_blacklist'.");
  8317. -ABSL_FLAG(std::vector<std::string>, class_name_blacklist, {},
  8318. +ABSL_FLAG(std::vector<std::string>,
  8319. + class_name_blacklist,
  8320. + {},
  8321. "Comma-separated list of class names that acts as a blacklist. If "
  8322. "non-empty, detections results whose 'class_name' is in this list "
  8323. "are filtered out. Mutually exclusive with 'class_name_whitelist'.");
  8324. -ABSL_FLAG(bool, use_coral, false,
  8325. +ABSL_FLAG(bool,
  8326. + use_coral,
  8327. + false,
  8328. "If true, inference will be delegated to a connected Coral Edge TPU "
  8329. "device.");
  8330. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/sources/TFLCommon.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/sources/TFLCommon.h
  8331. index a4fee55abe158..2ca42fb7f3fbe 100644
  8332. --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/sources/TFLCommon.h
  8333. +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/sources/TFLCommon.h
  8334. @@ -56,7 +56,8 @@ typedef NS_ENUM(NSUInteger, TFLSupportErrorCode) {
  8335. /** TensorFlow Lite metadata error codes. */
  8336. - /** Unexpected schema version (aka file_identifier) in the Metadata FlatBuffer. */
  8337. + /** Unexpected schema version (aka file_identifier) in the Metadata
  8338. + FlatBuffer. */
  8339. TFLSupportErrorCodeMetadataInvalidSchemaVersionError = 200,
  8340. /** No such associated file within metadata, or file has not been packed. */
  8341. @@ -198,11 +199,13 @@ typedef NS_ENUM(NSUInteger, TFLSupportErrorCode) {
  8342. */
  8343. TFLSupportErrorCodeImageProcessingBackendError,
  8344. - /** kNotFound indicates some requested entity (such as a file or directory) was not found. */
  8345. + /** kNotFound indicates some requested entity (such as a file or directory)
  8346. + was not found. */
  8347. TFLSupportErrorCodeNotFoundError = 900,
  8348. - /** kInternal indicates an internal error has occurred and some invariants expected by the
  8349. - * underlying system have not been satisfied. This error code is reserved for serious errors.
  8350. + /** kInternal indicates an internal error has occurred and some invariants
  8351. + * expected by the underlying system have not been satisfied. This error code
  8352. + * is reserved for serious errors.
  8353. */
  8354. TFLSupportErrorCodeInternalError,
  8355. } NS_SWIFT_NAME(SupportErrorCode);
  8356. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/sources/TFLCommonUtils.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/sources/TFLCommonUtils.h
  8357. index f3d71984a3213..58710c6f8eeeb 100644
  8358. --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/sources/TFLCommonUtils.h
  8359. +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/sources/TFLCommonUtils.h
  8360. @@ -25,36 +25,36 @@ NS_ASSUME_NONNULL_BEGIN
  8361. *
  8362. * @param code Error code.
  8363. * @param description Error description.
  8364. - * @param error Pointer to the memory location where the created error should be saved. If `nil`,
  8365. - * no error will be saved.
  8366. + * @param error Pointer to the memory location where the created error should be
  8367. + * saved. If `nil`, no error will be saved.
  8368. */
  8369. -+ (void)createCustomError:(NSError **)error
  8370. ++ (void)createCustomError:(NSError**)error
  8371. withCode:(NSInteger)code
  8372. - description:(NSString *)description;
  8373. + description:(NSString*)description;
  8374. /**
  8375. * Converts a C library error, TfLiteSupportError to an NSError.
  8376. *
  8377. * @param supportError C library error.
  8378. - * @param error Pointer to the memory location where the created error should be saved. If `nil`,
  8379. - * no error will be saved.
  8380. + * @param error Pointer to the memory location where the created error should be
  8381. + * saved. If `nil`, no error will be saved.
  8382. */
  8383. -+ (BOOL)checkCError:(TfLiteSupportError *)supportError toError:(NSError **)error;
  8384. ++ (BOOL)checkCError:(TfLiteSupportError*)supportError toError:(NSError**)error;
  8385. /**
  8386. - * Allocates a block of memory with the specified size and returns a pointer to it. If memory
  8387. - * cannot be allocated because of an invalid memSize, it saves an error. In other cases, it
  8388. - * terminates program execution.
  8389. + * Allocates a block of memory with the specified size and returns a pointer to
  8390. + * it. If memory cannot be allocated because of an invalid memSize, it saves an
  8391. + * error. In other cases, it terminates program execution.
  8392. *
  8393. * @param memSize size of memory to be allocated
  8394. - * @param error Pointer to the memory location where errors if any should be saved. If `nil`, no
  8395. - * error will be saved.
  8396. + * @param error Pointer to the memory location where errors if any should be
  8397. + * saved. If `nil`, no error will be saved.
  8398. *
  8399. - * @return Pointer to the allocated block of memory on successfull allocation. nil in case as
  8400. - * error is encountered because of invalid memSize. If failure is due to any other reason, method
  8401. - * terminates program execution.
  8402. + * @return Pointer to the allocated block of memory on successfull allocation.
  8403. + * nil in case as error is encountered because of invalid memSize. If failure is
  8404. + * due to any other reason, method terminates program execution.
  8405. */
  8406. -+ (void *)mallocWithSize:(size_t)memSize error:(NSError **)error;
  8407. ++ (void*)mallocWithSize:(size_t)memSize error:(NSError**)error;
  8408. @end
  8409. NS_ASSUME_NONNULL_END
  8410. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/sources/TFLCommonUtils.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/sources/TFLCommonUtils.m
  8411. index 3904b0ba11d68..9e23b5b571386 100644
  8412. --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/sources/TFLCommonUtils.m
  8413. +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/sources/TFLCommonUtils.m
  8414. @@ -20,23 +20,26 @@ static NSString *const TFLSupportTaskErrorDomain = @"org.tensorflow.lite.tasks";
  8415. @implementation TFLCommonUtils
  8416. -+ (void)createCustomError:(NSError **)error
  8417. ++ (void)createCustomError:(NSError**)error
  8418. withCode:(NSInteger)code
  8419. - description:(NSString *)description {
  8420. + description:(NSString*)description {
  8421. if (error) {
  8422. - *error = [NSError errorWithDomain:TFLSupportTaskErrorDomain
  8423. - code:code
  8424. - userInfo:@{NSLocalizedDescriptionKey : description}];
  8425. + *error =
  8426. + [NSError errorWithDomain:TFLSupportTaskErrorDomain
  8427. + code:code
  8428. + userInfo:@{NSLocalizedDescriptionKey : description}];
  8429. }
  8430. }
  8431. -+ (BOOL)checkCError:(TfLiteSupportError *)supportError toError:(NSError **)error {
  8432. ++ (BOOL)checkCError:(TfLiteSupportError*)supportError toError:(NSError**)error {
  8433. if (!supportError) {
  8434. return YES;
  8435. }
  8436. - NSString *description = [NSString stringWithCString:supportError->message
  8437. + NSString* description = [NSString stringWithCString:supportError->message
  8438. encoding:NSUTF8StringEncoding];
  8439. - [self createCustomError:error withCode:supportError->code description:description];
  8440. + [self createCustomError:error
  8441. + withCode:supportError->code
  8442. + description:description];
  8443. return NO;
  8444. }
  8445. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/audio/core/sources/TFLFloatBuffer.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/audio/core/sources/TFLFloatBuffer.h
  8446. index 79b6ba238e982..a5db97038a047 100644
  8447. --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/audio/core/sources/TFLFloatBuffer.h
  8448. +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/audio/core/sources/TFLFloatBuffer.h
  8449. @@ -23,26 +23,28 @@ NS_ASSUME_NONNULL_BEGIN
  8450. @property(nonatomic, readonly) NSUInteger size;
  8451. /** Pointer to float array wrapped by `TFLFloatBuffer`. */
  8452. -@property(nonatomic, readonly) float *data;
  8453. +@property(nonatomic, readonly) float* data;
  8454. /**
  8455. - * Initializes a new `TFLFloatBuffer` by copying the elements of the given float data array.
  8456. + * Initializes a new `TFLFloatBuffer` by copying the elements of the given float
  8457. + * data array.
  8458. *
  8459. - * @param data A pointer to a float data array whose values are to be copied into the buffer.
  8460. + * @param data A pointer to a float data array whose values are to be copied
  8461. + * into the buffer.
  8462. * @param size Size of the array float data array.
  8463. *
  8464. - * @return A new instance of `TFLFloatBuffer` initialized with the elements of the given float data
  8465. - * array.
  8466. + * @return A new instance of `TFLFloatBuffer` initialized with the elements of
  8467. + * the given float data array.
  8468. */
  8469. -- (instancetype)initWithData:(float *)data size:(NSUInteger)size;
  8470. +- (instancetype)initWithData:(float*)data size:(NSUInteger)size;
  8471. /**
  8472. * Initializes a `TFLFloatBuffer` of the specified size with zeros.
  8473. *
  8474. * @param size Number of elements the `TFLFloatBuffer` can hold.
  8475. *
  8476. - * @return A new instance of `TFLFloatBuffer` of the given size with all elements initialized to
  8477. - * zero.
  8478. + * @return A new instance of `TFLFloatBuffer` of the given size with all
  8479. + * elements initialized to zero.
  8480. */
  8481. - (instancetype)initWithSize:(NSUInteger)size;
  8482. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/audio/core/sources/TFLFloatBuffer.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/audio/core/sources/TFLFloatBuffer.m
  8483. index 24d50affb27aa..d32fc4363efc2 100644
  8484. --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/audio/core/sources/TFLFloatBuffer.m
  8485. +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/audio/core/sources/TFLFloatBuffer.m
  8486. @@ -16,7 +16,7 @@
  8487. @implementation TFLFloatBuffer
  8488. -- (instancetype)initWithData:(float *)data size:(NSUInteger)size {
  8489. +- (instancetype)initWithData:(float*)data size:(NSUInteger)size {
  8490. self = [self init];
  8491. if (self) {
  8492. _size = size;
  8493. @@ -43,7 +43,7 @@
  8494. return self;
  8495. }
  8496. -- (id)copyWithZone:(NSZone *)zone {
  8497. +- (id)copyWithZone:(NSZone*)zone {
  8498. return [[TFLFloatBuffer alloc] initWithData:_data size:_size];
  8499. }
  8500. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/audio/core/sources/TFLRingBuffer.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/audio/core/sources/TFLRingBuffer.h
  8501. index 5a0ab68974b88..b300de6b94d89 100644
  8502. --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/audio/core/sources/TFLRingBuffer.h
  8503. +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/audio/core/sources/TFLRingBuffer.h
  8504. @@ -17,13 +17,14 @@
  8505. NS_ASSUME_NONNULL_BEGIN
  8506. -/** An wrapper class which stores a buffer that is written in circular fashion. */
  8507. +/** An wrapper class which stores a buffer that is written in circular fashion.
  8508. + */
  8509. @interface TFLRingBuffer : NSObject
  8510. /**
  8511. * A copy of all the internal ring buffer elements in order.
  8512. */
  8513. -@property(nullable, nonatomic, readonly) TFLFloatBuffer *floatBuffer;
  8514. +@property(nullable, nonatomic, readonly) TFLFloatBuffer* floatBuffer;
  8515. /**
  8516. * Capacity of the ring buffer in number of elements.
  8517. @@ -36,34 +37,37 @@ NS_ASSUME_NONNULL_BEGIN
  8518. *
  8519. * @param size Size of the ring buffer.
  8520. *
  8521. - * @return A new instance of `TFLRingBuffer` with the given size and all elements
  8522. - * initialized to zero.
  8523. + * @return A new instance of `TFLRingBuffer` with the given size and all
  8524. + * elements initialized to zero.
  8525. */
  8526. - (instancetype)initWithBufferSize:(NSUInteger)size;
  8527. /**
  8528. - * Loads a slice of a float array to the ring buffer. If the float array is longer than ring
  8529. - * buffer's capacity, samples with lower indices in the array will be ignored.
  8530. + * Loads a slice of a float array to the ring buffer. If the float array is
  8531. + * longer than ring buffer's capacity, samples with lower indices in the array
  8532. + * will be ignored.
  8533. *
  8534. * @return Boolean indicating success or failure of loading operation.
  8535. */
  8536. -- (BOOL)loadBuffer:(TFLFloatBuffer *)sourceBuffer
  8537. +- (BOOL)loadBuffer:(TFLFloatBuffer*)sourceBuffer
  8538. offset:(NSUInteger)offset
  8539. size:(NSUInteger)size
  8540. - error:(NSError **)error;
  8541. + error:(NSError**)error;
  8542. /**
  8543. - * Returns a `TFLFloatBuffer` with a copy of size number of the ring buffer elements in order
  8544. - * starting at offset, i.e, buffer[offset:offset+size].
  8545. + * Returns a `TFLFloatBuffer` with a copy of size number of the ring buffer
  8546. + * elements in order starting at offset, i.e, buffer[offset:offset+size].
  8547. *
  8548. - * @param offset Offset in the ring buffer from which elements are to be returned.
  8549. + * @param offset Offset in the ring buffer from which elements are to be
  8550. + * returned.
  8551. *
  8552. * @param size Number of elements to be returned.
  8553. *
  8554. - * @return A new `TFLFloatBuffer` if offset + size is within the bounds of the ring buffer,
  8555. - * otherwise nil.
  8556. + * @return A new `TFLFloatBuffer` if offset + size is within the bounds of the
  8557. + * ring buffer, otherwise nil.
  8558. */
  8559. -- (nullable TFLFloatBuffer *)floatBufferWithOffset:(NSUInteger)offset size:(NSUInteger)size;
  8560. +- (nullable TFLFloatBuffer*)floatBufferWithOffset:(NSUInteger)offset
  8561. + size:(NSUInteger)size;
  8562. /**
  8563. * Clears the `TFLRingBuffer` by setting all the elements to zero .
  8564. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/audio/core/sources/TFLRingBuffer.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/audio/core/sources/TFLRingBuffer.m
  8565. index 675f7058fff61..57495409f51c8 100644
  8566. --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/audio/core/sources/TFLRingBuffer.m
  8567. +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/audio/core/sources/TFLRingBuffer.m
  8568. @@ -18,7 +18,7 @@
  8569. @implementation TFLRingBuffer {
  8570. NSUInteger _nextIndex;
  8571. - TFLFloatBuffer *_buffer;
  8572. + TFLFloatBuffer* _buffer;
  8573. }
  8574. - (instancetype)initWithBufferSize:(NSUInteger)size {
  8575. @@ -29,18 +29,18 @@
  8576. return self;
  8577. }
  8578. -- (BOOL)loadBuffer:(TFLFloatBuffer *)sourceBuffer
  8579. +- (BOOL)loadBuffer:(TFLFloatBuffer*)sourceBuffer
  8580. offset:(NSUInteger)offset
  8581. size:(NSUInteger)size
  8582. - error:(NSError **)error {
  8583. + error:(NSError**)error {
  8584. NSUInteger sizeToCopy = size;
  8585. NSUInteger newOffset = offset;
  8586. if (offset + size > sourceBuffer.size) {
  8587. - [TFLCommonUtils
  8588. - createCustomError:error
  8589. - withCode:TFLSupportErrorCodeInvalidArgumentError
  8590. - description:@"offset + size exceeds the maximum size of the source buffer."];
  8591. + [TFLCommonUtils createCustomError:error
  8592. + withCode:TFLSupportErrorCodeInvalidArgumentError
  8593. + description:@"offset + size exceeds the maximum size "
  8594. + @"of the source buffer."];
  8595. return NO;
  8596. }
  8597. @@ -51,13 +51,15 @@
  8598. newOffset = offset + (size - _buffer.size);
  8599. }
  8600. - // If the new nextIndex + sizeToCopy is smaller than the size of the ring buffer directly
  8601. - // copy all elements to the end of the ring buffer.
  8602. + // If the new nextIndex + sizeToCopy is smaller than the size of the ring
  8603. + // buffer directly copy all elements to the end of the ring buffer.
  8604. if (_nextIndex + sizeToCopy < _buffer.size) {
  8605. - memcpy(_buffer.data + _nextIndex, sourceBuffer.data + newOffset, sizeof(float) * sizeToCopy);
  8606. + memcpy(_buffer.data + _nextIndex, sourceBuffer.data + newOffset,
  8607. + sizeof(float) * sizeToCopy);
  8608. } else {
  8609. NSUInteger endChunkSize = _buffer.size - _nextIndex;
  8610. - memcpy(_buffer.data + _nextIndex, sourceBuffer.data + newOffset, sizeof(float) * endChunkSize);
  8611. + memcpy(_buffer.data + _nextIndex, sourceBuffer.data + newOffset,
  8612. + sizeof(float) * endChunkSize);
  8613. NSUInteger startChunkSize = sizeToCopy - endChunkSize;
  8614. memcpy(_buffer.data, sourceBuffer.data + newOffset + endChunkSize,
  8615. @@ -69,16 +71,17 @@
  8616. return YES;
  8617. }
  8618. -- (TFLFloatBuffer *)floatBuffer {
  8619. +- (TFLFloatBuffer*)floatBuffer {
  8620. return [self floatBufferWithOffset:0 size:self.size];
  8621. }
  8622. -- (nullable TFLFloatBuffer *)floatBufferWithOffset:(NSUInteger)offset size:(NSUInteger)size {
  8623. +- (nullable TFLFloatBuffer*)floatBufferWithOffset:(NSUInteger)offset
  8624. + size:(NSUInteger)size {
  8625. if (offset + size > _buffer.size) {
  8626. return nil;
  8627. }
  8628. - TFLFloatBuffer *bufferToReturn = [[TFLFloatBuffer alloc] initWithSize:size];
  8629. + TFLFloatBuffer* bufferToReturn = [[TFLFloatBuffer alloc] initWithSize:size];
  8630. // Return buffer in correct order.
  8631. // Compute offset in flat ring buffer array considering warping.
  8632. @@ -86,17 +89,21 @@
  8633. // If no; elements to be copied are within the end of the flat ring buffer.
  8634. if ((correctOffset + size) <= _buffer.size) {
  8635. - memcpy(bufferToReturn.data, _buffer.data + correctOffset, sizeof(float) * size);
  8636. + memcpy(bufferToReturn.data, _buffer.data + correctOffset,
  8637. + sizeof(float) * size);
  8638. } else {
  8639. - // If no; elements to be copied warps around to the beginning of the ring buffer.
  8640. - // Copy the chunk starting at ringBuffer[nextIndex + offset : size] to
  8641. - // beginning of the result array.
  8642. + // If no; elements to be copied warps around to the beginning of the ring
  8643. + // buffer. Copy the chunk starting at ringBuffer[nextIndex + offset : size]
  8644. + // to beginning of the result array.
  8645. NSInteger endChunkSize = _buffer.size - correctOffset;
  8646. - memcpy(bufferToReturn.data, _buffer.data + correctOffset, sizeof(float) * endChunkSize);
  8647. + memcpy(bufferToReturn.data, _buffer.data + correctOffset,
  8648. + sizeof(float) * endChunkSize);
  8649. - // Next copy the chunk starting at ringBuffer[0 : size - endChunkSize] to the result array.
  8650. + // Next copy the chunk starting at ringBuffer[0 : size - endChunkSize] to
  8651. + // the result array.
  8652. NSInteger firstChunkSize = size - endChunkSize;
  8653. - memcpy(bufferToReturn.data + endChunkSize, _buffer.data, sizeof(float) * firstChunkSize);
  8654. + memcpy(bufferToReturn.data + endChunkSize, _buffer.data,
  8655. + sizeof(float) * firstChunkSize);
  8656. }
  8657. return bufferToReturn;
  8658. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions+Helpers.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions+Helpers.h
  8659. index a117bd7b3c4c3..5058f7c9a5a7b 100644
  8660. --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions+Helpers.h
  8661. +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions+Helpers.h
  8662. @@ -18,7 +18,7 @@
  8663. NS_ASSUME_NONNULL_BEGIN
  8664. @interface TFLBaseOptions (Helpers)
  8665. -- (void)copyToCOptions:(TfLiteBaseOptions *)cBaseOptions;
  8666. +- (void)copyToCOptions:(TfLiteBaseOptions*)cBaseOptions;
  8667. @end
  8668. NS_ASSUME_NONNULL_END
  8669. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions.h
  8670. index 330132f4ba138..7ab7e7240791e 100644
  8671. --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions.h
  8672. +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions.h
  8673. @@ -19,10 +19,10 @@ NS_ASSUME_NONNULL_BEGIN
  8674. NS_SWIFT_NAME(CpuSettings)
  8675. @interface TFLCpuSettings : NSObject <NSCopying>
  8676. -/** Specifies the number of threads to be used for TFLite ops that support multi-threadingwhen
  8677. - * running inference with CPU.
  8678. - * @discussion This property hould be greater than 0 or equal to -1. Setting it to -1 has the
  8679. - * effect to let TFLite runtime set the value.
  8680. +/** Specifies the number of threads to be used for TFLite ops that support
  8681. + * multi-threadingwhen running inference with CPU.
  8682. + * @discussion This property hould be greater than 0 or equal to -1. Setting it
  8683. + * to -1 has the effect to let TFLite runtime set the value.
  8684. */
  8685. @property(nonatomic) int numThreads;
  8686. @@ -35,7 +35,7 @@ NS_SWIFT_NAME(ComputeSettings)
  8687. @interface TFLComputeSettings : NSObject <NSCopying>
  8688. /** Holds cpu settings. */
  8689. -@property(nonatomic, copy) TFLCpuSettings *cpuSettings;
  8690. +@property(nonatomic, copy) TFLCpuSettings* cpuSettings;
  8691. @end
  8692. @@ -46,30 +46,32 @@ NS_SWIFT_NAME(ExternalFile)
  8693. @interface TFLExternalFile : NSObject <NSCopying>
  8694. /** Path to the file in bundle. */
  8695. -@property(nonatomic, copy) NSString *filePath;
  8696. +@property(nonatomic, copy) NSString* filePath;
  8697. /// Add provision for other sources in future.
  8698. @end
  8699. /**
  8700. - * Holds the base options that is used for creation of any type of task. It has fields with
  8701. - * important information acceleration configuration, tflite model source etc.
  8702. + * Holds the base options that is used for creation of any type of task. It has
  8703. + * fields with important information acceleration configuration, tflite model
  8704. + * source etc.
  8705. */
  8706. NS_SWIFT_NAME(BaseOptions)
  8707. @interface TFLBaseOptions : NSObject <NSCopying>
  8708. /**
  8709. - * The external model file, as a single standalone TFLite file. It could be packed with TFLite Model
  8710. - * Metadata[1] and associated files if exist. Fail to provide the necessary metadata and associated
  8711. - * files might result in errors.
  8712. + * The external model file, as a single standalone TFLite file. It could be
  8713. + * packed with TFLite Model Metadata[1] and associated files if exist. Fail to
  8714. + * provide the necessary metadata and associated files might result in errors.
  8715. */
  8716. -@property(nonatomic, copy) TFLExternalFile *modelFile;
  8717. +@property(nonatomic, copy) TFLExternalFile* modelFile;
  8718. /**
  8719. - * Holds settings for one possible acceleration configuration including.cpu/gpu settings.
  8720. - * Please see documentation of TfLiteComputeSettings and its members for more details.
  8721. + * Holds settings for one possible acceleration configuration including.cpu/gpu
  8722. + * settings. Please see documentation of TfLiteComputeSettings and its members
  8723. + * for more details.
  8724. */
  8725. -@property(nonatomic, copy) TFLComputeSettings *computeSettings;
  8726. +@property(nonatomic, copy) TFLComputeSettings* computeSettings;
  8727. @end
  8728. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLCategory+Helpers.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLCategory+Helpers.h
  8729. index 617fa3ae7120e..6f515e46744b9 100644
  8730. --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLCategory+Helpers.h
  8731. +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLCategory+Helpers.h
  8732. @@ -30,7 +30,7 @@ NS_ASSUME_NONNULL_BEGIN
  8733. * results returned by inference methods of the iOS TF Lite Task Classification
  8734. * tasks.
  8735. */
  8736. -+ (TFLCategory *)categoryWithCCategory:(TfLiteCategory *)cCategory;
  8737. ++ (TFLCategory*)categoryWithCCategory:(TfLiteCategory*)cCategory;
  8738. @end
  8739. NS_ASSUME_NONNULL_END
  8740. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLCategory+Helpers.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLCategory+Helpers.m
  8741. index 7d49c36aa48c9..4139525500a59 100644
  8742. --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLCategory+Helpers.m
  8743. +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLCategory+Helpers.m
  8744. @@ -19,8 +19,8 @@
  8745. + (TFLCategory *)categoryWithCCategory:(TfLiteCategory *)cCategory {
  8746. if (cCategory == nil) return nil;
  8747. - NSString *displayName;
  8748. - NSString *label;
  8749. + NSString* displayName;
  8750. + NSString* label;
  8751. if (cCategory->display_name != nil) {
  8752. displayName = [NSString stringWithCString:cCategory->display_name
  8753. @@ -28,7 +28,8 @@
  8754. }
  8755. if (cCategory->label != nil) {
  8756. - label = [NSString stringWithCString:cCategory->label encoding:NSUTF8StringEncoding];
  8757. + label = [NSString stringWithCString:cCategory->label
  8758. + encoding:NSUTF8StringEncoding];
  8759. }
  8760. return [[TFLCategory alloc] initWithIndex:(NSInteger)cCategory->index
  8761. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLCategory.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLCategory.h
  8762. index 91060ef4f1840..5c521f2239ab7 100644
  8763. --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLCategory.h
  8764. +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLCategory.h
  8765. @@ -20,24 +20,25 @@ NS_ASSUME_NONNULL_BEGIN
  8766. NS_SWIFT_NAME(ClassificationCategory)
  8767. @interface TFLCategory : NSObject
  8768. -/** Index of the class in the corresponding label map, usually packed in the TFLite Model
  8769. - * Metadata. */
  8770. +/** Index of the class in the corresponding label map, usually packed in the
  8771. + * TFLite Model Metadata. */
  8772. @property(nonatomic, readonly) NSInteger index;
  8773. /** Confidence score for this class . */
  8774. @property(nonatomic, readonly) float score;
  8775. /** Class name of the class. */
  8776. -@property(nonatomic, readonly, nullable) NSString *label;
  8777. +@property(nonatomic, readonly, nullable) NSString* label;
  8778. /** Display name of the class. */
  8779. -@property(nonatomic, readonly, nullable) NSString *displayName;
  8780. +@property(nonatomic, readonly, nullable) NSString* displayName;
  8781. /**
  8782. - * Initializes a new `TFLCategory` with the given index, score, label and display name.
  8783. + * Initializes a new `TFLCategory` with the given index, score, label and
  8784. + * display name.
  8785. *
  8786. - * @param index Index of the class in the corresponding label map, usually packed in the TFLite
  8787. - * Model Metadata.
  8788. + * @param index Index of the class in the corresponding label map, usually
  8789. + * packed in the TFLite Model Metadata.
  8790. *
  8791. * @param score Confidence score for this class.
  8792. *
  8793. @@ -45,12 +46,13 @@ NS_SWIFT_NAME(ClassificationCategory)
  8794. *
  8795. * @param displayName Display name of the class.
  8796. *
  8797. - * @return An instance of `TFLCategory` initialized with the given index, score, label and display name.
  8798. + * @return An instance of `TFLCategory` initialized with the given index, score,
  8799. + * label and display name.
  8800. */
  8801. - (instancetype)initWithIndex:(NSInteger)index
  8802. score:(float)score
  8803. - label:(nullable NSString *)label
  8804. - displayName:(nullable NSString *)displayName;
  8805. + label:(nullable NSString*)label
  8806. + displayName:(nullable NSString*)displayName;
  8807. - (instancetype)init NS_UNAVAILABLE;
  8808. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLCategory.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLCategory.m
  8809. index b72c3b55fdaf1..603c5a27c9673 100644
  8810. --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLCategory.m
  8811. +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLCategory.m
  8812. @@ -18,8 +18,8 @@
  8813. - (instancetype)initWithIndex:(NSInteger)index
  8814. score:(float)score
  8815. - label:(nullable NSString *)label
  8816. - displayName:(nullable NSString *)displayName {
  8817. + label:(nullable NSString*)label
  8818. + displayName:(nullable NSString*)displayName {
  8819. self = [super init];
  8820. if (self) {
  8821. _index = index;
  8822. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions+Helpers.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions+Helpers.h
  8823. index b12c118e89021..152aa33dbdb59 100644
  8824. --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions+Helpers.h
  8825. +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions+Helpers.h
  8826. @@ -18,11 +18,11 @@
  8827. NS_ASSUME_NONNULL_BEGIN
  8828. @interface TFLClassificationOptions (Helpers)
  8829. -- (BOOL)copyToCOptions:(TfLiteClassificationOptions *)cClassificationOptions
  8830. - error:(NSError **)error;
  8831. +- (BOOL)copyToCOptions:(TfLiteClassificationOptions*)cClassificationOptions
  8832. + error:(NSError**)error;
  8833. - (void)deleteAllocatedMemoryOfClassificationOptions:
  8834. - (TfLiteClassificationOptions *)cClassificationOptions;
  8835. + (TfLiteClassificationOptions*)cClassificationOptions;
  8836. @end
  8837. NS_ASSUME_NONNULL_END
  8838. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions+Helpers.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions+Helpers.m
  8839. index 84e8fa5e234fb..767e5e4d577a3 100644
  8840. --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions+Helpers.m
  8841. +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions+Helpers.m
  8842. @@ -20,21 +20,28 @@
  8843. + (char **)cStringArrayFromNSArray:(NSArray<NSString *> *)strings error:(NSError **)error {
  8844. if (strings.count <= 0) {
  8845. - [TFLCommonUtils createCustomError:error
  8846. - withCode:TFLSupportErrorCodeInvalidArgumentError
  8847. - description:@"Invalid length of strings found for list type options."];
  8848. + [TFLCommonUtils
  8849. + createCustomError:error
  8850. + withCode:TFLSupportErrorCodeInvalidArgumentError
  8851. + description:
  8852. + @"Invalid length of strings found for list type options."];
  8853. return nil;
  8854. }
  8855. - char **cStrings = [TFLCommonUtils mallocWithSize:strings.count * sizeof(char *) error:error];
  8856. - if (!cStrings) return NULL;
  8857. + char** cStrings = [TFLCommonUtils mallocWithSize:strings.count * sizeof(char*)
  8858. + error:error];
  8859. + if (!cStrings)
  8860. + return NULL;
  8861. for (NSInteger i = 0; i < strings.count; i++) {
  8862. cStrings[i] = [TFLCommonUtils
  8863. - mallocWithSize:([strings[i] lengthOfBytesUsingEncoding:NSUTF8StringEncoding] + 1) *
  8864. + mallocWithSize:([strings[i]
  8865. + lengthOfBytesUsingEncoding:NSUTF8StringEncoding] +
  8866. + 1) *
  8867. sizeof(char)
  8868. error:error];
  8869. - if (!cStrings[i]) return NULL;
  8870. + if (!cStrings[i])
  8871. + return NULL;
  8872. strcpy(cStrings[i], strings[i].UTF8String);
  8873. }
  8874. @@ -77,14 +84,16 @@
  8875. if (self.displayNamesLocale) {
  8876. if (self.displayNamesLocale.UTF8String) {
  8877. - cClassificationOptions->display_names_local = strdup(self.displayNamesLocale.UTF8String);
  8878. + cClassificationOptions->display_names_local =
  8879. + strdup(self.displayNamesLocale.UTF8String);
  8880. if (!cClassificationOptions->display_names_local) {
  8881. exit(-1); // Memory Allocation Failed.
  8882. }
  8883. } else {
  8884. - [TFLCommonUtils createCustomError:error
  8885. - withCode:TFLSupportErrorCodeInvalidArgumentError
  8886. - description:@"Could not convert (NSString *) to (char *)."];
  8887. + [TFLCommonUtils
  8888. + createCustomError:error
  8889. + withCode:TFLSupportErrorCodeInvalidArgumentError
  8890. + description:@"Could not convert (NSString *) to (char *)."];
  8891. return NO;
  8892. }
  8893. }
  8894. @@ -93,7 +102,7 @@
  8895. }
  8896. - (void)deleteAllocatedMemoryOfClassificationOptions:
  8897. - (TfLiteClassificationOptions *)cClassificationOptions {
  8898. + (TfLiteClassificationOptions*)cClassificationOptions {
  8899. if (self.labelAllowList) {
  8900. [TFLClassificationOptions deleteCStringsArray:cClassificationOptions->label_allowlist.list
  8901. count:cClassificationOptions->label_allowlist.length];
  8902. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions.h
  8903. index 41b69bec8a7d8..ce3f5d6580913 100644
  8904. --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions.h
  8905. +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions.h
  8906. @@ -23,13 +23,14 @@ NS_SWIFT_NAME(ClassificationOptions)
  8907. @interface TFLClassificationOptions : NSObject <NSCopying>
  8908. /** If set, all classes in this list will be filtered out from the results . */
  8909. -@property(nonatomic, copy) NSArray *labelDenyList;
  8910. +@property(nonatomic, copy) NSArray* labelDenyList;
  8911. -/** If set, all classes not in this list will be filtered out from the results . */
  8912. -@property(nonatomic, copy) NSArray *labelAllowList;
  8913. +/** If set, all classes not in this list will be filtered out from the results .
  8914. + */
  8915. +@property(nonatomic, copy) NSArray* labelAllowList;
  8916. /** Display names local for display names*/
  8917. -@property(nonatomic, copy) NSString *displayNamesLocale;
  8918. +@property(nonatomic, copy) NSString* displayNamesLocale;
  8919. /** Results with score threshold greater than this value are returned . */
  8920. @property(nonatomic) float scoreThreshold;
  8921. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationResult+Helpers.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationResult+Helpers.h
  8922. index 7ef58fc5b76ce..351e87db729c6 100644
  8923. --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationResult+Helpers.h
  8924. +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationResult+Helpers.h
  8925. @@ -20,17 +20,18 @@ NS_ASSUME_NONNULL_BEGIN
  8926. @interface TFLClassificationResult (Helpers)
  8927. /**
  8928. - * Creates and returns a TFLClassificationResult from a TfLiteClassificationResult returned by
  8929. - * TFLite Task C Library Classification tasks.
  8930. + * Creates and returns a TFLClassificationResult from a
  8931. + * TfLiteClassificationResult returned by TFLite Task C Library Classification
  8932. + * tasks.
  8933. *
  8934. - * @param cClassificationResult Classification results returned by TFLite Task C Library
  8935. - * Classification tasks
  8936. + * @param cClassificationResult Classification results returned by TFLite Task C
  8937. + * Library Classification tasks
  8938. *
  8939. - * @return Classification Result of type TFLClassificationResult to be returned by inference methods
  8940. - * of the iOS TF Lite Task Classification tasks.
  8941. + * @return Classification Result of type TFLClassificationResult to be returned
  8942. + * by inference methods of the iOS TF Lite Task Classification tasks.
  8943. */
  8944. -+ (TFLClassificationResult *)classificationResultWithCResult:
  8945. - (TfLiteClassificationResult *)cClassificationResult;
  8946. ++ (TFLClassificationResult*)classificationResultWithCResult:
  8947. + (TfLiteClassificationResult*)cClassificationResult;
  8948. @end
  8949. NS_ASSUME_NONNULL_END
  8950. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationResult+Helpers.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationResult+Helpers.m
  8951. index c8744a3bf99c6..52e92852d88a9 100644
  8952. --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationResult+Helpers.m
  8953. +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationResult+Helpers.m
  8954. @@ -19,30 +19,34 @@
  8955. + (TFLClassificationResult *)classificationResultWithCResult:
  8956. (TfLiteClassificationResult *)cClassificationResult {
  8957. - if (!cClassificationResult) return nil;
  8958. + if (!cClassificationResult)
  8959. + return nil;
  8960. NSMutableArray *classificationHeads = [[NSMutableArray alloc] init];
  8961. for (int i = 0; i < cClassificationResult->size; i++) {
  8962. TfLiteClassifications cClassifications = cClassificationResult->classifications[i];
  8963. - NSMutableArray *categories = [[NSMutableArray alloc] init];
  8964. + NSMutableArray* categories = [[NSMutableArray alloc] init];
  8965. for (int j = 0; j < cClassifications.size; j++) {
  8966. TfLiteCategory cCategory = cClassifications.categories[j];
  8967. [categories addObject:[TFLCategory categoryWithCCategory:&cCategory]];
  8968. }
  8969. - NSString *headName = nil;
  8970. + NSString* headName = nil;
  8971. if (cClassifications.head_name) {
  8972. - headName = [NSString stringWithCString:cClassifications.head_name encoding:NSUTF8StringEncoding];
  8973. + headName = [NSString stringWithCString:cClassifications.head_name
  8974. + encoding:NSUTF8StringEncoding];
  8975. }
  8976. -
  8977. - TFLClassifications *classifications = [[TFLClassifications alloc] initWithHeadIndex:cClassifications.head_index
  8978. - headName:headName
  8979. - categories:categories];
  8980. +
  8981. + TFLClassifications* classifications = [[TFLClassifications alloc]
  8982. + initWithHeadIndex:cClassifications.head_index
  8983. + headName:headName
  8984. + categories:categories];
  8985. [classificationHeads addObject:classifications];
  8986. }
  8987. - return [[TFLClassificationResult alloc] initWithClassifications:classificationHeads];
  8988. + return [[TFLClassificationResult alloc]
  8989. + initWithClassifications:classificationHeads];
  8990. }
  8991. @end
  8992. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationResult.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationResult.h
  8993. index 72d5c85dec0d6..052b4f1daf710 100644
  8994. --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationResult.h
  8995. +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationResult.h
  8996. @@ -17,58 +17,66 @@ limitations under the License.
  8997. NS_ASSUME_NONNULL_BEGIN
  8998. -/** Encapsulates list of predicted classes (aka labels) for a given image classifier head. */
  8999. +/** Encapsulates list of predicted classes (aka labels) for a given image
  9000. + * classifier head. */
  9001. NS_SWIFT_NAME(Classifications)
  9002. @interface TFLClassifications : NSObject
  9003. /**
  9004. - * The index of the image classifier head these classes refer to. This is useful for multi-head
  9005. - * models.
  9006. + * The index of the image classifier head these classes refer to. This is useful
  9007. + * for multi-head models.
  9008. */
  9009. @property(nonatomic, readonly) NSInteger headIndex;
  9010. /** The name of the classifier head, which is the corresponding tensor metadata
  9011. - * name. See https://github.com/tensorflow/tflite-support/blob/710e323265bfb71fdbdd72b3516e00cff15c0326/tensorflow_lite_support/metadata/metadata_schema.fbs#L545
  9012. - * This will always be NULL for the `TFLClassifications` in the `TFLClassificationResult` returned by the follwing methods of `TFLImageClassifier`.
  9013. + * name. See
  9014. + * https://github.com/tensorflow/tflite-support/blob/710e323265bfb71fdbdd72b3516e00cff15c0326/tensorflow_lite_support/metadata/metadata_schema.fbs#L545
  9015. + * This will always be NULL for the `TFLClassifications` in the
  9016. + * `TFLClassificationResult` returned by the follwing methods of
  9017. + * `TFLImageClassifier`.
  9018. * 1. -[TFLImageClassifier classifyWithGMLImage:error:]
  9019. * 2. -[TFLImageClassifier classifyWithGMLImage:regionOfInterest:error:]
  9020. */
  9021. -@property(nonatomic, readonly) NSString *headName;
  9022. +@property(nonatomic, readonly) NSString* headName;
  9023. -/** The array of predicted classes, usually sorted by descending scores (e.g.from high to low
  9024. - * probability). */
  9025. -@property(nonatomic, readonly) NSArray<TFLCategory *> *categories;
  9026. +/** The array of predicted classes, usually sorted by descending scores
  9027. + * (e.g.from high to low probability). */
  9028. +@property(nonatomic, readonly) NSArray<TFLCategory*>* categories;
  9029. /**
  9030. - * Initializes a new `TFLClassifications` with the given head index and array of categories.
  9031. - * head name is initialized to `nil`.
  9032. + * Initializes a new `TFLClassifications` with the given head index and array of
  9033. + * categories. head name is initialized to `nil`.
  9034. *
  9035. - * @param headIndex The index of the image classifier head these classes refer to.
  9036. + * @param headIndex The index of the image classifier head these classes refer
  9037. + * to.
  9038. * @param categories An array of `TFLCategory` objects encapsulating a list of
  9039. - * predictions usually sorted by descending scores (e.g. from high to low probability).
  9040. + * predictions usually sorted by descending scores (e.g. from high to low
  9041. + * probability).
  9042. *
  9043. - * @return An instance of `TFLClassifications` initialized with the given head index and
  9044. - * array of categories.
  9045. + * @return An instance of `TFLClassifications` initialized with the given head
  9046. + * index and array of categories.
  9047. */
  9048. - (instancetype)initWithHeadIndex:(NSInteger)headIndex
  9049. - categories:(NSArray<TFLCategory *> *)categories;
  9050. -
  9051. + categories:(NSArray<TFLCategory*>*)categories;
  9052. /**
  9053. - * Initializes a new `TFLClassifications` with the given head index, head name and array of categories.
  9054. + * Initializes a new `TFLClassifications` with the given head index, head name
  9055. + * and array of categories.
  9056. *
  9057. - * @param headIndex The index of the image classifier head these classes refer to.
  9058. - * @param headName The name of the classifier head, which is the corresponding tensor metadata
  9059. - * name.
  9060. + * @param headIndex The index of the image classifier head these classes refer
  9061. + * to.
  9062. + * @param headName The name of the classifier head, which is the corresponding
  9063. + * tensor metadata name.
  9064. * @param categories An array of `TFLCategory` objects encapsulating a list of
  9065. - * predictions usually sorted by descending scores (e.g. from high to low probability).
  9066. + * predictions usually sorted by descending scores (e.g. from high to low
  9067. + * probability).
  9068. *
  9069. - * @return An object of `TFLClassifications` initialized with the given head index, head name and
  9070. - * array of categories.
  9071. + * @return An object of `TFLClassifications` initialized with the given head
  9072. + * index, head name and array of categories.
  9073. */
  9074. - (instancetype)initWithHeadIndex:(NSInteger)headIndex
  9075. - headName:(nullable NSString *)headName
  9076. - categories:(NSArray<TFLCategory *> *)categories;
  9077. + headName:(nullable NSString*)headName
  9078. + categories:(NSArray<TFLCategory*>*)categories;
  9079. @end
  9080. @@ -76,20 +84,23 @@ NS_SWIFT_NAME(Classifications)
  9081. NS_SWIFT_NAME(ClassificationResult)
  9082. @interface TFLClassificationResult : NSObject
  9083. -/** Array of TFLClassifications objects containing image classifier predictions per image classifier
  9084. - * head.
  9085. +/** Array of TFLClassifications objects containing image classifier predictions
  9086. + * per image classifier head.
  9087. */
  9088. -@property(nonatomic, readonly) NSArray<TFLClassifications *> *classifications;
  9089. +@property(nonatomic, readonly) NSArray<TFLClassifications*>* classifications;
  9090. /**
  9091. - * Initializes a new `TFLClassificationResult` with the given array of classifications.
  9092. + * Initializes a new `TFLClassificationResult` with the given array of
  9093. + * classifications.
  9094. *
  9095. - * @param classifications An Aaray of `TFLClassifications` objects containing image classifier
  9096. - * predictions per image classifier head.
  9097. + * @param classifications An Aaray of `TFLClassifications` objects containing
  9098. + * image classifier predictions per image classifier head.
  9099. *
  9100. - * @return An instance of 1TFLClassificationResult1 initialized with the given array of classifications.
  9101. + * @return An instance of 1TFLClassificationResult1 initialized with the given
  9102. + * array of classifications.
  9103. */
  9104. -- (instancetype)initWithClassifications:(NSArray<TFLClassifications *> *)classifications;
  9105. +- (instancetype)initWithClassifications:
  9106. + (NSArray<TFLClassifications*>*)classifications;
  9107. @end
  9108. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationResult.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationResult.m
  9109. index f56600cb94f3b..0ea238417c891 100644
  9110. --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationResult.m
  9111. +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationResult.m
  9112. @@ -17,9 +17,8 @@ limitations under the License.
  9113. @implementation TFLClassifications
  9114. - (instancetype)initWithHeadIndex:(NSInteger)headIndex
  9115. - headName:(nullable NSString *)headName
  9116. - categories:(NSArray<TFLCategory *> *)categories {
  9117. -
  9118. + headName:(nullable NSString*)headName
  9119. + categories:(NSArray<TFLCategory*>*)categories {
  9120. self = [super init];
  9121. if (self) {
  9122. _headIndex = headIndex;
  9123. @@ -30,17 +29,18 @@ limitations under the License.
  9124. }
  9125. - (instancetype)initWithHeadIndex:(NSInteger)headIndex
  9126. - categories:(NSArray<TFLCategory *> *)categories {
  9127. + categories:(NSArray<TFLCategory*>*)categories {
  9128. return [self initWithHeadIndex:headIndex headName:nil categories:categories];
  9129. }
  9130. @end
  9131. @implementation TFLClassificationResult {
  9132. - NSArray<TFLClassifications *> *_classifications;
  9133. + NSArray<TFLClassifications*>* _classifications;
  9134. }
  9135. -- (instancetype)initWithClassifications:(NSArray<TFLClassifications *> *)classifications {
  9136. +- (instancetype)initWithClassifications:
  9137. + (NSArray<TFLClassifications*>*)classifications {
  9138. self = [super init];
  9139. if (self) {
  9140. _classifications = classifications;
  9141. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLDetectionResult+Helpers.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLDetectionResult+Helpers.h
  9142. index 7f6e8cae27f2c..81efbcc1d8c57 100644
  9143. --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLDetectionResult+Helpers.h
  9144. +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLDetectionResult+Helpers.h
  9145. @@ -19,16 +19,17 @@ NS_ASSUME_NONNULL_BEGIN
  9146. @interface TFLDetectionResult (Helpers)
  9147. /**
  9148. - * Creates and retrurns a TFLDetectionResult from a TfLiteDetectionResult returned by
  9149. - * TFLite Task C Library Object Detection task.
  9150. + * Creates and retrurns a TFLDetectionResult from a TfLiteDetectionResult
  9151. + * returned by TFLite Task C Library Object Detection task.
  9152. *
  9153. * @param cDetectionResult Detection results returned by TFLite Task C Library
  9154. * Object Detection task.
  9155. *
  9156. - * @return Detection Result of type TFLDetectionResult to be returned by inference methods
  9157. - * of the iOS TF Lite Task Object Detection task.
  9158. + * @return Detection Result of type TFLDetectionResult to be returned by
  9159. + * inference methods of the iOS TF Lite Task Object Detection task.
  9160. */
  9161. -+ (TFLDetectionResult *)detectionResultWithCResult:(TfLiteDetectionResult *)cDetectionResult;
  9162. ++ (TFLDetectionResult*)detectionResultWithCResult:
  9163. + (TfLiteDetectionResult*)cDetectionResult;
  9164. @end
  9165. NS_ASSUME_NONNULL_END
  9166. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLDetectionResult+Helpers.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLDetectionResult+Helpers.m
  9167. index 405bddf117cdd..3ae292cb0ef3b 100644
  9168. --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLDetectionResult+Helpers.m
  9169. +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLDetectionResult+Helpers.m
  9170. @@ -17,8 +17,10 @@
  9171. @implementation TFLDetectionResult (Helpers)
  9172. -+ (TFLDetectionResult *)detectionResultWithCResult:(TfLiteDetectionResult *)cDetectionResult {
  9173. - if (!cDetectionResult) return nil;
  9174. ++ (TFLDetectionResult*)detectionResultWithCResult:
  9175. + (TfLiteDetectionResult*)cDetectionResult {
  9176. + if (!cDetectionResult)
  9177. + return nil;
  9178. NSMutableArray *detections = [[NSMutableArray alloc] init];
  9179. for (int i = 0; i < cDetectionResult->size; i++) {
  9180. @@ -30,10 +32,11 @@
  9181. TFLCategory *resultCategory = [TFLCategory categoryWithCCategory:&cCategory];
  9182. [categories addObject:resultCategory];
  9183. }
  9184. - TFLDetection *detection = [[TFLDetection alloc]
  9185. - initWithBoundingBox:CGRectMake(
  9186. - cDetection.bounding_box.origin_x, cDetection.bounding_box.origin_y,
  9187. - cDetection.bounding_box.width, cDetection.bounding_box.height)
  9188. + TFLDetection* detection = [[TFLDetection alloc]
  9189. + initWithBoundingBox:CGRectMake(cDetection.bounding_box.origin_x,
  9190. + cDetection.bounding_box.origin_y,
  9191. + cDetection.bounding_box.width,
  9192. + cDetection.bounding_box.height)
  9193. categories:categories];
  9194. [detections addObject:detection];
  9195. }
  9196. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLDetectionResult.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLDetectionResult.h
  9197. index 0c64aa98b6089..00cc75bbc161e 100644
  9198. --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLDetectionResult.h
  9199. +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLDetectionResult.h
  9200. @@ -19,31 +19,35 @@ limitations under the License.
  9201. NS_ASSUME_NONNULL_BEGIN
  9202. -/** Encapsulates list of predicted classes (aka labels) and bounding box for a detected object. */
  9203. +/** Encapsulates list of predicted classes (aka labels) and bounding box for a
  9204. + * detected object. */
  9205. NS_SWIFT_NAME(Detection)
  9206. @interface TFLDetection : NSObject
  9207. /**
  9208. - * The index of the image classifier head these classes refer to. This is useful for multi-head
  9209. - * models.
  9210. + * The index of the image classifier head these classes refer to. This is useful
  9211. + * for multi-head models.
  9212. */
  9213. @property(nonatomic, readonly) CGRect boundingBox;
  9214. -/** The array of predicted classes, usually sorted by descending scores (e.g.from high to low
  9215. - * probability). */
  9216. -@property(nonatomic, readonly) NSArray<TFLCategory *> *categories;
  9217. +/** The array of predicted classes, usually sorted by descending scores
  9218. + * (e.g.from high to low probability). */
  9219. +@property(nonatomic, readonly) NSArray<TFLCategory*>* categories;
  9220. /**
  9221. - * Initializes an object of `TFLDetection` with the given bounding box and array of categories.
  9222. + * Initializes an object of `TFLDetection` with the given bounding box and array
  9223. + * of categories.
  9224. *
  9225. - * @param boundingBox CGRect specifying the bounds of the object represented by this detection.
  9226. - * @param categories Array of predicted classes, usually sorted by descending scores (e.g.from high
  9227. - * to low probability).
  9228. + * @param boundingBox CGRect specifying the bounds of the object represented by
  9229. + * this detection.
  9230. + * @param categories Array of predicted classes, usually sorted by descending
  9231. + * scores (e.g.from high to low probability).
  9232. *
  9233. - * @return An instance of `TFLDetection` initialized with the given bounding box and array of categories.
  9234. + * @return An instance of `TFLDetection` initialized with the given bounding box
  9235. + * and array of categories.
  9236. */
  9237. - (instancetype)initWithBoundingBox:(CGRect)boundingBox
  9238. - categories:(NSArray<TFLCategory *> *)categories;
  9239. + categories:(NSArray<TFLCategory*>*)categories;
  9240. - (instancetype)init NS_UNAVAILABLE;
  9241. @@ -55,16 +59,17 @@ NS_SWIFT_NAME(Detection)
  9242. NS_SWIFT_NAME(DetectionResult)
  9243. @interface TFLDetectionResult : NSObject
  9244. -@property(nonatomic, readonly) NSArray<TFLDetection *> *detections;
  9245. +@property(nonatomic, readonly) NSArray<TFLDetection*>* detections;
  9246. /**
  9247. * Initializes a new `TFLDetectionResult` with the given array of detections.
  9248. *
  9249. * @param detections Array of detected objects of type TFLDetection.
  9250. *
  9251. - * @return An instance of `TFLDetectionResult` initialized with the given array of detections.
  9252. + * @return An instance of `TFLDetectionResult` initialized with the given array
  9253. + * of detections.
  9254. */
  9255. -- (instancetype)initWithDetections:(NSArray<TFLDetection *> *)detections;
  9256. +- (instancetype)initWithDetections:(NSArray<TFLDetection*>*)detections;
  9257. - (instancetype)init NS_UNAVAILABLE;
  9258. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLDetectionResult.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLDetectionResult.m
  9259. index 280767e6a353a..14cec3bca3d08 100644
  9260. --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLDetectionResult.m
  9261. +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLDetectionResult.m
  9262. @@ -17,7 +17,7 @@ limitations under the License.
  9263. @implementation TFLDetection
  9264. - (instancetype)initWithBoundingBox:(CGRect)boundingBox
  9265. - categories:(NSArray<TFLCategory *> *)categories {
  9266. + categories:(NSArray<TFLCategory*>*)categories {
  9267. self = [super init];
  9268. if (self) {
  9269. _boundingBox = boundingBox;
  9270. @@ -30,7 +30,7 @@ limitations under the License.
  9271. @implementation TFLDetectionResult
  9272. -- (instancetype)initWithDetections:(NSArray<TFLDetection *> *)detections {
  9273. +- (instancetype)initWithDetections:(NSArray<TFLDetection*>*)detections {
  9274. self = [super init];
  9275. if (self) {
  9276. _detections = detections;
  9277. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLSegmentationResult+Helpers.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLSegmentationResult+Helpers.h
  9278. index c979fda53c70b..0a85efe2877bb 100644
  9279. --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLSegmentationResult+Helpers.h
  9280. +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLSegmentationResult+Helpers.h
  9281. @@ -28,8 +28,8 @@ NS_ASSUME_NONNULL_BEGIN
  9282. * @return Segmentation Result of type TFLSegmentationResult to be returned by
  9283. * inference methods of the iOS TF Lite Task Image Segmentation task.
  9284. */
  9285. -+ (TFLSegmentationResult *)segmentationResultWithCResult:
  9286. - (TfLiteSegmentationResult *)cSegmentationResult;
  9287. ++ (TFLSegmentationResult*)segmentationResultWithCResult:
  9288. + (TfLiteSegmentationResult*)cSegmentationResult;
  9289. @end
  9290. NS_ASSUME_NONNULL_END
  9291. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLSegmentationResult+Helpers.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLSegmentationResult+Helpers.m
  9292. index f2ea957ca3010..2a897f0ba3614 100644
  9293. --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLSegmentationResult+Helpers.m
  9294. +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLSegmentationResult+Helpers.m
  9295. @@ -16,29 +16,31 @@
  9296. @implementation TFLSegmentationResult (Helpers)
  9297. -+ (TFLSegmentationResult *)segmentationResultWithCResult:
  9298. - (TfLiteSegmentationResult *)cSegmentationResult {
  9299. - if (!cSegmentationResult) return nil;
  9300. ++ (TFLSegmentationResult*)segmentationResultWithCResult:
  9301. + (TfLiteSegmentationResult*)cSegmentationResult {
  9302. + if (!cSegmentationResult)
  9303. + return nil;
  9304. - NSMutableArray *segmentations = [[NSMutableArray alloc] init];
  9305. + NSMutableArray* segmentations = [[NSMutableArray alloc] init];
  9306. for (int i = 0; i < cSegmentationResult->size; i++) {
  9307. TfLiteSegmentation cSegmentation = cSegmentationResult->segmentations[i];
  9308. - NSMutableArray *coloredLabels = [[NSMutableArray alloc] init];
  9309. + NSMutableArray* coloredLabels = [[NSMutableArray alloc] init];
  9310. for (int j = 0; j < cSegmentation.colored_labels_size; j++) {
  9311. TfLiteColoredLabel cColoredLabel = cSegmentation.colored_labels[j];
  9312. - NSString *displayName;
  9313. + NSString* displayName;
  9314. if (cColoredLabel.display_name) {
  9315. displayName = [NSString stringWithCString:cColoredLabel.display_name
  9316. encoding:NSUTF8StringEncoding];
  9317. }
  9318. - NSString *label;
  9319. + NSString* label;
  9320. if (cColoredLabel.label) {
  9321. - label = [NSString stringWithCString:cColoredLabel.label encoding:NSUTF8StringEncoding];
  9322. + label = [NSString stringWithCString:cColoredLabel.label
  9323. + encoding:NSUTF8StringEncoding];
  9324. }
  9325. - TFLColoredLabel *coloredLabel =
  9326. + TFLColoredLabel* coloredLabel =
  9327. [[TFLColoredLabel alloc] initWithRed:(NSUInteger)cColoredLabel.r
  9328. green:(NSUInteger)cColoredLabel.g
  9329. blue:(NSUInteger)cColoredLabel.b
  9330. @@ -47,27 +49,29 @@
  9331. [coloredLabels addObject:coloredLabel];
  9332. }
  9333. - TFLSegmentation *segmentation;
  9334. + TFLSegmentation* segmentation;
  9335. if (cSegmentation.confidence_masks) {
  9336. - NSMutableArray *confidenceMasks = [[NSMutableArray alloc] init];
  9337. + NSMutableArray* confidenceMasks = [[NSMutableArray alloc] init];
  9338. for (int i = 0; i < cSegmentation.colored_labels_size; i++) {
  9339. - TFLConfidenceMask *confidenceMask =
  9340. - [[TFLConfidenceMask alloc] initWithWidth:(NSInteger)cSegmentation.width
  9341. - height:(NSInteger)cSegmentation.height
  9342. - mask:cSegmentation.confidence_masks[i]];
  9343. + TFLConfidenceMask* confidenceMask = [[TFLConfidenceMask alloc]
  9344. + initWithWidth:(NSInteger)cSegmentation.width
  9345. + height:(NSInteger)cSegmentation.height
  9346. + mask:cSegmentation.confidence_masks[i]];
  9347. [confidenceMasks addObject:confidenceMask];
  9348. }
  9349. - segmentation = [[TFLSegmentation alloc] initWithConfidenceMasks:confidenceMasks
  9350. - coloredLabels:coloredLabels];
  9351. + segmentation =
  9352. + [[TFLSegmentation alloc] initWithConfidenceMasks:confidenceMasks
  9353. + coloredLabels:coloredLabels];
  9354. } else if (cSegmentation.category_mask) {
  9355. - TFLCategoryMask *categoryMask =
  9356. + TFLCategoryMask* categoryMask =
  9357. [[TFLCategoryMask alloc] initWithWidth:(NSInteger)cSegmentation.width
  9358. height:(NSInteger)cSegmentation.height
  9359. mask:cSegmentation.category_mask];
  9360. - segmentation = [[TFLSegmentation alloc] initWithCategoryMask:categoryMask
  9361. - coloredLabels:coloredLabels];
  9362. + segmentation =
  9363. + [[TFLSegmentation alloc] initWithCategoryMask:categoryMask
  9364. + coloredLabels:coloredLabels];
  9365. }
  9366. [segmentations addObject:segmentation];
  9367. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLSegmentationResult.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLSegmentationResult.h
  9368. index 1307e26294dd4..3aca4567ebe2e 100644
  9369. --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLSegmentationResult.h
  9370. +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLSegmentationResult.h
  9371. @@ -23,7 +23,7 @@ NS_SWIFT_NAME(ConfidenceMask)
  9372. /**
  9373. * Confidence masks of size `width` x `height` for any one class.
  9374. */
  9375. -@property(nonatomic, readonly) float *mask;
  9376. +@property(nonatomic, readonly) float* mask;
  9377. /**
  9378. * The width of the mask. This is an intrinsic parameter of the model being
  9379. @@ -42,7 +42,7 @@ NS_SWIFT_NAME(ConfidenceMask)
  9380. */
  9381. - (instancetype)initWithWidth:(NSInteger)width
  9382. height:(NSInteger)height
  9383. - mask:(float * _Nullable)mask;
  9384. + mask:(float* _Nullable)mask;
  9385. - (instancetype)init NS_UNAVAILABLE;
  9386. @@ -59,7 +59,7 @@ NS_SWIFT_NAME(CategoryMask)
  9387. * The value of each pixel in this mask represents the class to which the
  9388. * pixel belongs.
  9389. */
  9390. -@property(nonatomic, readonly) UInt8 *mask;
  9391. +@property(nonatomic, readonly) UInt8* mask;
  9392. /**
  9393. * The width of the mask. This is an intrinsic parameter of the model being
  9394. @@ -80,15 +80,15 @@ NS_SWIFT_NAME(CategoryMask)
  9395. *
  9396. * @param width Width of the mask.
  9397. * @param height Height of the mask.
  9398. - * @param mask Flattened 2D-array of size `width` x `height`, in row major order.
  9399. - * The value of each pixel in this mask represents the class to which the
  9400. + * @param mask Flattened 2D-array of size `width` x `height`, in row major
  9401. + * order. The value of each pixel in this mask represents the class to which the
  9402. * pixel belongs.
  9403. *
  9404. * @return An instance of TFLCategoryMask initialized to the specified values.
  9405. */
  9406. - (instancetype)initWithWidth:(NSInteger)width
  9407. height:(NSInteger)height
  9408. - mask:(UInt8 * _Nullable)mask;
  9409. + mask:(UInt8* _Nullable)mask;
  9410. - (instancetype)init NS_UNAVAILABLE;
  9411. @@ -107,17 +107,18 @@ NS_SWIFT_NAME(ColoredLabel)
  9412. * The class name, as provided in the label map packed in the TFLite Model
  9413. * Metadata.
  9414. */
  9415. -@property(nonatomic, readonly) NSString *label;
  9416. +@property(nonatomic, readonly) NSString* label;
  9417. /**
  9418. * The display name, as provided in the label map (if available) packed in
  9419. * the TFLite Model Metadata. See displayNamesLocale in
  9420. * TFLClassificationOptions.
  9421. */
  9422. -@property(nonatomic, readonly) NSString *displayName;
  9423. +@property(nonatomic, readonly) NSString* displayName;
  9424. /**
  9425. - * Initializes a new `TFLColoredLabel` with red, gree, blue color components, label and display name.
  9426. + * Initializes a new `TFLColoredLabel` with red, gree, blue color components,
  9427. + * label and display name.
  9428. *
  9429. * @param r Red component of the RGB color components.
  9430. * @param g Green component of the RGB color components.
  9431. @@ -125,13 +126,14 @@ NS_SWIFT_NAME(ColoredLabel)
  9432. * @param label Class name.
  9433. * @param displayName Display name.
  9434. *
  9435. - * @return An instance of TFLColoredLabel initialized with red, gree, blue color components, label and display name.
  9436. + * @return An instance of TFLColoredLabel initialized with red, gree, blue color
  9437. + * components, label and display name.
  9438. */
  9439. - (instancetype)initWithRed:(NSUInteger)r
  9440. green:(NSUInteger)g
  9441. blue:(NSUInteger)b
  9442. - label:(NSString *)label
  9443. - displayName:(NSString *)displayName;
  9444. + label:(NSString*)label
  9445. + displayName:(NSString*)displayName;
  9446. - (instancetype)init NS_UNAVAILABLE;
  9447. @@ -150,7 +152,8 @@ NS_SWIFT_NAME(Segmentation)
  9448. * this particular class.
  9449. * This property is mutually exclusive with `categoryMask`.
  9450. */
  9451. -@property(nonatomic, nullable, readonly) NSArray<TFLConfidenceMask *> *confidenceMasks;
  9452. +@property(nonatomic, nullable, readonly)
  9453. + NSArray<TFLConfidenceMask*>* confidenceMasks;
  9454. /**
  9455. * Holds the category mask.
  9456. @@ -158,7 +161,7 @@ NS_SWIFT_NAME(Segmentation)
  9457. * pixel belongs.
  9458. * This property is mutually exclusive with `confidenceMasks`.
  9459. */
  9460. -@property(nonatomic, nullable, readonly) TFLCategoryMask *categoryMask;
  9461. +@property(nonatomic, nullable, readonly) TFLCategoryMask* categoryMask;
  9462. /**
  9463. * The list of colored labels for all the supported categories (classes).
  9464. @@ -167,33 +170,38 @@ NS_SWIFT_NAME(Segmentation)
  9465. * `colored_labels[i]`, `confidence_masks` indices, i.e. `confidence_masks[i]`
  9466. * is associated with `colored_labels[i]`.
  9467. */
  9468. -@property(nonatomic, readonly) NSArray<TFLColoredLabel *> *coloredLabels;
  9469. +@property(nonatomic, readonly) NSArray<TFLColoredLabel*>* coloredLabels;
  9470. + (instancetype)new NS_UNAVAILABLE;
  9471. /**
  9472. - * Initializes a new `TFLSegmentation` with an array of confidence masks and an array of colored labels.
  9473. - * `categoryMask` is initialized to `nil` as it is mutually exclusive with `confidenceMasks`.
  9474. + * Initializes a new `TFLSegmentation` with an array of confidence masks and an
  9475. + * array of colored labels. `categoryMask` is initialized to `nil` as it is
  9476. + * mutually exclusive with `confidenceMasks`.
  9477. *
  9478. * @param confidenceMasks An array of `TFLConfidenceMask` objects.
  9479. * @param coloredLabels An array of `TFLColoredLabel` objects.
  9480. *
  9481. - * @return An instance of `TFLSegmentation` initialized with an array of confidence masks and an array of colored labels.
  9482. + * @return An instance of `TFLSegmentation` initialized with an array of
  9483. + * confidence masks and an array of colored labels.
  9484. */
  9485. -- (instancetype)initWithConfidenceMasks:(NSArray<TFLConfidenceMask *> *)confidenceMasks
  9486. - coloredLabels:(NSArray<TFLColoredLabel *> *)coloredLabels;
  9487. +- (instancetype)
  9488. + initWithConfidenceMasks:(NSArray<TFLConfidenceMask*>*)confidenceMasks
  9489. + coloredLabels:(NSArray<TFLColoredLabel*>*)coloredLabels;
  9490. /**
  9491. - * Initializes a new `TFLSegmentation` with a category mask and array of colored labels.
  9492. - * `confidenceMasks` is initialized to `nil` as it is mutually exclusive with `categoryMask`.
  9493. + * Initializes a new `TFLSegmentation` with a category mask and array of colored
  9494. + * labels. `confidenceMasks` is initialized to `nil` as it is mutually exclusive
  9495. + * with `categoryMask`.
  9496. *
  9497. * @param categoryMask A `TFLCategoryMask` object.
  9498. * @param coloredLabels An array of `TFLColoredLabel` objects.
  9499. *
  9500. - * @return An instance of `TFLSegmentation` initialized with a category mask and array of colored labels.
  9501. + * @return An instance of `TFLSegmentation` initialized with a category mask and
  9502. + * array of colored labels.
  9503. */
  9504. -- (instancetype)initWithCategoryMask:(TFLCategoryMask *)categoryMask
  9505. - coloredLabels:(NSArray<TFLColoredLabel *> *)coloredLabels;
  9506. +- (instancetype)initWithCategoryMask:(TFLCategoryMask*)categoryMask
  9507. + coloredLabels:(NSArray<TFLColoredLabel*>*)coloredLabels;
  9508. - (instancetype)init NS_UNAVAILABLE;
  9509. @@ -209,7 +217,7 @@ NS_SWIFT_NAME(SegmentationResult)
  9510. * e.g. instance segmentation models, which may return one segmentation per
  9511. * object.
  9512. */
  9513. -@property(nonatomic, readonly) NSArray<TFLSegmentation *> *segmentations;
  9514. +@property(nonatomic, readonly) NSArray<TFLSegmentation*>* segmentations;
  9515. + (instancetype)new NS_UNAVAILABLE;
  9516. @@ -218,9 +226,10 @@ NS_SWIFT_NAME(SegmentationResult)
  9517. *
  9518. * @param segmentations An array of `TFLSegmentation` objects.
  9519. *
  9520. - * @return An instance of `TFLSegmentationResult` initialized with an array of segmentations.
  9521. + * @return An instance of `TFLSegmentationResult` initialized with an array of
  9522. + * segmentations.
  9523. */
  9524. -- (instancetype)initWithSegmentations:(NSArray<TFLSegmentation *> *)segmentations;
  9525. +- (instancetype)initWithSegmentations:(NSArray<TFLSegmentation*>*)segmentations;
  9526. - (instancetype)init NS_UNAVAILABLE;
  9527. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLSegmentationResult.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLSegmentationResult.m
  9528. index 33defd1139509..45b5510525fdc 100644
  9529. --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLSegmentationResult.m
  9530. +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLSegmentationResult.m
  9531. @@ -17,13 +17,16 @@
  9532. @implementation TFLCategoryMask
  9533. -- (instancetype)initWithWidth:(NSInteger)width height:(NSInteger)height mask:(UInt8 *)mask {
  9534. +- (instancetype)initWithWidth:(NSInteger)width
  9535. + height:(NSInteger)height
  9536. + mask:(UInt8*)mask {
  9537. self = [super init];
  9538. if (self) {
  9539. _width = width;
  9540. _height = height;
  9541. if (mask != NULL) {
  9542. - _mask = [TFLCommonUtils mallocWithSize:width * height * sizeof(UInt8) error:nil];
  9543. + _mask = [TFLCommonUtils mallocWithSize:width * height * sizeof(UInt8)
  9544. + error:nil];
  9545. if (_mask) {
  9546. memcpy(_mask, mask, width * height * sizeof(UInt8));
  9547. }
  9548. @@ -32,7 +35,7 @@
  9549. return self;
  9550. }
  9551. -- (id)copyWithZone:(NSZone *)zone {
  9552. +- (id)copyWithZone:(NSZone*)zone {
  9553. return [[TFLCategoryMask alloc] initWithWidth:self.width
  9554. height:self.height
  9555. mask:self.mask];
  9556. @@ -46,13 +49,16 @@
  9557. @implementation TFLConfidenceMask
  9558. -- (instancetype)initWithWidth:(NSInteger)width height:(NSInteger)height mask:(float *)mask {
  9559. +- (instancetype)initWithWidth:(NSInteger)width
  9560. + height:(NSInteger)height
  9561. + mask:(float*)mask {
  9562. self = [super init];
  9563. if (self) {
  9564. _width = width;
  9565. _height = height;
  9566. if (mask != NULL) {
  9567. - _mask = [TFLCommonUtils mallocWithSize:width * height * sizeof(float) error:nil];
  9568. + _mask = [TFLCommonUtils mallocWithSize:width * height * sizeof(float)
  9569. + error:nil];
  9570. if (_mask) {
  9571. memcpy(_mask, mask, width * height * sizeof(float));
  9572. }
  9573. @@ -61,7 +67,7 @@
  9574. return self;
  9575. }
  9576. -- (id)copyWithZone:(NSZone *)zone {
  9577. +- (id)copyWithZone:(NSZone*)zone {
  9578. return [[TFLConfidenceMask alloc] initWithWidth:self.width
  9579. height:self.height
  9580. mask:self.mask];
  9581. @@ -78,8 +84,8 @@
  9582. - (instancetype)initWithRed:(NSUInteger)r
  9583. green:(NSUInteger)g
  9584. blue:(NSUInteger)b
  9585. - label:(NSString *)label
  9586. - displayName:(NSString *)displayName {
  9587. + label:(NSString*)label
  9588. + displayName:(NSString*)displayName {
  9589. self = [super init];
  9590. if (self) {
  9591. _r = r;
  9592. @@ -95,21 +101,25 @@
  9593. @implementation TFLSegmentation
  9594. -- (instancetype)initWithConfidenceMasks:(NSArray<TFLConfidenceMask *> *)confidenceMasks
  9595. - coloredLabels:(NSArray<TFLColoredLabel *> *)coloredLabels {
  9596. +- (instancetype)
  9597. + initWithConfidenceMasks:(NSArray<TFLConfidenceMask*>*)confidenceMasks
  9598. + coloredLabels:(NSArray<TFLColoredLabel*>*)coloredLabels {
  9599. return [self initWithConfidenceMasks:confidenceMasks
  9600. categoryMask:nil
  9601. coloredLabels:coloredLabels];
  9602. }
  9603. -- (instancetype)initWithCategoryMask:(TFLCategoryMask *)categoryMask
  9604. - coloredLabels:(NSArray<TFLColoredLabel *> *)coloredLabels {
  9605. - return [self initWithConfidenceMasks:nil categoryMask:categoryMask coloredLabels:coloredLabels];
  9606. +- (instancetype)initWithCategoryMask:(TFLCategoryMask*)categoryMask
  9607. + coloredLabels:(NSArray<TFLColoredLabel*>*)coloredLabels {
  9608. + return [self initWithConfidenceMasks:nil
  9609. + categoryMask:categoryMask
  9610. + coloredLabels:coloredLabels];
  9611. }
  9612. -- (instancetype)initWithConfidenceMasks:(NSArray<TFLConfidenceMask *> *)confidenceMasks
  9613. - categoryMask:(TFLCategoryMask *)categoryMask
  9614. - coloredLabels:(NSArray<TFLColoredLabel *> *)coloredLabels {
  9615. +- (instancetype)
  9616. + initWithConfidenceMasks:(NSArray<TFLConfidenceMask*>*)confidenceMasks
  9617. + categoryMask:(TFLCategoryMask*)categoryMask
  9618. + coloredLabels:(NSArray<TFLColoredLabel*>*)coloredLabels {
  9619. self = [super init];
  9620. if (self) {
  9621. _confidenceMasks = confidenceMasks;
  9622. @@ -123,7 +133,8 @@
  9623. @implementation TFLSegmentationResult
  9624. -- (instancetype)initWithSegmentations:(NSArray<TFLSegmentation *> *)segmentations {
  9625. +- (instancetype)initWithSegmentations:
  9626. + (NSArray<TFLSegmentation*>*)segmentations {
  9627. self = [super init];
  9628. if (self) {
  9629. _segmentations = segmentations;
  9630. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLBertNLClassifier.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLBertNLClassifier.h
  9631. index 99de5ad04febf..ac81a15ac11c6 100644
  9632. --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLBertNLClassifier.h
  9633. +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLBertNLClassifier.h
  9634. @@ -27,15 +27,17 @@ NS_ASSUME_NONNULL_BEGIN
  9635. @end
  9636. /**
  9637. - * Classifier API for NLClassification tasks with Bert models, categorizes string into different
  9638. - * classes. The API expects a Bert based TFLite model with metadata populated.
  9639. + * Classifier API for NLClassification tasks with Bert models, categorizes
  9640. + * string into different classes. The API expects a Bert based TFLite model with
  9641. + * metadata populated.
  9642. *
  9643. * The metadata should contain the following information:
  9644. * 1 input_process_unit for Wordpiece/Sentencepiece Tokenizer.
  9645. * 3 input tensors with names "ids", "mask" and "segment_ids".
  9646. - * 1 output tensor of type float32[1, 2], with a optionally attached label file. If a label
  9647. - * file is attached, the file should be a plain text file with one label per line, the number
  9648. - * of labels should match the number of categories the model outputs.
  9649. + * 1 output tensor of type float32[1, 2], with a optionally attached label
  9650. + * file. If a label file is attached, the file should be a plain text file with
  9651. + * one label per line, the number of labels should match the number of
  9652. + * categories the model outputs.
  9653. */
  9654. @interface TFLBertNLClassifier : NSObject
  9655. @@ -45,7 +47,7 @@ NS_ASSUME_NONNULL_BEGIN
  9656. * @param modelPath Path to the classification model.
  9657. * @return A TFLBertNLClassifier instance.
  9658. */
  9659. -+ (instancetype)bertNLClassifierWithModelPath:(NSString *)modelPath
  9660. ++ (instancetype)bertNLClassifierWithModelPath:(NSString*)modelPath
  9661. NS_SWIFT_NAME(bertNLClassifier(modelPath:));
  9662. /**
  9663. @@ -54,8 +56,9 @@ NS_ASSUME_NONNULL_BEGIN
  9664. * @param modelPath Path to the classification model.
  9665. * @return A TFLBertNLClassifier instance.
  9666. */
  9667. -+ (instancetype)bertNLClassifierWithModelPath:(NSString *)modelPath
  9668. - options:(TFLBertNLClassifierOptions *)options
  9669. ++ (instancetype)bertNLClassifierWithModelPath:(NSString*)modelPath
  9670. + options:
  9671. + (TFLBertNLClassifierOptions*)options
  9672. NS_SWIFT_NAME(bertNLClassifier(modelPath:options:));
  9673. /**
  9674. @@ -65,7 +68,7 @@ NS_ASSUME_NONNULL_BEGIN
  9675. * @param text input text to the model.
  9676. * @return A NSDictionary of categorization results.
  9677. */
  9678. -- (NSDictionary<NSString *, NSNumber *> *)classifyWithText:(NSString *)text
  9679. +- (NSDictionary<NSString*, NSNumber*>*)classifyWithText:(NSString*)text
  9680. NS_SWIFT_NAME(classify(text:));
  9681. @end
  9682. NS_ASSUME_NONNULL_END
  9683. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLNLClassifier.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLNLClassifier.h
  9684. index ceb8d2ef9a307..41eb0fb76c9ea 100644
  9685. --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLNLClassifier.h
  9686. +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLNLClassifier.h
  9687. @@ -23,14 +23,14 @@ NS_ASSUME_NONNULL_BEGIN
  9688. @property(nonatomic) int inputTensorIndex;
  9689. @property(nonatomic) int outputScoreTensorIndex;
  9690. @property(nonatomic) int outputLabelTensorIndex;
  9691. -@property(nonatomic) NSString *inputTensorName;
  9692. -@property(nonatomic) NSString *outputScoreTensorName;
  9693. -@property(nonatomic) NSString *outputLabelTensorName;
  9694. +@property(nonatomic) NSString* inputTensorName;
  9695. +@property(nonatomic) NSString* outputScoreTensorName;
  9696. +@property(nonatomic) NSString* outputLabelTensorName;
  9697. @end
  9698. /**
  9699. - * Classifier API for natural language classification tasks, categorizes string into different
  9700. - * classes.
  9701. + * Classifier API for natural language classification tasks, categorizes string
  9702. + * into different classes.
  9703. *
  9704. * The API expects a TFLite model with the following input/output tensor:
  9705. *
  9706. @@ -39,25 +39,28 @@ NS_ASSUME_NONNULL_BEGIN
  9707. *
  9708. * Output score tensor
  9709. * (kTfLiteUInt8/kTfLiteInt8/kTfLiteInt16/kTfLiteFloat32/kTfLiteFloat64/kTfLiteBool)
  9710. - * output scores for each class, if type is one of the Int types, dequantize it, if it
  9711. - * is Bool type, convert the values to 0.0 and 1.0 respectively.
  9712. + * output scores for each class, if type is one of the Int types, dequantize
  9713. + * it, if it is Bool type, convert the values to 0.0 and 1.0 respectively.
  9714. *
  9715. - * can have an optional associated file in metadata for labels, the file should be a
  9716. - * plain text file with one label per line, the number of labels should match the number
  9717. - * of categories the model outputs. Output label tensor: optional (kTfLiteString) -
  9718. - * output classname for each class, should be of the same length with scores. If this
  9719. - * tensor is not present, the API uses score indices as classnames. - will be ignored if
  9720. - * output score tensor already has an associated label file.
  9721. + * can have an optional associated file in metadata for labels, the file
  9722. + * should be a plain text file with one label per line, the number of labels
  9723. + * should match the number of categories the model outputs. Output label tensor:
  9724. + * optional (kTfLiteString) - output classname for each class, should be of the
  9725. + * same length with scores. If this tensor is not present, the API uses score
  9726. + * indices as classnames. - will be ignored if output score tensor already has
  9727. + * an associated label file.
  9728. *
  9729. * Optional Output label tensor (kTfLiteString/kTfLiteInt32)
  9730. - * output classname for each class, should be of the same length with scores. If this
  9731. - * tensor is not present, the API uses score indices as classnames.
  9732. + * output classname for each class, should be of the same length with
  9733. + * scores. If this tensor is not present, the API uses score indices as
  9734. + * classnames.
  9735. *
  9736. - * will be ignored if output score tensor already has an associated labe file.
  9737. + * will be ignored if output score tensor already has an associated labe
  9738. + * file.
  9739. *
  9740. - * By default the API tries to find the input/output tensors with default configurations in
  9741. - * TFLNLClassifierOptions, with tensor name prioritized over tensor index. The option is
  9742. - * configurable for different TFLite models.
  9743. + * By default the API tries to find the input/output tensors with default
  9744. + * configurations in TFLNLClassifierOptions, with tensor name prioritized over
  9745. + * tensor index. The option is configurable for different TFLite models.
  9746. */
  9747. @interface TFLNLClassifier : NSObject
  9748. @@ -69,8 +72,8 @@ NS_ASSUME_NONNULL_BEGIN
  9749. *
  9750. * @return A TFLNLClassifier instance.
  9751. */
  9752. -+ (instancetype)nlClassifierWithModelPath:(NSString *)modelPath
  9753. - options:(TFLNLClassifierOptions *)options
  9754. ++ (instancetype)nlClassifierWithModelPath:(NSString*)modelPath
  9755. + options:(TFLNLClassifierOptions*)options
  9756. NS_SWIFT_NAME(nlClassifier(modelPath:options:));
  9757. /**
  9758. @@ -80,7 +83,7 @@ NS_ASSUME_NONNULL_BEGIN
  9759. * @param text input text to the model.
  9760. * @return A NSDictionary of categorization results.
  9761. */
  9762. -- (NSDictionary<NSString *, NSNumber *> *)classifyWithText:(NSString *)text
  9763. +- (NSDictionary<NSString*, NSNumber*>*)classifyWithText:(NSString*)text
  9764. NS_SWIFT_NAME(classify(text:));
  9765. @end
  9766. NS_ASSUME_NONNULL_END
  9767. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/qa/Sources/TFLBertQuestionAnswerer.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/qa/Sources/TFLBertQuestionAnswerer.h
  9768. index 57b7c69c70f62..446e2cb137dd9 100644
  9769. --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/qa/Sources/TFLBertQuestionAnswerer.h
  9770. +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/qa/Sources/TFLBertQuestionAnswerer.h
  9771. @@ -54,13 +54,13 @@ struct TFLPos {
  9772. * @param modelPath The file path to the tflite model.
  9773. * @return A BertQuestionAnswerer instance.
  9774. */
  9775. -+ (instancetype)questionAnswererWithModelPath:(NSString *)modelPath
  9776. ++ (instancetype)questionAnswererWithModelPath:(NSString*)modelPath
  9777. NS_SWIFT_NAME(questionAnswerer(modelPath:));
  9778. /**
  9779. * Answers question based on the context. Could be empty if no answer was found
  9780. * from the given context.
  9781. - *
  9782. + *
  9783. * @param context Context the question bases on.
  9784. * @param question Question to ask.
  9785. *
  9786. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageClassifier.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageClassifier.h
  9787. index f228034147c40..7e38abe002623 100644
  9788. --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageClassifier.h
  9789. +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageClassifier.h
  9790. @@ -31,29 +31,32 @@ NS_SWIFT_NAME(ImageClassifierOptions)
  9791. * Base options that are used for creation of any type of task.
  9792. * @discussion Please see `TFLBaseOptions` for more details.
  9793. */
  9794. -@property(nonatomic, copy) TFLBaseOptions *baseOptions;
  9795. +@property(nonatomic, copy) TFLBaseOptions* baseOptions;
  9796. /**
  9797. * Options that configure the display and filtering of results.
  9798. * @discussion Please see `TFLClassificationOptions` for more details.
  9799. */
  9800. -@property(nonatomic, copy) TFLClassificationOptions *classificationOptions;
  9801. +@property(nonatomic, copy) TFLClassificationOptions* classificationOptions;
  9802. /**
  9803. - * Initializes a new `TFLImageClassifierOptions` with the absolute path to the model file
  9804. - * stored locally on the device, set to the given the model path.
  9805. + * Initializes a new `TFLImageClassifierOptions` with the absolute path to the
  9806. + * model file stored locally on the device, set to the given the model path.
  9807. *
  9808. - * @discussion The external model file, must be a single standalone TFLite file. It could be packed
  9809. - * with TFLite Model Metadata[1] and associated files if exist. Fail to provide the necessary
  9810. - * metadata and associated files might result in errors. Check the [documentation]
  9811. - * (https://www.tensorflow.org/lite/convert/metadata) for each task about the specific requirement.
  9812. + * @discussion The external model file, must be a single standalone TFLite file.
  9813. + * It could be packed with TFLite Model Metadata[1] and associated files if
  9814. + * exist. Fail to provide the necessary metadata and associated files might
  9815. + * result in errors. Check the [documentation]
  9816. + * (https://www.tensorflow.org/lite/convert/metadata) for each task about the
  9817. + * specific requirement.
  9818. *
  9819. - * @param modelPath An absolute path to a TensorFlow Lite model file stored locally on the device.
  9820. + * @param modelPath An absolute path to a TensorFlow Lite model file stored
  9821. + * locally on the device.
  9822. *
  9823. * @return An instance of `TFLImageClassifierOptions` initialized to the given
  9824. * model path.
  9825. */
  9826. -- (instancetype)initWithModelPath:(NSString *)modelPath;
  9827. +- (instancetype)initWithModelPath:(NSString*)modelPath;
  9828. @end
  9829. @@ -64,17 +67,19 @@ NS_SWIFT_NAME(ImageClassifier)
  9830. @interface TFLImageClassifier : NSObject
  9831. /**
  9832. - * Creates a new instance of `TFLImageClassifier` from the given `TFLImageClassifierOptions`.
  9833. + * Creates a new instance of `TFLImageClassifier` from the given
  9834. + * `TFLImageClassifierOptions`.
  9835. *
  9836. * @param options The options to use for configuring the `TFLImageClassifier`.
  9837. - * @param error An optional error parameter populated when there is an error in initializing
  9838. - * the image classifier.
  9839. + * @param error An optional error parameter populated when there is an error in
  9840. + * initializing the image classifier.
  9841. *
  9842. - * @return A new instance of `TFLImageClassifier` with the given options. `nil` if there is an error
  9843. - * in initializing the image classifier.
  9844. + * @return A new instance of `TFLImageClassifier` with the given options. `nil`
  9845. + * if there is an error in initializing the image classifier.
  9846. */
  9847. -+ (nullable instancetype)imageClassifierWithOptions:(TFLImageClassifierOptions *)options
  9848. - error:(NSError **)error
  9849. ++ (nullable instancetype)imageClassifierWithOptions:
  9850. + (TFLImageClassifierOptions*)options
  9851. + error:(NSError**)error
  9852. NS_SWIFT_NAME(classifier(options:));
  9853. + (instancetype)new NS_UNAVAILABLE;
  9854. @@ -82,46 +87,49 @@ NS_SWIFT_NAME(ImageClassifier)
  9855. /**
  9856. * Performs classification on the given GMLImage.
  9857. *
  9858. - * @discussion This method currently supports classification of only the following types of images:
  9859. + * @discussion This method currently supports classification of only the
  9860. + * following types of images:
  9861. * 1. RGB and RGBA images for `GMLImageSourceTypeImage`.
  9862. * 2. kCVPixelFormatType_32BGRA for `GMLImageSourceTypePixelBuffer` and
  9863. - * `GMLImageSourceTypeSampleBuffer`. If you are using `AVCaptureSession` to setup
  9864. - * camera and get the frames for inference, you must request for this format
  9865. - * from AVCaptureVideoDataOutput. Otherwise your classification
  9866. - * results will be wrong.
  9867. + * `GMLImageSourceTypeSampleBuffer`. If you are using `AVCaptureSession` to
  9868. + * setup camera and get the frames for inference, you must request for this
  9869. + * format from AVCaptureVideoDataOutput. Otherwise your classification results
  9870. + * will be wrong.
  9871. *
  9872. * @param image An image to be classified, represented as a `GMLImage`.
  9873. *
  9874. - * @return A TFLClassificationResult with one set of results per image classifier head. `nil` if
  9875. - * there is an error encountered during classification. Please see `TFLClassificationResult` for
  9876. - * more details.
  9877. + * @return A TFLClassificationResult with one set of results per image
  9878. + * classifier head. `nil` if there is an error encountered during
  9879. + * classification. Please see `TFLClassificationResult` for more details.
  9880. */
  9881. -- (nullable TFLClassificationResult *)classifyWithGMLImage:(GMLImage *)image
  9882. - error:(NSError **)error
  9883. +- (nullable TFLClassificationResult*)classifyWithGMLImage:(GMLImage*)image
  9884. + error:(NSError**)error
  9885. NS_SWIFT_NAME(classify(mlImage:));
  9886. /**
  9887. - * Performs classification on the pixels within the specified region of interest of the given
  9888. - * `GMLImage`.
  9889. + * Performs classification on the pixels within the specified region of interest
  9890. + * of the given `GMLImage`.
  9891. *
  9892. - * @discussion This method currently supports inference on only following type of images:
  9893. + * @discussion This method currently supports inference on only following type
  9894. + * of images:
  9895. * 1. RGB and RGBA images for `GMLImageSourceTypeImage`.
  9896. * 2. kCVPixelFormatType_32BGRA for `GMLImageSourceTypePixelBuffer` and
  9897. - * `GMLImageSourceTypeSampleBuffer`. If you are using `AVCaptureSession` to setup
  9898. - * camera and get the frames for inference, you must request for this format
  9899. - * from AVCaptureVideoDataOutput. Otherwise your classification
  9900. - * results will be wrong.
  9901. + * `GMLImageSourceTypeSampleBuffer`. If you are using `AVCaptureSession` to
  9902. + * setup camera and get the frames for inference, you must request for this
  9903. + * format from AVCaptureVideoDataOutput. Otherwise your classification results
  9904. + * will be wrong.
  9905. *
  9906. * @param image An image to be classified, represented as a `GMLImage`.
  9907. - * @param roi A CGRect specifying the region of interest within the given `GMLImage`, on which
  9908. - * classification should be performed.
  9909. + * @param roi A CGRect specifying the region of interest within the given
  9910. + * `GMLImage`, on which classification should be performed.
  9911. *
  9912. - * @return A TFLClassificationResult with one set of results per image classifier head. `nil` if
  9913. - * there is an error encountered during classification.
  9914. + * @return A TFLClassificationResult with one set of results per image
  9915. + * classifier head. `nil` if there is an error encountered during
  9916. + * classification.
  9917. */
  9918. -- (nullable TFLClassificationResult *)classifyWithGMLImage:(GMLImage *)image
  9919. - regionOfInterest:(CGRect)roi
  9920. - error:(NSError **)error
  9921. +- (nullable TFLClassificationResult*)classifyWithGMLImage:(GMLImage*)image
  9922. + regionOfInterest:(CGRect)roi
  9923. + error:(NSError**)error
  9924. NS_SWIFT_NAME(classify(mlImage:regionOfInterest:));
  9925. - (instancetype)init NS_UNAVAILABLE;
  9926. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageClassifier.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageClassifier.m
  9927. index f8c09527bd902..79ad474054525 100644
  9928. --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageClassifier.m
  9929. +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageClassifier.m
  9930. @@ -40,7 +40,7 @@
  9931. return self;
  9932. }
  9933. -- (instancetype)initWithModelPath:(NSString *)modelPath {
  9934. +- (instancetype)initWithModelPath:(NSString*)modelPath {
  9935. self = [self init];
  9936. if (self) {
  9937. self.baseOptions.modelFile.filePath = modelPath;
  9938. @@ -63,40 +63,45 @@
  9939. return self;
  9940. }
  9941. -+ (nullable instancetype)imageClassifierWithOptions:(TFLImageClassifierOptions *)options
  9942. - error:(NSError **)error {
  9943. ++ (nullable instancetype)imageClassifierWithOptions:
  9944. + (TFLImageClassifierOptions*)options
  9945. + error:(NSError**)error {
  9946. if (!options) {
  9947. - [TFLCommonUtils createCustomError:error
  9948. - withCode:TFLSupportErrorCodeInvalidArgumentError
  9949. - description:@"TFLImageClassifierOptions argument cannot be nil."];
  9950. + [TFLCommonUtils
  9951. + createCustomError:error
  9952. + withCode:TFLSupportErrorCodeInvalidArgumentError
  9953. + description:@"TFLImageClassifierOptions argument cannot be nil."];
  9954. return nil;
  9955. }
  9956. TfLiteImageClassifierOptions cOptions = TfLiteImageClassifierOptionsCreate();
  9957. - if (![options.classificationOptions copyToCOptions:&(cOptions.classification_options)
  9958. - error:error]) {
  9959. - [options.classificationOptions
  9960. - deleteAllocatedMemoryOfClassificationOptions:&(cOptions.classification_options)];
  9961. + if (![options.classificationOptions
  9962. + copyToCOptions:&(cOptions.classification_options)
  9963. + error:error]) {
  9964. + [options.classificationOptions deleteAllocatedMemoryOfClassificationOptions:
  9965. + &(cOptions.classification_options)];
  9966. return nil;
  9967. }
  9968. [options.baseOptions copyToCOptions:&(cOptions.base_options)];
  9969. - TfLiteSupportError *cCreateClassifierError = NULL;
  9970. - TfLiteImageClassifier *cImageClassifier =
  9971. + TfLiteSupportError* cCreateClassifierError = NULL;
  9972. + TfLiteImageClassifier* cImageClassifier =
  9973. TfLiteImageClassifierFromOptions(&cOptions, &cCreateClassifierError);
  9974. - [options.classificationOptions
  9975. - deleteAllocatedMemoryOfClassificationOptions:&(cOptions.classification_options)];
  9976. + [options.classificationOptions deleteAllocatedMemoryOfClassificationOptions:
  9977. + &(cOptions.classification_options)];
  9978. - // Populate iOS error if TfliteSupportError is not null and afterwards delete it.
  9979. + // Populate iOS error if TfliteSupportError is not null and afterwards delete
  9980. + // it.
  9981. if (![TFLCommonUtils checkCError:cCreateClassifierError toError:error]) {
  9982. TfLiteSupportErrorDelete(cCreateClassifierError);
  9983. }
  9984. - // Return nil if classifier evaluates to nil. If an error was generted by the C layer, it has
  9985. - // already been populated to an NSError and deleted before returning from the method.
  9986. + // Return nil if classifier evaluates to nil. If an error was generted by the
  9987. + // C layer, it has already been populated to an NSError and deleted before
  9988. + // returning from the method.
  9989. if (!cImageClassifier) {
  9990. return nil;
  9991. }
  9992. @@ -104,16 +109,16 @@
  9993. return [[TFLImageClassifier alloc] initWithImageClassifier:cImageClassifier];
  9994. }
  9995. -- (nullable TFLClassificationResult *)classifyWithGMLImage:(GMLImage *)image
  9996. - error:(NSError **)error {
  9997. +- (nullable TFLClassificationResult*)classifyWithGMLImage:(GMLImage*)image
  9998. + error:(NSError**)error {
  9999. return [self classifyWithGMLImage:image
  10000. regionOfInterest:CGRectMake(0, 0, image.width, image.height)
  10001. error:error];
  10002. }
  10003. -- (nullable TFLClassificationResult *)classifyWithGMLImage:(GMLImage *)image
  10004. - regionOfInterest:(CGRect)roi
  10005. - error:(NSError **)error {
  10006. +- (nullable TFLClassificationResult*)classifyWithGMLImage:(GMLImage*)image
  10007. + regionOfInterest:(CGRect)roi
  10008. + error:(NSError**)error {
  10009. if (!image) {
  10010. [TFLCommonUtils createCustomError:error
  10011. withCode:TFLSupportErrorCodeInvalidArgumentError
  10012. @@ -121,7 +126,7 @@
  10013. return nil;
  10014. }
  10015. - TfLiteFrameBuffer *cFrameBuffer = [image cFrameBufferWithError:error];
  10016. + TfLiteFrameBuffer* cFrameBuffer = [image cFrameBufferWithError:error];
  10017. if (!cFrameBuffer) {
  10018. return nil;
  10019. @@ -132,7 +137,7 @@
  10020. .width = roi.size.width,
  10021. .height = roi.size.height};
  10022. - TfLiteSupportError *classifyError = NULL;
  10023. + TfLiteSupportError* classifyError = NULL;
  10024. TfLiteClassificationResult *cClassificationResult = TfLiteImageClassifierClassifyWithRoi(
  10025. _imageClassifier, cFrameBuffer, &boundingBox, &classifyError);
  10026. @@ -147,8 +152,9 @@
  10027. TfLiteSupportErrorDelete(classifyError);
  10028. }
  10029. - // Return nil if C result evaluates to nil. If an error was generted by the C layer, it has
  10030. - // already been populated to an NSError and deleted before returning from the method.
  10031. + // Return nil if C result evaluates to nil. If an error was generted by the C
  10032. + // layer, it has already been populated to an NSError and deleted before
  10033. + // returning from the method.
  10034. if (!cClassificationResult) {
  10035. return nil;
  10036. }
  10037. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageSegmenter.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageSegmenter.h
  10038. index 7b556dcd312e2..234e10d68b319 100644
  10039. --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageSegmenter.h
  10040. +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageSegmenter.h
  10041. @@ -20,9 +20,10 @@
  10042. NS_ASSUME_NONNULL_BEGIN
  10043. /**
  10044. - * Specifies the type of the output segmentation mask to be returned as the result
  10045. - * of the image segmentation operation. This directs the `TFLImageSegmenter` to
  10046. - * choose the type of post-processing to be performed on the raw model results.
  10047. + * Specifies the type of the output segmentation mask to be returned as the
  10048. + * result of the image segmentation operation. This directs the
  10049. + * `TFLImageSegmenter` to choose the type of post-processing to be performed on
  10050. + * the raw model results.
  10051. */
  10052. typedef NS_ENUM(NSUInteger, TFLOutputType) {
  10053. /** Unspecified output type. */
  10054. @@ -52,7 +53,7 @@ NS_SWIFT_NAME(ImageSegmenterOptions)
  10055. * Base options that is used for creation of any type of task.
  10056. * @discussion Please see `TFLBaseOptions` for more details.
  10057. */
  10058. -@property(nonatomic, copy) TFLBaseOptions *baseOptions;
  10059. +@property(nonatomic, copy) TFLBaseOptions* baseOptions;
  10060. /**
  10061. * Specifies the type of output segmentation mask to be returned as a result
  10062. @@ -63,24 +64,26 @@ NS_SWIFT_NAME(ImageSegmenterOptions)
  10063. /**
  10064. * Display names local for display names
  10065. */
  10066. -@property(nonatomic, copy) NSString *displayNamesLocale;
  10067. +@property(nonatomic, copy) NSString* displayNamesLocale;
  10068. /**
  10069. - * Initializes a new `TFLImageSegmenterOptions` with the absolute path to the model file
  10070. - * stored locally on the device, set to the given the model path.
  10071. + * Initializes a new `TFLImageSegmenterOptions` with the absolute path to the
  10072. + * model file stored locally on the device, set to the given the model path.
  10073. * .
  10074. * @discussion The external model file, must be a single standalone TFLite
  10075. * file. It could be packed with TFLite Model Metadata[1] and associated files
  10076. * if exist. Fail to provide the necessary metadata and associated files might
  10077. - * result in errors. Check the [documentation](https://www.tensorflow.org/lite/convert/metadata)
  10078. - * for each task about the specific requirement.
  10079. + * result in errors. Check the
  10080. + * [documentation](https://www.tensorflow.org/lite/convert/metadata) for each
  10081. + * task about the specific requirement.
  10082. *
  10083. - * @param modelPath An absolute path to a TensorFlow Lite model file stored locally on the device.
  10084. + * @param modelPath An absolute path to a TensorFlow Lite model file stored
  10085. + * locally on the device.
  10086. *
  10087. * @return An instance of `TFLImageSegmenterOptions` initialized to the given
  10088. * model path.
  10089. */
  10090. -- (instancetype)initWithModelPath:(NSString *)modelPath;
  10091. +- (instancetype)initWithModelPath:(NSString*)modelPath;
  10092. @end
  10093. @@ -88,17 +91,19 @@ NS_SWIFT_NAME(ImageSegmenter)
  10094. @interface TFLImageSegmenter : NSObject
  10095. /**
  10096. - * Creates a new instance of `TFLImageSegmenter` from the given `TFLImageSegmenterOptions`.
  10097. + * Creates a new instance of `TFLImageSegmenter` from the given
  10098. + * `TFLImageSegmenterOptions`.
  10099. *
  10100. * @param options The options to use for configuring the `TFLImageSegmenter`.
  10101. - * @param error An optional error parameter populated when there is an error in initializing
  10102. - * the image segmenter.
  10103. + * @param error An optional error parameter populated when there is an error in
  10104. + * initializing the image segmenter.
  10105. *
  10106. - * @return A new instance of `TFLImageSegmenter` with the given options. `nil` if there is an error
  10107. - * in initializing the image segmenter.
  10108. + * @return A new instance of `TFLImageSegmenter` with the given options. `nil`
  10109. + * if there is an error in initializing the image segmenter.
  10110. */
  10111. -+ (nullable instancetype)imageSegmenterWithOptions:(nonnull TFLImageSegmenterOptions *)options
  10112. - error:(NSError **)error
  10113. ++ (nullable instancetype)imageSegmenterWithOptions:
  10114. + (nonnull TFLImageSegmenterOptions*)options
  10115. + error:(NSError**)error
  10116. NS_SWIFT_NAME(segmenter(options:));
  10117. + (instancetype)new NS_UNAVAILABLE;
  10118. @@ -106,22 +111,23 @@ NS_SWIFT_NAME(ImageSegmenter)
  10119. /**
  10120. * Performs segmentation on the given GMLImage.
  10121. *
  10122. - * @discussion This method currently supports segmentation of only the following types of images:
  10123. + * @discussion This method currently supports segmentation of only the following
  10124. + * types of images:
  10125. * 1. RGB and RGBA images for `GMLImageSourceTypeImage`.
  10126. * 2. kCVPixelFormatType_32BGRA for `GMLImageSourceTypePixelBuffer` and
  10127. - * `GMLImageSourceTypeSampleBuffer`. If you are using `AVCaptureSession` to setup
  10128. - * camera and get the frames for inference, you must request for this format
  10129. - * from AVCaptureVideoDataOutput. Otherwise your segmentation
  10130. - * results will be wrong.
  10131. + * `GMLImageSourceTypeSampleBuffer`. If you are using `AVCaptureSession` to
  10132. + * setup camera and get the frames for inference, you must request for this
  10133. + * format from AVCaptureVideoDataOutput. Otherwise your segmentation results
  10134. + * will be wrong.
  10135. *
  10136. * @param image An image to be segmented, represented as a `GMLImage`.
  10137. *
  10138. - * @return A TFLSegmentationResult that holds the segmentation masks returned by the image
  10139. - * segmentation task. `nil` if there is an error encountered during segmentation. Please see
  10140. - * `TFLSegmentationResult` for more details.
  10141. + * @return A TFLSegmentationResult that holds the segmentation masks returned by
  10142. + * the image segmentation task. `nil` if there is an error encountered during
  10143. + * segmentation. Please see `TFLSegmentationResult` for more details.
  10144. */
  10145. -- (nullable TFLSegmentationResult *)segmentWithGMLImage:(GMLImage *)image
  10146. - error:(NSError **)error
  10147. +- (nullable TFLSegmentationResult*)segmentWithGMLImage:(GMLImage*)image
  10148. + error:(NSError**)error
  10149. NS_SWIFT_NAME(segment(mlImage:));
  10150. - (instancetype)init NS_UNAVAILABLE;
  10151. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageSegmenter.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageSegmenter.m
  10152. index 70068bfdd645a..7b7f3211df952 100644
  10153. --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageSegmenter.m
  10154. +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageSegmenter.m
  10155. @@ -35,7 +35,7 @@
  10156. return self;
  10157. }
  10158. -- (instancetype)initWithModelPath:(NSString *)modelPath {
  10159. +- (instancetype)initWithModelPath:(NSString*)modelPath {
  10160. self = [self init];
  10161. if (self) {
  10162. self.baseOptions.modelFile.filePath = modelPath;
  10163. @@ -47,14 +47,14 @@
  10164. @implementation TFLImageSegmenter {
  10165. /** ImageSegmenter backed by C API */
  10166. - TfLiteImageSegmenter *_imageSegmenter;
  10167. + TfLiteImageSegmenter* _imageSegmenter;
  10168. }
  10169. - (void)dealloc {
  10170. TfLiteImageSegmenterDelete(_imageSegmenter);
  10171. }
  10172. -- (instancetype)initWithImageSegmenter:(TfLiteImageSegmenter *)imageSegmenter {
  10173. +- (instancetype)initWithImageSegmenter:(TfLiteImageSegmenter*)imageSegmenter {
  10174. self = [super init];
  10175. if (self) {
  10176. _imageSegmenter = imageSegmenter;
  10177. @@ -62,8 +62,9 @@
  10178. return self;
  10179. }
  10180. -+ (nullable instancetype)imageSegmenterWithOptions:(nonnull TFLImageSegmenterOptions *)options
  10181. - error:(NSError **)error {
  10182. ++ (nullable instancetype)imageSegmenterWithOptions:
  10183. + (nonnull TFLImageSegmenterOptions*)options
  10184. + error:(NSError**)error {
  10185. TfLiteImageSegmenterOptions cOptions = TfLiteImageSegmenterOptionsCreate();
  10186. [options.baseOptions copyToCOptions:&(cOptions.base_options)];
  10187. @@ -71,20 +72,22 @@
  10188. if (options.displayNamesLocale) {
  10189. if (options.displayNamesLocale.UTF8String) {
  10190. - cOptions.display_names_locale = strdup(options.displayNamesLocale.UTF8String);
  10191. + cOptions.display_names_locale =
  10192. + strdup(options.displayNamesLocale.UTF8String);
  10193. if (!cOptions.display_names_locale) {
  10194. exit(-1); // Memory Allocation Failed.
  10195. }
  10196. } else {
  10197. - [TFLCommonUtils createCustomError:error
  10198. - withCode:TFLSupportErrorCodeInvalidArgumentError
  10199. - description:@"Could not convert (NSString *) to (char *)."];
  10200. + [TFLCommonUtils
  10201. + createCustomError:error
  10202. + withCode:TFLSupportErrorCodeInvalidArgumentError
  10203. + description:@"Could not convert (NSString *) to (char *)."];
  10204. return nil;
  10205. }
  10206. }
  10207. - TfLiteSupportError *cCreateImageSegmenterError = nil;
  10208. - TfLiteImageSegmenter *cImageSegmenter =
  10209. + TfLiteSupportError* cCreateImageSegmenterError = nil;
  10210. + TfLiteImageSegmenter* cImageSegmenter =
  10211. TfLiteImageSegmenterFromOptions(&cOptions, &cCreateImageSegmenterError);
  10212. // Freeing memory of allocated string.
  10213. @@ -94,16 +97,17 @@
  10214. TfLiteSupportErrorDelete(cCreateImageSegmenterError);
  10215. }
  10216. - // Return nil if C object detector evaluates to nil. If an error was generted by the C layer, it
  10217. - // has already been populated to an NSError and deleted before returning from the method.
  10218. + // Return nil if C object detector evaluates to nil. If an error was generted
  10219. + // by the C layer, it has already been populated to an NSError and deleted
  10220. + // before returning from the method.
  10221. if (!cImageSegmenter) {
  10222. return nil;
  10223. }
  10224. return [[TFLImageSegmenter alloc] initWithImageSegmenter:cImageSegmenter];
  10225. }
  10226. -- (nullable TFLSegmentationResult *)segmentWithGMLImage:(GMLImage *)image
  10227. - error:(NSError **)error {
  10228. +- (nullable TFLSegmentationResult*)segmentWithGMLImage:(GMLImage*)image
  10229. + error:(NSError**)error {
  10230. if (!image) {
  10231. [TFLCommonUtils createCustomError:error
  10232. withCode:TFLSupportErrorCodeInvalidArgumentError
  10233. @@ -111,15 +115,15 @@
  10234. return nil;
  10235. }
  10236. - TfLiteFrameBuffer *cFrameBuffer = [image cFrameBufferWithError:error];
  10237. + TfLiteFrameBuffer* cFrameBuffer = [image cFrameBufferWithError:error];
  10238. if (!cFrameBuffer) {
  10239. return nil;
  10240. }
  10241. - TfLiteSupportError *cSegmentError = nil;
  10242. - TfLiteSegmentationResult *cSegmentationResult =
  10243. - TfLiteImageSegmenterSegment(_imageSegmenter, cFrameBuffer, &cSegmentError);
  10244. + TfLiteSupportError* cSegmentError = nil;
  10245. + TfLiteSegmentationResult* cSegmentationResult = TfLiteImageSegmenterSegment(
  10246. + _imageSegmenter, cFrameBuffer, &cSegmentError);
  10247. free(cFrameBuffer->buffer);
  10248. cFrameBuffer->buffer = nil;
  10249. @@ -132,13 +136,14 @@
  10250. TfLiteSupportErrorDelete(cSegmentError);
  10251. }
  10252. - // Return nil if C result evaluates to nil. If an error was generted by the C layer, it has
  10253. - // already been populated to an NSError and deleted before returning from the method.
  10254. + // Return nil if C result evaluates to nil. If an error was generted by the C
  10255. + // layer, it has already been populated to an NSError and deleted before
  10256. + // returning from the method.
  10257. if (!cSegmentationResult) {
  10258. return nil;
  10259. }
  10260. - TFLSegmentationResult *segmentationResult =
  10261. + TFLSegmentationResult* segmentationResult =
  10262. [TFLSegmentationResult segmentationResultWithCResult:cSegmentationResult];
  10263. TfLiteSegmentationResultDelete(cSegmentationResult);
  10264. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLObjectDetector.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLObjectDetector.h
  10265. index 5e3a0e7186cfe..db76c90cc6868 100644
  10266. --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLObjectDetector.h
  10267. +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLObjectDetector.h
  10268. @@ -30,28 +30,31 @@ NS_SWIFT_NAME(ObjectDetectorOptions)
  10269. * Base options that is used for creation of any type of task.
  10270. * @discussion Please see `TFLBaseOptions` for more details.
  10271. */
  10272. -@property(nonatomic, copy) TFLBaseOptions *baseOptions;
  10273. +@property(nonatomic, copy) TFLBaseOptions* baseOptions;
  10274. /**
  10275. * Options that configure the display and filtering of results.
  10276. * @discussion Please see `TFLClassificationOptions` for more details.
  10277. */
  10278. -@property(nonatomic, copy) TFLClassificationOptions *classificationOptions;
  10279. +@property(nonatomic, copy) TFLClassificationOptions* classificationOptions;
  10280. /**
  10281. - * Initializes a new `TFLObjectDetectorOptions` with the absolute path to the model file
  10282. - * stored locally on the device, set to the given the model path.
  10283. + * Initializes a new `TFLObjectDetectorOptions` with the absolute path to the
  10284. + * model file stored locally on the device, set to the given the model path.
  10285. *
  10286. - * @discussion The external model file, must be a single standalone TFLite file. It could be packed
  10287. - * with TFLite Model Metadata[1] and associated files if exist. Fail to provide the necessary
  10288. - * metadata and associated files might result in errors. Check the [documentation]
  10289. - * (https://www.tensorflow.org/lite/convert/metadata) for each task about the specific requirement.
  10290. + * @discussion The external model file, must be a single standalone TFLite file.
  10291. + * It could be packed with TFLite Model Metadata[1] and associated files if
  10292. + * exist. Fail to provide the necessary metadata and associated files might
  10293. + * result in errors. Check the [documentation]
  10294. + * (https://www.tensorflow.org/lite/convert/metadata) for each task about the
  10295. + * specific requirement.
  10296. *
  10297. - * @param modelPath An absolute path to a TensorFlow Lite model file stored locally on the device.
  10298. + * @param modelPath An absolute path to a TensorFlow Lite model file stored
  10299. + * locally on the device.
  10300. * @return An instance of `TFLObjectDetectorOptions` initialized to the given
  10301. * model path.
  10302. */
  10303. -- (instancetype)initWithModelPath:(NSString *)modelPath;
  10304. +- (instancetype)initWithModelPath:(NSString*)modelPath;
  10305. @end
  10306. @@ -59,40 +62,43 @@ NS_SWIFT_NAME(ObjectDetector)
  10307. @interface TFLObjectDetector : NSObject
  10308. /**
  10309. - * Creates a new instance of `TFLObjectDetector` from the given `TFLObjectDetectorOptions`.
  10310. + * Creates a new instance of `TFLObjectDetector` from the given
  10311. + * `TFLObjectDetectorOptions`.
  10312. *
  10313. * @param options The options to use for configuring the `TFLObjectDetector`.
  10314. - * @param error An optional error parameter populated when there is an error in initializing
  10315. - * the object detector.
  10316. + * @param error An optional error parameter populated when there is an error in
  10317. + * initializing the object detector.
  10318. *
  10319. - * @return A new instance of `TFLObjectDetector` with the given options. `nil` if there is an error
  10320. - * in initializing the object detector.
  10321. + * @return A new instance of `TFLObjectDetector` with the given options. `nil`
  10322. + * if there is an error in initializing the object detector.
  10323. */
  10324. -+ (nullable instancetype)objectDetectorWithOptions:(TFLObjectDetectorOptions *)options
  10325. - error:(NSError **)error
  10326. ++ (nullable instancetype)objectDetectorWithOptions:
  10327. + (TFLObjectDetectorOptions*)options
  10328. + error:(NSError**)error
  10329. NS_SWIFT_NAME(detector(options:));
  10330. + (instancetype)new NS_UNAVAILABLE;
  10331. /**
  10332. * Performs object detection on the given GMLImage.
  10333. - * @discussion This method currently supports object detection on only the following types of
  10334. - * images:
  10335. + * @discussion This method currently supports object detection on only the
  10336. + * following types of images:
  10337. * 1. RGB and RGBA images for `GMLImageSourceTypeImage`.
  10338. * 2. `kCVPixelFormatType_32BGRA` for `GMLImageSourceTypePixelBuffer` and
  10339. - * `GMLImageSourceTypeSampleBuffer`. If you are using `AVCaptureSession` to setup
  10340. - * camera and get the frames for inference, you must request for this format
  10341. - * from AVCaptureVideoDataOutput. Otherwise your object detection
  10342. - * results will be wrong.
  10343. + * `GMLImageSourceTypeSampleBuffer`. If you are using `AVCaptureSession` to
  10344. + * setup camera and get the frames for inference, you must request for this
  10345. + * format from AVCaptureVideoDataOutput. Otherwise your object detection results
  10346. + * will be wrong.
  10347. *
  10348. - * @param image An image on which object detection is to be performed, represented as a `GMLImage`.
  10349. + * @param image An image on which object detection is to be performed,
  10350. + * represented as a `GMLImage`.
  10351. *
  10352. - * @return A `TFLDetectionResult` holding an array of TFLDetection objects, each having a bounding
  10353. - * box specifying the region the were detected in and an array of predicted classes. Please see
  10354. - * `TFLDetectionResult` for more details.
  10355. + * @return A `TFLDetectionResult` holding an array of TFLDetection objects, each
  10356. + * having a bounding box specifying the region the were detected in and an array
  10357. + * of predicted classes. Please see `TFLDetectionResult` for more details.
  10358. */
  10359. -- (nullable TFLDetectionResult *)detectWithGMLImage:(GMLImage *)image
  10360. - error:(NSError **)error
  10361. +- (nullable TFLDetectionResult*)detectWithGMLImage:(GMLImage*)image
  10362. + error:(NSError**)error
  10363. NS_SWIFT_NAME(detect(mlImage:));
  10364. - (instancetype)init NS_UNAVAILABLE;
  10365. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLObjectDetector.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLObjectDetector.m
  10366. index 31cb241a2a448..def2e5b0b4877 100644
  10367. --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLObjectDetector.m
  10368. +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLObjectDetector.m
  10369. @@ -40,7 +40,7 @@
  10370. return self;
  10371. }
  10372. -- (instancetype)initWithModelPath:(NSString *)modelPath {
  10373. +- (instancetype)initWithModelPath:(NSString*)modelPath {
  10374. self = [self init];
  10375. if (self) {
  10376. self.baseOptions.modelFile.filePath = modelPath;
  10377. @@ -63,40 +63,45 @@
  10378. return self;
  10379. }
  10380. -+ (nullable instancetype)objectDetectorWithOptions:(TFLObjectDetectorOptions *)options
  10381. - error:(NSError **)error {
  10382. ++ (nullable instancetype)objectDetectorWithOptions:
  10383. + (TFLObjectDetectorOptions*)options
  10384. + error:(NSError**)error {
  10385. if (!options) {
  10386. - [TFLCommonUtils createCustomError:error
  10387. - withCode:TFLSupportErrorCodeInvalidArgumentError
  10388. - description:@"TFLObjectDetectorOptions argument cannot be nil."];
  10389. + [TFLCommonUtils
  10390. + createCustomError:error
  10391. + withCode:TFLSupportErrorCodeInvalidArgumentError
  10392. + description:@"TFLObjectDetectorOptions argument cannot be nil."];
  10393. return nil;
  10394. }
  10395. TfLiteObjectDetectorOptions cOptions = TfLiteObjectDetectorOptionsCreate();
  10396. - if (![options.classificationOptions copyToCOptions:&(cOptions.classification_options)
  10397. - error:error]) {
  10398. + if (![options.classificationOptions
  10399. + copyToCOptions:&(cOptions.classification_options)
  10400. + error:error]) {
  10401. // Deallocating any allocated memory on failure.
  10402. - [options.classificationOptions
  10403. - deleteAllocatedMemoryOfClassificationOptions:&(cOptions.classification_options)];
  10404. + [options.classificationOptions deleteAllocatedMemoryOfClassificationOptions:
  10405. + &(cOptions.classification_options)];
  10406. return nil;
  10407. }
  10408. [options.baseOptions copyToCOptions:&(cOptions.base_options)];
  10409. - TfLiteSupportError *cCreateObjectDetectorError = nil;
  10410. - TfLiteObjectDetector *cObjectDetector =
  10411. + TfLiteSupportError* cCreateObjectDetectorError = nil;
  10412. + TfLiteObjectDetector* cObjectDetector =
  10413. TfLiteObjectDetectorFromOptions(&cOptions, &cCreateObjectDetectorError);
  10414. - [options.classificationOptions
  10415. - deleteAllocatedMemoryOfClassificationOptions:&(cOptions.classification_options)];
  10416. + [options.classificationOptions deleteAllocatedMemoryOfClassificationOptions:
  10417. + &(cOptions.classification_options)];
  10418. - // Populate iOS error if TfliteSupportError is not null and afterwards delete it.
  10419. + // Populate iOS error if TfliteSupportError is not null and afterwards delete
  10420. + // it.
  10421. if (![TFLCommonUtils checkCError:cCreateObjectDetectorError toError:error]) {
  10422. TfLiteSupportErrorDelete(cCreateObjectDetectorError);
  10423. }
  10424. - // Return nil if C object detector evaluates to nil. If an error was generted by the C layer, it
  10425. - // has already been populated to an NSError and deleted before returning from the method.
  10426. + // Return nil if C object detector evaluates to nil. If an error was generted
  10427. + // by the C layer, it has already been populated to an NSError and deleted
  10428. + // before returning from the method.
  10429. if (!cObjectDetector) {
  10430. return nil;
  10431. }
  10432. @@ -104,8 +109,8 @@
  10433. return [[TFLObjectDetector alloc] initWithObjectDetector:cObjectDetector];
  10434. }
  10435. -- (nullable TFLDetectionResult *)detectWithGMLImage:(GMLImage *)image
  10436. - error:(NSError **)error {
  10437. +- (nullable TFLDetectionResult*)detectWithGMLImage:(GMLImage*)image
  10438. + error:(NSError**)error {
  10439. if (!image) {
  10440. [TFLCommonUtils createCustomError:error
  10441. withCode:TFLSupportErrorCodeInvalidArgumentError
  10442. @@ -113,14 +118,14 @@
  10443. return nil;
  10444. }
  10445. - TfLiteFrameBuffer *cFrameBuffer = [image cFrameBufferWithError:error];
  10446. + TfLiteFrameBuffer* cFrameBuffer = [image cFrameBufferWithError:error];
  10447. if (!cFrameBuffer) {
  10448. return nil;
  10449. }
  10450. - TfLiteSupportError *cDetectError = nil;
  10451. - TfLiteDetectionResult *cDetectionResult =
  10452. + TfLiteSupportError* cDetectError = nil;
  10453. + TfLiteDetectionResult* cDetectionResult =
  10454. TfLiteObjectDetectorDetect(_objectDetector, cFrameBuffer, &cDetectError);
  10455. free(cFrameBuffer->buffer);
  10456. @@ -134,8 +139,9 @@
  10457. TfLiteSupportErrorDelete(cDetectError);
  10458. }
  10459. - // Return nil if C result evaluates to nil. If an error was generted by the C layer, it has
  10460. - // already been populated to an NSError and deleted before returning from the method.
  10461. + // Return nil if C result evaluates to nil. If an error was generted by the C
  10462. + // layer, it has already been populated to an NSError and deleted before
  10463. + // returning from the method.
  10464. if (!cDetectionResult) {
  10465. return nil;
  10466. }
  10467. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/utils/sources/GMLImage+Utils.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/utils/sources/GMLImage+Utils.h
  10468. index 77c3e33185b9f..8524903b36602 100644
  10469. --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/utils/sources/GMLImage+Utils.h
  10470. +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/utils/sources/GMLImage+Utils.h
  10471. @@ -36,7 +36,7 @@ NS_ASSUME_NONNULL_BEGIN
  10472. * @return The TfLiteFrameBuffer created from the gmlImage which can be used
  10473. * with the TF Lite Task Vision C library.
  10474. */
  10475. -- (nullable TfLiteFrameBuffer *)cFrameBufferWithError:(NSError *_Nullable *)error;
  10476. +- (nullable TfLiteFrameBuffer*)cFrameBufferWithError:(NSError* _Nullable*)error;
  10477. /**
  10478. * Gets grayscale pixel buffer from GMLImage if source type is
  10479. @@ -61,9 +61,9 @@ NS_ASSUME_NONNULL_BEGIN
  10480. * @return The GMLImage object contains the loaded image. This method returns
  10481. * nil if it cannot load the image.
  10482. */
  10483. -+ (nullable GMLImage *)imageFromBundleWithClass:(Class)classObject
  10484. - fileName:(NSString *)name
  10485. - ofType:(NSString *)type
  10486. ++ (nullable GMLImage*)imageFromBundleWithClass:(Class)classObject
  10487. + fileName:(NSString*)name
  10488. + ofType:(NSString*)type
  10489. NS_SWIFT_NAME(imageFromBundle(class:filename:type:));
  10490. @end
  10491. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/utils/sources/GMLImage+Utils.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/utils/sources/GMLImage+Utils.m
  10492. index d1ab5105448fe..532f75ef25a6c 100644
  10493. --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/utils/sources/GMLImage+Utils.m
  10494. +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/utils/sources/GMLImage+Utils.m
  10495. @@ -25,35 +25,38 @@
  10496. @interface TFLCVPixelBufferUtils : NSObject
  10497. -+ (TfLiteFrameBuffer *)cFrameBufferWithWidth:(int)width
  10498. - height:(int)height
  10499. - frameBufferFormat:(enum TfLiteFrameBufferFormat)frameBufferFormat
  10500. - buffer:(uint8_t *)buffer
  10501. - error:(NSError **)error;
  10502. ++ (TfLiteFrameBuffer*)cFrameBufferWithWidth:(int)width
  10503. + height:(int)height
  10504. + frameBufferFormat:
  10505. + (enum TfLiteFrameBufferFormat)frameBufferFormat
  10506. + buffer:(uint8_t*)buffer
  10507. + error:(NSError**)error;
  10508. -+ (TfLiteFrameBuffer *)cFramebufferFromCVPixelBuffer:(CVPixelBufferRef)pixelBuffer
  10509. - error:(NSError **)error;
  10510. ++ (TfLiteFrameBuffer*)cFramebufferFromCVPixelBuffer:
  10511. + (CVPixelBufferRef)pixelBuffer
  10512. + error:(NSError**)error;
  10513. @end
  10514. @interface UIImage (RawPixelDataUtils)
  10515. -- (TfLiteFrameBuffer *)frameBufferWithError:(NSError **)error;
  10516. +- (TfLiteFrameBuffer*)frameBufferWithError:(NSError**)error;
  10517. - (CVPixelBufferRef)grayScalePixelBuffer;
  10518. @end
  10519. @implementation TFLCVPixelBufferUtils
  10520. -+ (TfLiteFrameBuffer *)cFrameBufferWithWidth:(int)width
  10521. - height:(int)height
  10522. - frameBufferFormat:(enum TfLiteFrameBufferFormat)frameBufferFormat
  10523. - buffer:(uint8_t *)buffer
  10524. - error:(NSError **)error {
  10525. ++ (TfLiteFrameBuffer*)cFrameBufferWithWidth:(int)width
  10526. + height:(int)height
  10527. + frameBufferFormat:
  10528. + (enum TfLiteFrameBufferFormat)frameBufferFormat
  10529. + buffer:(uint8_t*)buffer
  10530. + error:(NSError**)error {
  10531. if (!buffer) {
  10532. return NULL;
  10533. }
  10534. - TfLiteFrameBuffer *cFrameBuffer = [TFLCommonUtils mallocWithSize:sizeof(TfLiteFrameBuffer)
  10535. - error:error];
  10536. + TfLiteFrameBuffer* cFrameBuffer =
  10537. + [TFLCommonUtils mallocWithSize:sizeof(TfLiteFrameBuffer) error:error];
  10538. if (cFrameBuffer) {
  10539. cFrameBuffer->dimension.width = width;
  10540. @@ -65,17 +68,18 @@
  10541. return cFrameBuffer;
  10542. }
  10543. -+ (uint8_t *)createRGBImageDatafromImageData:(uint8_t *)data
  10544. - withWidth:(size_t)width
  10545. - height:(size_t)height
  10546. - stride:(size_t)stride
  10547. - pixelBufferFormat:(OSType)pixelBufferFormatType
  10548. - error:(NSError **)error {
  10549. ++ (uint8_t*)createRGBImageDatafromImageData:(uint8_t*)data
  10550. + withWidth:(size_t)width
  10551. + height:(size_t)height
  10552. + stride:(size_t)stride
  10553. + pixelBufferFormat:(OSType)pixelBufferFormatType
  10554. + error:(NSError**)error {
  10555. NSInteger destinationChannelCount = 3;
  10556. size_t destinationBytesPerRow = width * destinationChannelCount;
  10557. - uint8_t *destPixelBufferAddress =
  10558. - [TFLCommonUtils mallocWithSize:sizeof(uint8_t) * height * destinationBytesPerRow error:error];
  10559. + uint8_t* destPixelBufferAddress = [TFLCommonUtils
  10560. + mallocWithSize:sizeof(uint8_t) * height * destinationBytesPerRow
  10561. + error:error];
  10562. if (!destPixelBufferAddress) {
  10563. return NULL;
  10564. @@ -95,19 +99,23 @@
  10565. switch (pixelBufferFormatType) {
  10566. case kCVPixelFormatType_32RGBA: {
  10567. - convertError = vImageConvert_RGBA8888toRGB888(&srcBuffer, &destBuffer, kvImageNoFlags);
  10568. + convertError = vImageConvert_RGBA8888toRGB888(&srcBuffer, &destBuffer,
  10569. + kvImageNoFlags);
  10570. break;
  10571. }
  10572. case kCVPixelFormatType_32BGRA: {
  10573. - convertError = vImageConvert_BGRA8888toRGB888(&srcBuffer, &destBuffer, kvImageNoFlags);
  10574. + convertError = vImageConvert_BGRA8888toRGB888(&srcBuffer, &destBuffer,
  10575. + kvImageNoFlags);
  10576. break;
  10577. }
  10578. default: {
  10579. - [TFLCommonUtils createCustomError:error
  10580. - withCode:TFLSupportErrorCodeInvalidArgumentError
  10581. - description:@"Invalid source pixel buffer format. Expecting one of "
  10582. - @"kCVPixelFormatType_32RGBA, kCVPixelFormatType_32BGRA, "
  10583. - @"kCVPixelFormatType_32ARGB"];
  10584. + [TFLCommonUtils
  10585. + createCustomError:error
  10586. + withCode:TFLSupportErrorCodeInvalidArgumentError
  10587. + description:
  10588. + @"Invalid source pixel buffer format. Expecting one of "
  10589. + @"kCVPixelFormatType_32RGBA, kCVPixelFormatType_32BGRA, "
  10590. + @"kCVPixelFormatType_32ARGB"];
  10591. free(destPixelBufferAddress);
  10592. return NULL;
  10593. @@ -126,16 +134,17 @@
  10594. return destPixelBufferAddress;
  10595. }
  10596. -+ (uint8_t *)createRGBImageDatafromCVPixelBuffer:(CVPixelBufferRef)pixelBuffer
  10597. - error:(NSError **)error {
  10598. ++ (uint8_t*)createRGBImageDatafromCVPixelBuffer:(CVPixelBufferRef)pixelBuffer
  10599. + error:(NSError**)error {
  10600. CVPixelBufferLockBaseAddress(pixelBuffer, 0);
  10601. - uint8_t *rgbData = [TFLCVPixelBufferUtils
  10602. + uint8_t* rgbData = [TFLCVPixelBufferUtils
  10603. createRGBImageDatafromImageData:CVPixelBufferGetBaseAddress(pixelBuffer)
  10604. withWidth:CVPixelBufferGetWidth(pixelBuffer)
  10605. height:CVPixelBufferGetHeight(pixelBuffer)
  10606. stride:CVPixelBufferGetBytesPerRow(pixelBuffer)
  10607. - pixelBufferFormat:CVPixelBufferGetPixelFormatType(pixelBuffer)
  10608. + pixelBufferFormat:CVPixelBufferGetPixelFormatType(
  10609. + pixelBuffer)
  10610. error:error];
  10611. CVPixelBufferUnlockBaseAddress(pixelBuffer, 0);
  10612. @@ -143,9 +152,10 @@
  10613. return rgbData;
  10614. }
  10615. -+ (TfLiteFrameBuffer *)cFramebufferFromCVPixelBuffer:(CVPixelBufferRef)pixelBuffer
  10616. - error:(NSError **)error {
  10617. - uint8_t *buffer = NULL;
  10618. ++ (TfLiteFrameBuffer*)cFramebufferFromCVPixelBuffer:
  10619. + (CVPixelBufferRef)pixelBuffer
  10620. + error:(NSError**)error {
  10621. + uint8_t* buffer = NULL;
  10622. enum TfLiteFrameBufferFormat cPixelFormat = kRGB;
  10623. OSType pixelBufferFormat = CVPixelBufferGetPixelFormatType(pixelBuffer);
  10624. @@ -154,14 +164,18 @@
  10625. case kCVPixelFormatType_32BGRA: {
  10626. cPixelFormat = kRGB;
  10627. - buffer = [TFLCVPixelBufferUtils createRGBImageDatafromCVPixelBuffer:pixelBuffer error:error];
  10628. + buffer =
  10629. + [TFLCVPixelBufferUtils createRGBImageDatafromCVPixelBuffer:pixelBuffer
  10630. + error:error];
  10631. break;
  10632. }
  10633. default: {
  10634. - [TFLCommonUtils createCustomError:error
  10635. - withCode:TFLSupportErrorCodeInvalidArgumentError
  10636. - description:@"Unsupported pixel format for CVPixelBuffer. Supported "
  10637. - @"pixel format types are kCVPixelFormatType_32BGRA"];
  10638. + [TFLCommonUtils
  10639. + createCustomError:error
  10640. + withCode:TFLSupportErrorCodeInvalidArgumentError
  10641. + description:
  10642. + @"Unsupported pixel format for CVPixelBuffer. Supported "
  10643. + @"pixel format types are kCVPixelFormatType_32BGRA"];
  10644. }
  10645. }
  10646. @@ -176,8 +190,8 @@
  10647. @implementation UIImage (RawPixelDataUtils)
  10648. -- (TfLiteFrameBuffer *)frameBufferWithError:(NSError **)error {
  10649. - TfLiteFrameBuffer *frameBuffer = nil;
  10650. +- (TfLiteFrameBuffer*)frameBufferWithError:(NSError**)error {
  10651. + TfLiteFrameBuffer* frameBuffer = nil;
  10652. if (self.CGImage) {
  10653. frameBuffer = [self frameBufferFromCGImage:self.CGImage error:error];
  10654. @@ -202,59 +216,65 @@
  10655. }
  10656. CGDataProviderRef imageDataProvider = CGImageGetDataProvider(cgImage);
  10657. - CFMutableDataRef mutableDataRef =
  10658. - CFDataCreateMutableCopy(kCFAllocatorDefault, 0, CGDataProviderCopyData(imageDataProvider));
  10659. + CFMutableDataRef mutableDataRef = CFDataCreateMutableCopy(
  10660. + kCFAllocatorDefault, 0, CGDataProviderCopyData(imageDataProvider));
  10661. - UInt8 *pixelData = CFDataGetMutableBytePtr(mutableDataRef);
  10662. + UInt8* pixelData = CFDataGetMutableBytePtr(mutableDataRef);
  10663. - if (pixelData == nil) return nil;
  10664. + if (pixelData == nil)
  10665. + return nil;
  10666. CVPixelBufferRef cvPixelBuffer = nil;
  10667. - CVPixelBufferCreateWithBytes(kCFAllocatorDefault, CGImageGetWidth(cgImage),
  10668. - CGImageGetHeight(cgImage), kCVPixelFormatType_OneComponent8,
  10669. - pixelData, CGImageGetBytesPerRow(cgImage), nil, nil, options,
  10670. - &cvPixelBuffer);
  10671. + CVPixelBufferCreateWithBytes(
  10672. + kCFAllocatorDefault, CGImageGetWidth(cgImage), CGImageGetHeight(cgImage),
  10673. + kCVPixelFormatType_OneComponent8, pixelData,
  10674. + CGImageGetBytesPerRow(cgImage), nil, nil, options, &cvPixelBuffer);
  10675. return cvPixelBuffer;
  10676. }
  10677. -+ (UInt8 *_Nullable)pixelDataFromCGImage:(CGImageRef)cgImage error:(NSError **)error {
  10678. ++ (UInt8* _Nullable)pixelDataFromCGImage:(CGImageRef)cgImage
  10679. + error:(NSError**)error {
  10680. size_t width = CGImageGetWidth(cgImage);
  10681. size_t height = CGImageGetHeight(cgImage);
  10682. NSInteger bitsPerComponent = 8;
  10683. NSInteger channelCount = 4;
  10684. - UInt8 *buffer_to_return = NULL;
  10685. + UInt8* buffer_to_return = NULL;
  10686. CGColorSpaceRef colorSpace = CGColorSpaceCreateDeviceRGB();
  10687. size_t bytesPerRow = channelCount * width;
  10688. // iOS infers bytesPerRow if it is set to 0.
  10689. - // See https://developer.apple.com/documentation/coregraphics/1455939-cgbitmapcontextcreate
  10690. + // See
  10691. + // https://developer.apple.com/documentation/coregraphics/1455939-cgbitmapcontextcreate
  10692. // But for segmentation test image, this was not the case.
  10693. // Hence setting it to the value of channelCount*width.
  10694. // kCGImageAlphaNoneSkipLast specifies that Alpha will always be next to B.
  10695. // kCGBitmapByteOrder32Big specifies that R will be stored before B.
  10696. // In combination they signify a pixelFormat of kCVPixelFormatType32RGBA.
  10697. - CGBitmapInfo bitMapinfoFor32RGBA = kCGImageAlphaNoneSkipLast | kCGBitmapByteOrder32Big;
  10698. - CGContextRef context = CGBitmapContextCreate(nil, width, height, bitsPerComponent, bytesPerRow,
  10699. - colorSpace, bitMapinfoFor32RGBA);
  10700. + CGBitmapInfo bitMapinfoFor32RGBA =
  10701. + kCGImageAlphaNoneSkipLast | kCGBitmapByteOrder32Big;
  10702. + CGContextRef context =
  10703. + CGBitmapContextCreate(nil, width, height, bitsPerComponent, bytesPerRow,
  10704. + colorSpace, bitMapinfoFor32RGBA);
  10705. if (context) {
  10706. CGContextDrawImage(context, CGRectMake(0, 0, width, height), cgImage);
  10707. - uint8_t *srcData = CGBitmapContextGetData(context);
  10708. + uint8_t* srcData = CGBitmapContextGetData(context);
  10709. if (srcData) {
  10710. - // We have drawn the image as an RGBA image with 8 bitsPerComponent and hence can safely input
  10711. - // a pixel format of type kCVPixelFormatType_32RGBA for conversion by vImage.
  10712. - buffer_to_return =
  10713. - [TFLCVPixelBufferUtils createRGBImageDatafromImageData:srcData
  10714. - withWidth:width
  10715. - height:height
  10716. - stride:bytesPerRow
  10717. - pixelBufferFormat:kCVPixelFormatType_32RGBA
  10718. - error:error];
  10719. + // We have drawn the image as an RGBA image with 8 bitsPerComponent and
  10720. + // hence can safely input a pixel format of type kCVPixelFormatType_32RGBA
  10721. + // for conversion by vImage.
  10722. + buffer_to_return = [TFLCVPixelBufferUtils
  10723. + createRGBImageDatafromImageData:srcData
  10724. + withWidth:width
  10725. + height:height
  10726. + stride:bytesPerRow
  10727. + pixelBufferFormat:kCVPixelFormatType_32RGBA
  10728. + error:error];
  10729. }
  10730. CGContextRelease(context);
  10731. @@ -265,18 +285,21 @@
  10732. return buffer_to_return;
  10733. }
  10734. -- (TfLiteFrameBuffer *)frameBufferFromCGImage:(CGImageRef)cgImage error:(NSError **)error {
  10735. - UInt8 *buffer = [UIImage pixelDataFromCGImage:cgImage error:error];
  10736. +- (TfLiteFrameBuffer*)frameBufferFromCGImage:(CGImageRef)cgImage
  10737. + error:(NSError**)error {
  10738. + UInt8* buffer = [UIImage pixelDataFromCGImage:cgImage error:error];
  10739. - return [TFLCVPixelBufferUtils cFrameBufferWithWidth:(int)CGImageGetWidth(cgImage)
  10740. - height:(int)CGImageGetHeight(cgImage)
  10741. - frameBufferFormat:kRGB
  10742. - buffer:buffer
  10743. - error:error];
  10744. + return [TFLCVPixelBufferUtils
  10745. + cFrameBufferWithWidth:(int)CGImageGetWidth(cgImage)
  10746. + height:(int)CGImageGetHeight(cgImage)
  10747. + frameBufferFormat:kRGB
  10748. + buffer:buffer
  10749. + error:error];
  10750. }
  10751. -- (TfLiteFrameBuffer *)frameBufferFromCIImage:(CIImage *)ciImage error:(NSError **)error {
  10752. - uint8_t *buffer = NULL;
  10753. +- (TfLiteFrameBuffer*)frameBufferFromCIImage:(CIImage*)ciImage
  10754. + error:(NSError**)error {
  10755. + uint8_t* buffer = NULL;
  10756. int width = 0;
  10757. int height = 0;
  10758. @@ -285,17 +308,20 @@
  10759. width = (int)CVPixelBufferGetWidth(ciImage.pixelBuffer);
  10760. height = (int)CVPixelBufferGetHeight(ciImage.pixelBuffer);
  10761. - buffer = [TFLCVPixelBufferUtils createRGBImageDatafromCVPixelBuffer:ciImage.pixelBuffer
  10762. - error:error];
  10763. + buffer = [TFLCVPixelBufferUtils
  10764. + createRGBImageDatafromCVPixelBuffer:ciImage.pixelBuffer
  10765. + error:error];
  10766. } else if (ciImage.CGImage) {
  10767. buffer = [UIImage pixelDataFromCGImage:ciImage.CGImage error:error];
  10768. width = (int)CGImageGetWidth(ciImage.CGImage);
  10769. height = (int)CGImageGetWidth(ciImage.CGImage);
  10770. } else {
  10771. - [TFLCommonUtils createCustomError:error
  10772. - withCode:TFLSupportErrorCodeInvalidArgumentError
  10773. - description:@"CIImage should have CGImage or CVPixelBuffer info."];
  10774. + [TFLCommonUtils
  10775. + createCustomError:error
  10776. + withCode:TFLSupportErrorCodeInvalidArgumentError
  10777. + description:
  10778. + @"CIImage should have CGImage or CVPixelBuffer info."];
  10779. }
  10780. return [TFLCVPixelBufferUtils cFrameBufferWithWidth:width
  10781. @@ -309,19 +335,23 @@
  10782. @implementation GMLImage (Utils)
  10783. -- (nullable TfLiteFrameBuffer *)cFrameBufferWithError:(NSError *_Nullable *)error {
  10784. - TfLiteFrameBuffer *cFrameBuffer = NULL;
  10785. +- (nullable TfLiteFrameBuffer*)cFrameBufferWithError:
  10786. + (NSError* _Nullable*)error {
  10787. + TfLiteFrameBuffer* cFrameBuffer = NULL;
  10788. switch (self.imageSourceType) {
  10789. case GMLImageSourceTypeSampleBuffer: {
  10790. - CVPixelBufferRef sampleImagePixelBuffer = CMSampleBufferGetImageBuffer(self.sampleBuffer);
  10791. - cFrameBuffer = [TFLCVPixelBufferUtils cFramebufferFromCVPixelBuffer:sampleImagePixelBuffer
  10792. - error:error];
  10793. + CVPixelBufferRef sampleImagePixelBuffer =
  10794. + CMSampleBufferGetImageBuffer(self.sampleBuffer);
  10795. + cFrameBuffer = [TFLCVPixelBufferUtils
  10796. + cFramebufferFromCVPixelBuffer:sampleImagePixelBuffer
  10797. + error:error];
  10798. break;
  10799. }
  10800. case GMLImageSourceTypePixelBuffer: {
  10801. - cFrameBuffer = [TFLCVPixelBufferUtils cFramebufferFromCVPixelBuffer:self.pixelBuffer
  10802. - error:error];
  10803. + cFrameBuffer =
  10804. + [TFLCVPixelBufferUtils cFramebufferFromCVPixelBuffer:self.pixelBuffer
  10805. + error:error];
  10806. break;
  10807. }
  10808. case GMLImageSourceTypeImage: {
  10809. @@ -352,14 +382,17 @@
  10810. return nil;
  10811. }
  10812. -+ (GMLImage *)imageFromBundleWithClass:(Class)classObject
  10813. - fileName:(NSString *)name
  10814. - ofType:(NSString *)type {
  10815. - NSString *imagePath = [[NSBundle bundleForClass:classObject] pathForResource:name ofType:type];
  10816. - if (!imagePath) return nil;
  10817. ++ (GMLImage*)imageFromBundleWithClass:(Class)classObject
  10818. + fileName:(NSString*)name
  10819. + ofType:(NSString*)type {
  10820. + NSString* imagePath =
  10821. + [[NSBundle bundleForClass:classObject] pathForResource:name ofType:type];
  10822. + if (!imagePath)
  10823. + return nil;
  10824. - UIImage *image = [[UIImage alloc] initWithContentsOfFile:imagePath];
  10825. - if (!image) return nil;
  10826. + UIImage* image = [[UIImage alloc] initWithContentsOfFile:imagePath];
  10827. + if (!image)
  10828. + return nil;
  10829. return [[GMLImage alloc] initWithImage:image];
  10830. }
  10831. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/audio/core/TFLRingBufferTests.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/audio/core/TFLRingBufferTests.m
  10832. index 3e2df5d4bf023..cd389b9c0a9a8 100644
  10833. --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/audio/core/TFLRingBufferTests.m
  10834. +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/audio/core/TFLRingBufferTests.m
  10835. @@ -17,10 +17,12 @@
  10836. #import "tensorflow_lite_support/ios/sources/TFLCommon.h"
  10837. #import "tensorflow_lite_support/ios/task/audio/core/sources/TFLRingBuffer.h"
  10838. -#define VerifyError(error, expectedErrorDomain, expectedErrorCode, expectedLocalizedDescription) \
  10839. - XCTAssertEqual(error.domain, expectedErrorDomain); \
  10840. - XCTAssertEqual(error.code, expectedErrorCode); \
  10841. - XCTAssertEqualObjects(error.localizedDescription, expectedLocalizedDescription);
  10842. +#define VerifyError(error, expectedErrorDomain, expectedErrorCode, \
  10843. + expectedLocalizedDescription) \
  10844. + XCTAssertEqual(error.domain, expectedErrorDomain); \
  10845. + XCTAssertEqual(error.code, expectedErrorCode); \
  10846. + XCTAssertEqualObjects(error.localizedDescription, \
  10847. + expectedLocalizedDescription);
  10848. NS_ASSUME_NONNULL_BEGIN
  10849. @@ -33,15 +35,20 @@ NS_ASSUME_NONNULL_BEGIN
  10850. NSInteger inDataLength = 5;
  10851. float inData[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f};
  10852. - TFLFloatBuffer *inBuffer = [[TFLFloatBuffer alloc] initWithData:&(inData[0]) size:inDataLength];
  10853. + TFLFloatBuffer* inBuffer =
  10854. + [[TFLFloatBuffer alloc] initWithData:&(inData[0]) size:inDataLength];
  10855. NSInteger bufferSize = 5;
  10856. - TFLRingBuffer *ringBuffer = [[TFLRingBuffer alloc] initWithBufferSize:bufferSize];
  10857. + TFLRingBuffer* ringBuffer =
  10858. + [[TFLRingBuffer alloc] initWithBufferSize:bufferSize];
  10859. - XCTAssertTrue([ringBuffer loadBuffer:inBuffer offset:0 size:inDataLength error:nil]);
  10860. + XCTAssertTrue([ringBuffer loadBuffer:inBuffer
  10861. + offset:0
  10862. + size:inDataLength
  10863. + error:nil]);
  10864. // State after load: [1.0, 2.0, 3.0, 4.0, 5.0]
  10865. - TFLFloatBuffer *outBuffer = ringBuffer.floatBuffer;
  10866. + TFLFloatBuffer* outBuffer = ringBuffer.floatBuffer;
  10867. XCTAssertNotNil(outBuffer);
  10868. XCTAssertEqual(outBuffer.size, bufferSize);
  10869. @@ -55,16 +62,21 @@ NS_ASSUME_NONNULL_BEGIN
  10870. - (void)testLoadSucceedsWithPartialLengthBuffer {
  10871. NSInteger inDataSize = 3;
  10872. float inData[] = {1.0f, 2.0f, 3.0f};
  10873. - TFLFloatBuffer *inBuffer = [[TFLFloatBuffer alloc] initWithData:&(inData[0]) size:inDataSize];
  10874. + TFLFloatBuffer* inBuffer =
  10875. + [[TFLFloatBuffer alloc] initWithData:&(inData[0]) size:inDataSize];
  10876. NSInteger bufferSize = 5;
  10877. - TFLRingBuffer *ringBuffer = [[TFLRingBuffer alloc] initWithBufferSize:bufferSize];
  10878. + TFLRingBuffer* ringBuffer =
  10879. + [[TFLRingBuffer alloc] initWithBufferSize:bufferSize];
  10880. - XCTAssertTrue([ringBuffer loadBuffer:inBuffer offset:0 size:inDataSize error:nil]);
  10881. + XCTAssertTrue([ringBuffer loadBuffer:inBuffer
  10882. + offset:0
  10883. + size:inDataSize
  10884. + error:nil]);
  10885. // State after load: [0.0, 0.0, 1.0, 2.0, 3.0]
  10886. - TFLFloatBuffer *outBuffer = ringBuffer.floatBuffer;
  10887. + TFLFloatBuffer* outBuffer = ringBuffer.floatBuffer;
  10888. XCTAssertNotNil(outBuffer);
  10889. XCTAssertEqual(outBuffer.size, bufferSize);
  10890. @@ -80,23 +92,32 @@ NS_ASSUME_NONNULL_BEGIN
  10891. NSInteger initialDataSize = 4;
  10892. float initialArray[] = {1.0f, 2.0f, 3.0f, 4.0f};
  10893. - TFLFloatBuffer *initialBuffer =
  10894. - [[TFLFloatBuffer alloc] initWithData:&(initialArray[0]) size:initialDataSize];
  10895. + TFLFloatBuffer* initialBuffer =
  10896. + [[TFLFloatBuffer alloc] initWithData:&(initialArray[0])
  10897. + size:initialDataSize];
  10898. NSInteger bufferSize = 5;
  10899. - TFLRingBuffer *ringBuffer = [[TFLRingBuffer alloc] initWithBufferSize:bufferSize];
  10900. + TFLRingBuffer* ringBuffer =
  10901. + [[TFLRingBuffer alloc] initWithBufferSize:bufferSize];
  10902. - XCTAssertTrue([ringBuffer loadBuffer:initialBuffer offset:0 size:initialDataSize error:nil]);
  10903. + XCTAssertTrue([ringBuffer loadBuffer:initialBuffer
  10904. + offset:0
  10905. + size:initialDataSize
  10906. + error:nil]);
  10907. // State after load: [0.0, 1.0, 2.0, 3.0, 4.0]
  10908. NSInteger inDataSize = 3;
  10909. float inArray[] = {5, 6, 7};
  10910. - TFLFloatBuffer *inBuffer = [[TFLFloatBuffer alloc] initWithData:&(inArray[0]) size:inDataSize];
  10911. + TFLFloatBuffer* inBuffer =
  10912. + [[TFLFloatBuffer alloc] initWithData:&(inArray[0]) size:inDataSize];
  10913. - XCTAssertTrue([ringBuffer loadBuffer:inBuffer offset:0 size:inDataSize error:nil]);
  10914. + XCTAssertTrue([ringBuffer loadBuffer:inBuffer
  10915. + offset:0
  10916. + size:inDataSize
  10917. + error:nil]);
  10918. - TFLFloatBuffer *outBuffer = ringBuffer.floatBuffer;
  10919. + TFLFloatBuffer* outBuffer = ringBuffer.floatBuffer;
  10920. XCTAssertNotNil(outBuffer);
  10921. XCTAssertEqual(outBuffer.size, bufferSize);
  10922. @@ -112,24 +133,33 @@ NS_ASSUME_NONNULL_BEGIN
  10923. NSInteger initialDataSize = 5;
  10924. float initialArray[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f};
  10925. - TFLFloatBuffer *initialBuffer =
  10926. - [[TFLFloatBuffer alloc] initWithData:&(initialArray[0]) size:initialDataSize];
  10927. + TFLFloatBuffer* initialBuffer =
  10928. + [[TFLFloatBuffer alloc] initWithData:&(initialArray[0])
  10929. + size:initialDataSize];
  10930. NSInteger bufferSize = 5;
  10931. - TFLRingBuffer *ringBuffer = [[TFLRingBuffer alloc] initWithBufferSize:bufferSize];
  10932. + TFLRingBuffer* ringBuffer =
  10933. + [[TFLRingBuffer alloc] initWithBufferSize:bufferSize];
  10934. - XCTAssertTrue([ringBuffer loadBuffer:initialBuffer offset:0 size:initialDataSize error:nil]);
  10935. + XCTAssertTrue([ringBuffer loadBuffer:initialBuffer
  10936. + offset:0
  10937. + size:initialDataSize
  10938. + error:nil]);
  10939. // State after load: [1.0, 2.0, 3.0, 4.0, 5.0]
  10940. NSInteger sourceDataSize = 6;
  10941. float sourceArray[] = {6, 7, 8, 9, 10, 11};
  10942. - TFLFloatBuffer *sourceBuffer =
  10943. - [[TFLFloatBuffer alloc] initWithData:&(sourceArray[0]) size:sourceDataSize];
  10944. + TFLFloatBuffer* sourceBuffer =
  10945. + [[TFLFloatBuffer alloc] initWithData:&(sourceArray[0])
  10946. + size:sourceDataSize];
  10947. - XCTAssertTrue([ringBuffer loadBuffer:sourceBuffer offset:0 size:sourceDataSize error:nil]);
  10948. + XCTAssertTrue([ringBuffer loadBuffer:sourceBuffer
  10949. + offset:0
  10950. + size:sourceDataSize
  10951. + error:nil]);
  10952. - TFLFloatBuffer *outBuffer = ringBuffer.floatBuffer;
  10953. + TFLFloatBuffer* outBuffer = ringBuffer.floatBuffer;
  10954. XCTAssertNotNil(outBuffer);
  10955. XCTAssertEqual(outBuffer.size, bufferSize);
  10956. @@ -145,25 +175,34 @@ NS_ASSUME_NONNULL_BEGIN
  10957. NSInteger initialDataSize = 5;
  10958. float initialArray[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f};
  10959. - TFLFloatBuffer *initialBuffer =
  10960. - [[TFLFloatBuffer alloc] initWithData:&(initialArray[0]) size:initialDataSize];
  10961. + TFLFloatBuffer* initialBuffer =
  10962. + [[TFLFloatBuffer alloc] initWithData:&(initialArray[0])
  10963. + size:initialDataSize];
  10964. NSInteger bufferSize = 5;
  10965. - TFLRingBuffer *ringBuffer = [[TFLRingBuffer alloc] initWithBufferSize:bufferSize];
  10966. + TFLRingBuffer* ringBuffer =
  10967. + [[TFLRingBuffer alloc] initWithBufferSize:bufferSize];
  10968. - XCTAssertTrue([ringBuffer loadBuffer:initialBuffer offset:0 size:initialDataSize error:nil]);
  10969. + XCTAssertTrue([ringBuffer loadBuffer:initialBuffer
  10970. + offset:0
  10971. + size:initialDataSize
  10972. + error:nil]);
  10973. // State after load: [1.0, 2.0, 3.0, 4.0, 5.0]
  10974. NSInteger totalInSize = 8;
  10975. float inArray[] = {6, 7, 8, 9, 10, 11, 12, 13};
  10976. - TFLFloatBuffer *inBuffer = [[TFLFloatBuffer alloc] initWithData:&(inArray[0]) size:totalInSize];
  10977. + TFLFloatBuffer* inBuffer =
  10978. + [[TFLFloatBuffer alloc] initWithData:&(inArray[0]) size:totalInSize];
  10979. NSInteger offset = 2;
  10980. NSInteger inDataSize = 6;
  10981. - XCTAssertTrue([ringBuffer loadBuffer:inBuffer offset:offset size:inDataSize error:nil]);
  10982. + XCTAssertTrue([ringBuffer loadBuffer:inBuffer
  10983. + offset:offset
  10984. + size:inDataSize
  10985. + error:nil]);
  10986. - TFLFloatBuffer *outBuffer = ringBuffer.floatBuffer;
  10987. + TFLFloatBuffer* outBuffer = ringBuffer.floatBuffer;
  10988. XCTAssertNotNil(outBuffer);
  10989. XCTAssertEqual(outBuffer.size, bufferSize);
  10990. @@ -179,25 +218,34 @@ NS_ASSUME_NONNULL_BEGIN
  10991. NSInteger initialDataSize = 2;
  10992. float initialArray[] = {1.0f, 2.0f};
  10993. - TFLFloatBuffer *initialBuffer =
  10994. - [[TFLFloatBuffer alloc] initWithData:&(initialArray[0]) size:initialDataSize];
  10995. + TFLFloatBuffer* initialBuffer =
  10996. + [[TFLFloatBuffer alloc] initWithData:&(initialArray[0])
  10997. + size:initialDataSize];
  10998. NSInteger bufferSize = 5;
  10999. - TFLRingBuffer *ringBuffer = [[TFLRingBuffer alloc] initWithBufferSize:bufferSize];
  11000. + TFLRingBuffer* ringBuffer =
  11001. + [[TFLRingBuffer alloc] initWithBufferSize:bufferSize];
  11002. - XCTAssertTrue([ringBuffer loadBuffer:initialBuffer offset:0 size:initialDataSize error:nil]);
  11003. + XCTAssertTrue([ringBuffer loadBuffer:initialBuffer
  11004. + offset:0
  11005. + size:initialDataSize
  11006. + error:nil]);
  11007. // State after load: [0.0, 0.0, 0.0, 1.0, 2.0]
  11008. NSInteger totalInSize = 4;
  11009. float inArray[] = {6.0f, 7.0f, 8.0f, 9.0f};
  11010. - TFLFloatBuffer *inBuffer = [[TFLFloatBuffer alloc] initWithData:&(inArray[0]) size:totalInSize];
  11011. + TFLFloatBuffer* inBuffer =
  11012. + [[TFLFloatBuffer alloc] initWithData:&(inArray[0]) size:totalInSize];
  11013. NSInteger offset = 2;
  11014. NSInteger inDataSize = 2;
  11015. - XCTAssertTrue([ringBuffer loadBuffer:inBuffer offset:offset size:inDataSize error:nil]);
  11016. + XCTAssertTrue([ringBuffer loadBuffer:inBuffer
  11017. + offset:offset
  11018. + size:inDataSize
  11019. + error:nil]);
  11020. - TFLFloatBuffer *outBuffer = ringBuffer.floatBuffer;
  11021. + TFLFloatBuffer* outBuffer = ringBuffer.floatBuffer;
  11022. XCTAssertNotNil(outBuffer);
  11023. XCTAssertEqual(outBuffer.size, bufferSize);
  11024. @@ -213,26 +261,36 @@ NS_ASSUME_NONNULL_BEGIN
  11025. NSInteger initialDataSize = 2;
  11026. float initialArray[] = {1.0f, 2.0f};
  11027. - TFLFloatBuffer *initialBuffer =
  11028. - [[TFLFloatBuffer alloc] initWithData:&(initialArray[0]) size:initialDataSize];
  11029. + TFLFloatBuffer* initialBuffer =
  11030. + [[TFLFloatBuffer alloc] initWithData:&(initialArray[0])
  11031. + size:initialDataSize];
  11032. NSInteger bufferSize = 5;
  11033. - TFLRingBuffer *ringBuffer = [[TFLRingBuffer alloc] initWithBufferSize:bufferSize];
  11034. + TFLRingBuffer* ringBuffer =
  11035. + [[TFLRingBuffer alloc] initWithBufferSize:bufferSize];
  11036. - XCTAssertTrue([ringBuffer loadBuffer:initialBuffer offset:0 size:initialDataSize error:nil]);
  11037. + XCTAssertTrue([ringBuffer loadBuffer:initialBuffer
  11038. + offset:0
  11039. + size:initialDataSize
  11040. + error:nil]);
  11041. NSInteger totalInSize = 4;
  11042. float inArray[] = {6.0f, 7.0f, 8.0f, 9.0f};
  11043. - TFLFloatBuffer *inBuffer = [[TFLFloatBuffer alloc] initWithData:&(inArray[0]) size:totalInSize];
  11044. + TFLFloatBuffer* inBuffer =
  11045. + [[TFLFloatBuffer alloc] initWithData:&(inArray[0]) size:totalInSize];
  11046. NSInteger offset = 2;
  11047. NSInteger inDataSize = 3;
  11048. - NSError *error = nil;
  11049. - XCTAssertFalse([ringBuffer loadBuffer:inBuffer offset:offset size:inDataSize error:&error]);
  11050. + NSError* error = nil;
  11051. + XCTAssertFalse([ringBuffer loadBuffer:inBuffer
  11052. + offset:offset
  11053. + size:inDataSize
  11054. + error:&error]);
  11055. XCTAssertNotNil(error);
  11056. - VerifyError(error, @"org.tensorflow.lite.tasks", TFLSupportErrorCodeInvalidArgumentError,
  11057. + VerifyError(error, @"org.tensorflow.lite.tasks",
  11058. + TFLSupportErrorCodeInvalidArgumentError,
  11059. @"offset + size exceeds the maximum size of the source buffer.");
  11060. }
  11061. @@ -240,19 +298,24 @@ NS_ASSUME_NONNULL_BEGIN
  11062. NSInteger initialDataSize = 2;
  11063. float initialArray[] = {1.0f, 2.0f};
  11064. - TFLFloatBuffer *initialBuffer =
  11065. - [[TFLFloatBuffer alloc] initWithData:&(initialArray[0]) size:initialDataSize];
  11066. + TFLFloatBuffer* initialBuffer =
  11067. + [[TFLFloatBuffer alloc] initWithData:&(initialArray[0])
  11068. + size:initialDataSize];
  11069. NSInteger bufferSize = 5;
  11070. - TFLRingBuffer *ringBuffer = [[TFLRingBuffer alloc] initWithBufferSize:bufferSize];
  11071. + TFLRingBuffer* ringBuffer =
  11072. + [[TFLRingBuffer alloc] initWithBufferSize:bufferSize];
  11073. - XCTAssertTrue([ringBuffer loadBuffer:initialBuffer offset:0 size:initialDataSize error:nil]);
  11074. + XCTAssertTrue([ringBuffer loadBuffer:initialBuffer
  11075. + offset:0
  11076. + size:initialDataSize
  11077. + error:nil]);
  11078. [ringBuffer clear];
  11079. float expectedData[] = {0.0f, 0.0f, 0.0f, 0.0f, 0.0f};
  11080. - TFLFloatBuffer *outBuffer = ringBuffer.floatBuffer;
  11081. + TFLFloatBuffer* outBuffer = ringBuffer.floatBuffer;
  11082. XCTAssertNotNil(outBuffer);
  11083. XCTAssertEqual(outBuffer.size, bufferSize);
  11084. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/vision/image_classifier/TFLImageClassifierTests.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/vision/image_classifier/TFLImageClassifierTests.m
  11085. index d03b6044bdd68..b1ed8cf1e2f6d 100644
  11086. --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/vision/image_classifier/TFLImageClassifierTests.m
  11087. +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/vision/image_classifier/TFLImageClassifierTests.m
  11088. @@ -29,8 +29,9 @@ NS_ASSUME_NONNULL_BEGIN
  11089. // Put setup code here. This method is called before the invocation of each test method in the
  11090. // class.
  11091. [super setUp];
  11092. - self.modelPath = [[NSBundle bundleForClass:self.class] pathForResource:@"mobilenet_v2_1.0_224"
  11093. - ofType:@"tflite"];
  11094. + self.modelPath = [[NSBundle bundleForClass:self.class]
  11095. + pathForResource:@"mobilenet_v2_1.0_224"
  11096. + ofType:@"tflite"];
  11097. XCTAssertNotNil(self.modelPath);
  11098. }
  11099. @@ -42,8 +43,9 @@ NS_ASSUME_NONNULL_BEGIN
  11100. [TFLImageClassifier imageClassifierWithOptions:imageClassifierOptions error:nil];
  11101. XCTAssertNotNil(imageClassifier);
  11102. - GMLImage *gmlImage =
  11103. - [GMLImage imageFromBundleWithClass:self.class fileName:@"burger" ofType:@"jpg"];
  11104. + GMLImage* gmlImage = [GMLImage imageFromBundleWithClass:self.class
  11105. + fileName:@"burger"
  11106. + ofType:@"jpg"];
  11107. XCTAssertNotNil(gmlImage);
  11108. TFLClassificationResult *classificationResults = [imageClassifier classifyWithGMLImage:gmlImage
  11109. @@ -67,14 +69,16 @@ NS_ASSUME_NONNULL_BEGIN
  11110. [TFLImageClassifier imageClassifierWithOptions:imageClassifierOptions error:nil];
  11111. XCTAssertNotNil(imageClassifier);
  11112. - GMLImage *gmlImage =
  11113. - [GMLImage imageFromBundleWithClass:self.class fileName:@"burger" ofType:@"jpg"];
  11114. + GMLImage* gmlImage = [GMLImage imageFromBundleWithClass:self.class
  11115. + fileName:@"burger"
  11116. + ofType:@"jpg"];
  11117. XCTAssertNotNil(gmlImage);
  11118. TFLClassificationResult *classificationResults = [imageClassifier classifyWithGMLImage:gmlImage
  11119. error:nil];
  11120. XCTAssertTrue(classificationResults.classifications.count > 0);
  11121. - XCTAssertLessThanOrEqual(classificationResults.classifications[0].categories.count, maxResults);
  11122. + XCTAssertLessThanOrEqual(
  11123. + classificationResults.classifications[0].categories.count, maxResults);
  11124. TFLCategory *category = classificationResults.classifications[0].categories[0];
  11125. XCTAssertTrue([category.label isEqual:@"cheeseburger"]);
  11126. @@ -92,8 +96,9 @@ NS_ASSUME_NONNULL_BEGIN
  11127. [TFLImageClassifier imageClassifierWithOptions:imageClassifierOptions error:nil];
  11128. XCTAssertNotNil(imageClassifier);
  11129. - GMLImage *gmlImage =
  11130. - [GMLImage imageFromBundleWithClass:self.class fileName:@"multi_objects" ofType:@"jpg"];
  11131. + GMLImage* gmlImage = [GMLImage imageFromBundleWithClass:self.class
  11132. + fileName:@"multi_objects"
  11133. + ofType:@"jpg"];
  11134. XCTAssertNotNil(gmlImage);
  11135. CGRect roi = CGRectMake(406, 110, 148, 153);
  11136. @@ -117,8 +122,9 @@ NS_ASSUME_NONNULL_BEGIN
  11137. [TFLImageClassifier imageClassifierWithOptions:imageClassifierOptions error:nil];
  11138. XCTAssertNotNil(imageClassifier);
  11139. - GMLImage *gmlImage =
  11140. - [GMLImage imageFromBundleWithClass:self.class fileName:@"sparrow" ofType:@"png"];
  11141. + GMLImage* gmlImage = [GMLImage imageFromBundleWithClass:self.class
  11142. + fileName:@"sparrow"
  11143. + ofType:@"png"];
  11144. XCTAssertNotNil(gmlImage);
  11145. TFLClassificationResult *classificationResults = [imageClassifier classifyWithGMLImage:gmlImage
  11146. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/vision/image_segmenter/TFLImageSegmenterTests.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/vision/image_segmenter/TFLImageSegmenterTests.m
  11147. index c2977475f6d4f..f483a516b9bc6 100644
  11148. --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/vision/image_segmenter/TFLImageSegmenterTests.m
  11149. +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/vision/image_segmenter/TFLImageSegmenterTests.m
  11150. @@ -18,10 +18,11 @@
  11151. #import "tensorflow_lite_support/ios/task/vision/sources/TFLImageSegmenter.h"
  11152. #import "tensorflow_lite_support/ios/task/vision/utils/sources/GMLImage+Utils.h"
  11153. -#define VerifyColoredLabel(coloredLabel, expectedR, expectedG, expectedB, expectedLabel) \
  11154. - XCTAssertEqual(coloredLabel.r, expectedR); \
  11155. - XCTAssertEqual(coloredLabel.g, expectedG); \
  11156. - XCTAssertEqual(coloredLabel.b, expectedB); \
  11157. +#define VerifyColoredLabel(coloredLabel, expectedR, expectedG, expectedB, \
  11158. + expectedLabel) \
  11159. + XCTAssertEqual(coloredLabel.r, expectedR); \
  11160. + XCTAssertEqual(coloredLabel.g, expectedG); \
  11161. + XCTAssertEqual(coloredLabel.b, expectedB); \
  11162. XCTAssertEqualObjects(coloredLabel.label, expectedLabel)
  11163. // The maximum fraction of pixels in the candidate mask that can have a
  11164. @@ -40,22 +41,24 @@ NSInteger const deepLabV3SegmentationHeight = 257;
  11165. @interface TFLImageSegmenterTests : XCTestCase
  11166. -@property(nonatomic, nullable) NSString *modelPath;
  11167. +@property(nonatomic, nullable) NSString* modelPath;
  11168. @end
  11169. @implementation TFLImageSegmenterTests
  11170. - (void)setUp {
  11171. - // Put setup code here. This method is called before the invocation of each test method in the
  11172. - // class.
  11173. + // Put setup code here. This method is called before the invocation of each
  11174. + // test method in the class.
  11175. [super setUp];
  11176. - self.modelPath = [[NSBundle bundleForClass:self.class] pathForResource:@"deeplabv3"
  11177. - ofType:@"tflite"];
  11178. + self.modelPath =
  11179. + [[NSBundle bundleForClass:self.class] pathForResource:@"deeplabv3"
  11180. + ofType:@"tflite"];
  11181. XCTAssertNotNil(self.modelPath);
  11182. }
  11183. -- (void)compareWithDeepLabV3PartialColoredLabels:(NSArray<TFLColoredLabel *> *)coloredLabels {
  11184. +- (void)compareWithDeepLabV3PartialColoredLabels:
  11185. + (NSArray<TFLColoredLabel*>*)coloredLabels {
  11186. VerifyColoredLabel(coloredLabels[0],
  11187. 0, // expectedR
  11188. 0, // expectedG
  11189. @@ -204,58 +207,67 @@ NSInteger const deepLabV3SegmentationHeight = 257;
  11190. }
  11191. - (void)testSuccessfulImageSegmentationWithCategoryMask {
  11192. - TFLImageSegmenterOptions *imageSegmenterOptions =
  11193. + TFLImageSegmenterOptions* imageSegmenterOptions =
  11194. [[TFLImageSegmenterOptions alloc] initWithModelPath:self.modelPath];
  11195. - TFLImageSegmenter *imageSegmenter =
  11196. - [TFLImageSegmenter imageSegmenterWithOptions:imageSegmenterOptions error:nil];
  11197. + TFLImageSegmenter* imageSegmenter =
  11198. + [TFLImageSegmenter imageSegmenterWithOptions:imageSegmenterOptions
  11199. + error:nil];
  11200. XCTAssertNotNil(imageSegmenter);
  11201. - GMLImage *gmlImage = [GMLImage imageFromBundleWithClass:self.class
  11202. - fileName:@"segmentation_input_rotation0"
  11203. - ofType:@"jpg"];
  11204. + GMLImage* gmlImage =
  11205. + [GMLImage imageFromBundleWithClass:self.class
  11206. + fileName:@"segmentation_input_rotation0"
  11207. + ofType:@"jpg"];
  11208. XCTAssertNotNil(gmlImage);
  11209. - TFLSegmentationResult *segmentationResult = [imageSegmenter segmentWithGMLImage:gmlImage
  11210. - error:nil];
  11211. + TFLSegmentationResult* segmentationResult =
  11212. + [imageSegmenter segmentWithGMLImage:gmlImage error:nil];
  11213. XCTAssertNotNil(segmentationResult);
  11214. XCTAssertEqual(segmentationResult.segmentations.count, 1);
  11215. XCTAssertNotNil(segmentationResult.segmentations[0].coloredLabels);
  11216. - [self compareWithDeepLabV3PartialColoredLabels:segmentationResult.segmentations[0].coloredLabels];
  11217. + [self compareWithDeepLabV3PartialColoredLabels:segmentationResult
  11218. + .segmentations[0]
  11219. + .coloredLabels];
  11220. XCTAssertNotNil(segmentationResult.segmentations[0].categoryMask);
  11221. XCTAssertTrue(segmentationResult.segmentations[0].categoryMask.mask != nil);
  11222. - GMLImage *goldenImage = [GMLImage imageFromBundleWithClass:self.class
  11223. - fileName:@"segmentation_golden_rotation0"
  11224. - ofType:@"png"];
  11225. + GMLImage* goldenImage =
  11226. + [GMLImage imageFromBundleWithClass:self.class
  11227. + fileName:@"segmentation_golden_rotation0"
  11228. + ofType:@"png"];
  11229. XCTAssertNotNil(goldenImage);
  11230. CVPixelBufferRef pixelBuffer = [goldenImage grayScalePixelBuffer];
  11231. CVPixelBufferLockBaseAddress(pixelBuffer, kCVPixelBufferLock_ReadOnly);
  11232. - UInt8 *pixelBufferBaseAddress = (UInt8 *)CVPixelBufferGetBaseAddress(pixelBuffer);
  11233. + UInt8* pixelBufferBaseAddress =
  11234. + (UInt8*)CVPixelBufferGetBaseAddress(pixelBuffer);
  11235. XCTAssertEqual(deepLabV3SegmentationWidth,
  11236. segmentationResult.segmentations[0].categoryMask.width);
  11237. XCTAssertEqual(deepLabV3SegmentationHeight,
  11238. segmentationResult.segmentations[0].categoryMask.height);
  11239. - NSInteger numPixels = deepLabV3SegmentationWidth * deepLabV3SegmentationHeight;
  11240. + NSInteger numPixels =
  11241. + deepLabV3SegmentationWidth * deepLabV3SegmentationHeight;
  11242. float inconsistentPixels = 0;
  11243. for (int i = 0; i < numPixels; i++)
  11244. - if (segmentationResult.segmentations[0].categoryMask.mask[i] * kGoldenMaskMagnificationFactor !=
  11245. + if (segmentationResult.segmentations[0].categoryMask.mask[i] *
  11246. + kGoldenMaskMagnificationFactor !=
  11247. pixelBufferBaseAddress[i])
  11248. inconsistentPixels += 1;
  11249. CVPixelBufferUnlockBaseAddress(pixelBuffer, kCVPixelBufferLock_ReadOnly);
  11250. - XCTAssertLessThan(inconsistentPixels / (float)numPixels, kGoldenMaskTolerance);
  11251. + XCTAssertLessThan(inconsistentPixels / (float)numPixels,
  11252. + kGoldenMaskTolerance);
  11253. }
  11254. @end
  11255. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/vision/object_detector/TFLObjectDetectorTests.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/vision/object_detector/TFLObjectDetectorTests.m
  11256. index f6820f335e18b..f7091a5995b02 100644
  11257. --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/vision/object_detector/TFLObjectDetectorTests.m
  11258. +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/vision/object_detector/TFLObjectDetectorTests.m
  11259. @@ -18,16 +18,22 @@
  11260. #import "tensorflow_lite_support/ios/task/vision/sources/TFLObjectDetector.h"
  11261. #import "tensorflow_lite_support/ios/task/vision/utils/sources/GMLImage+Utils.h"
  11262. -#define VerifyDetection(detection, expectedBoundingBox, expectedFirstScore, expectedFirstLabel) \
  11263. - XCTAssertGreaterThan(detection.categories.count, 0); \
  11264. - NSLog(@"Detected %f", detection.categories[0].score); \
  11265. - NSLog(@"Expected %f", expectedFirstScore); \
  11266. - XCTAssertEqual(detection.boundingBox.origin.x, expectedBoundingBox.origin.x); \
  11267. - XCTAssertEqual(detection.boundingBox.origin.y, expectedBoundingBox.origin.y); \
  11268. - XCTAssertEqual(detection.boundingBox.size.width, expectedBoundingBox.size.width); \
  11269. - XCTAssertEqual(detection.boundingBox.size.height, expectedBoundingBox.size.height); \
  11270. - XCTAssertEqualObjects(detection.categories[0].label, expectedFirstLabel); \
  11271. - XCTAssertEqualWithAccuracy(detection.categories[0].score, expectedFirstScore, 0.001)
  11272. +#define VerifyDetection(detection, expectedBoundingBox, expectedFirstScore, \
  11273. + expectedFirstLabel) \
  11274. + XCTAssertGreaterThan(detection.categories.count, 0); \
  11275. + NSLog(@"Detected %f", detection.categories[0].score); \
  11276. + NSLog(@"Expected %f", expectedFirstScore); \
  11277. + XCTAssertEqual(detection.boundingBox.origin.x, \
  11278. + expectedBoundingBox.origin.x); \
  11279. + XCTAssertEqual(detection.boundingBox.origin.y, \
  11280. + expectedBoundingBox.origin.y); \
  11281. + XCTAssertEqual(detection.boundingBox.size.width, \
  11282. + expectedBoundingBox.size.width); \
  11283. + XCTAssertEqual(detection.boundingBox.size.height, \
  11284. + expectedBoundingBox.size.height); \
  11285. + XCTAssertEqualObjects(detection.categories[0].label, expectedFirstLabel); \
  11286. + XCTAssertEqualWithAccuracy(detection.categories[0].score, \
  11287. + expectedFirstScore, 0.001)
  11288. @interface TFLObjectDetectorTests : XCTestCase
  11289. @property(nonatomic, nullable) NSString *modelPath;
  11290. @@ -77,8 +83,9 @@
  11291. [TFLObjectDetector objectDetectorWithOptions:objectDetectorOptions error:nil];
  11292. XCTAssertNotNil(objectDetector);
  11293. - GMLImage *gmlImage =
  11294. - [GMLImage imageFromBundleWithClass:self.class fileName:@"cats_and_dogs" ofType:@"jpg"];
  11295. + GMLImage* gmlImage = [GMLImage imageFromBundleWithClass:self.class
  11296. + fileName:@"cats_and_dogs"
  11297. + ofType:@"jpg"];
  11298. XCTAssertNotNil(gmlImage);
  11299. TFLDetectionResult *detectionResults = [objectDetector detectWithGMLImage:gmlImage error:nil];
  11300. @@ -95,8 +102,9 @@
  11301. [TFLObjectDetector objectDetectorWithOptions:objectDetectorOptions error:nil];
  11302. XCTAssertNotNil(objectDetector);
  11303. - GMLImage *gmlImage =
  11304. - [GMLImage imageFromBundleWithClass:self.class fileName:@"cats_and_dogs" ofType:@"jpg"];
  11305. + GMLImage* gmlImage = [GMLImage imageFromBundleWithClass:self.class
  11306. + fileName:@"cats_and_dogs"
  11307. + ofType:@"jpg"];
  11308. XCTAssertNotNil(gmlImage);
  11309. TFLDetectionResult *detectionResult = [objectDetector detectWithGMLImage:gmlImage error:nil];
  11310. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLBertTokenizer.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLBertTokenizer.h
  11311. index ed679c22a467b..c10c82afc1913 100644
  11312. --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLBertTokenizer.h
  11313. +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLBertTokenizer.h
  11314. @@ -28,11 +28,13 @@ NS_ASSUME_NONNULL_BEGIN
  11315. /**
  11316. * Initializes the tokenizer with the path to wordpiece vocabulary file.
  11317. */
  11318. -- (instancetype)initWithVocabPath:(NSString *)vocabPath NS_DESIGNATED_INITIALIZER;
  11319. +- (instancetype)initWithVocabPath:(NSString*)vocabPath
  11320. + NS_DESIGNATED_INITIALIZER;
  11321. /**
  11322. * Initializes the tokenizer with a list of tokens.
  11323. */
  11324. -- (instancetype)initWithVocab:(NSArray<NSString *> *)vocab NS_DESIGNATED_INITIALIZER;
  11325. +- (instancetype)initWithVocab:(NSArray<NSString*>*)vocab
  11326. + NS_DESIGNATED_INITIALIZER;
  11327. @end
  11328. NS_ASSUME_NONNULL_END
  11329. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLSentencepieceTokenizer.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLSentencepieceTokenizer.h
  11330. index f556dc642d736..be4010abd8e6f 100644
  11331. --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLSentencepieceTokenizer.h
  11332. +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLSentencepieceTokenizer.h
  11333. @@ -28,6 +28,6 @@ NS_ASSUME_NONNULL_BEGIN
  11334. /**
  11335. * Initializes the tokenizer with the path to sentencepiece model file.
  11336. */
  11337. -- (instancetype)initWithModelPath:(NSString *)modelPath;
  11338. +- (instancetype)initWithModelPath:(NSString*)modelPath;
  11339. @end
  11340. NS_ASSUME_NONNULL_END
  11341. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLTokenizer.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLTokenizer.h
  11342. index ee0972f8aba30..bd832060b6e80 100644
  11343. --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLTokenizer.h
  11344. +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLTokenizer.h
  11345. @@ -26,7 +26,7 @@ NS_ASSUME_NONNULL_BEGIN
  11346. *
  11347. * @return A list of tokens.
  11348. */
  11349. -- (NSArray<NSString *> *)tokensFromInput:(NSString *)input;
  11350. +- (NSArray<NSString*>*)tokensFromInput:(NSString*)input;
  11351. /*
  11352. * Convert a list of tokens back to their coressponding IDs.
  11353. @@ -34,6 +34,6 @@ NS_ASSUME_NONNULL_BEGIN
  11354. *
  11355. * @return A list of ids.
  11356. */
  11357. -- (NSArray<NSNumber *> *)idsFromTokens:(NSArray<NSString *> *)tokens;
  11358. +- (NSArray<NSNumber*>*)idsFromTokens:(NSArray<NSString*>*)tokens;
  11359. @end
  11360. NS_ASSUME_NONNULL_END
  11361. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLTokenizerUtil.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLTokenizerUtil.h
  11362. index 574b555301616..14e2906675b71 100644
  11363. --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLTokenizerUtil.h
  11364. +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLTokenizerUtil.h
  11365. @@ -18,21 +18,24 @@ limitations under the License.
  11366. using ::tflite::support::text::tokenizer::Tokenizer;
  11367. /**
  11368. - * Invokes the cpp tokenizer's tokenize function and converts input/output to objc.
  11369. + * Invokes the cpp tokenizer's tokenize function and converts input/output to
  11370. + * objc.
  11371. *
  11372. * @param tokenizer The cpp tokenizer pointer.
  11373. * @param input The input string to be tokenized.
  11374. *
  11375. * @return A list of tokens.
  11376. */
  11377. -NSArray<NSString *> *Tokenize(Tokenizer *tokenizer, NSString *input);
  11378. +NSArray<NSString*>* Tokenize(Tokenizer* tokenizer, NSString* input);
  11379. /**
  11380. - * Invokes the cpp tokenizer's convertTokensToIds function and converts input/output to objc.
  11381. + * Invokes the cpp tokenizer's convertTokensToIds function and converts
  11382. + * input/output to objc.
  11383. *
  11384. * @param tokenizer The cpp tokenizer pointer.
  11385. * @param input The tokens to be converted.
  11386. *
  11387. * @return A list of ids.
  11388. */
  11389. -NSArray<NSNumber *> *ConvertTokensToIds(Tokenizer *tokenizer, NSArray<NSString *> *tokens);
  11390. +NSArray<NSNumber*>* ConvertTokensToIds(Tokenizer* tokenizer,
  11391. + NSArray<NSString*>* tokens);
  11392. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/utils/Sources/TFLStringUtil.mm b/third_party/tflite_support/src/tensorflow_lite_support/ios/utils/Sources/TFLStringUtil.mm
  11393. index 6e9cf23802427..2a11bb6730474 100644
  11394. --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/utils/Sources/TFLStringUtil.mm
  11395. +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/utils/Sources/TFLStringUtil.mm
  11396. @@ -14,10 +14,13 @@ limitations under the License.
  11397. ==============================================================================*/
  11398. #import "third_party/tensorflow_lite_support/ios/utils/Sources/TFLStringUtil.h"
  11399. -std::string MakeString(NSString* str) { return std::string([str UTF8String]); }
  11400. +std::string MakeString(NSString* str) {
  11401. + return std::string([str UTF8String]);
  11402. +}
  11403. NSString* MakeNSString(const std::string& str) {
  11404. - return [[NSString alloc] initWithBytes:const_cast<void*>(static_cast<const void*>(str.data()))
  11405. - length:str.length()
  11406. - encoding:NSUTF8StringEncoding];
  11407. + return [[NSString alloc]
  11408. + initWithBytes:const_cast<void*>(static_cast<const void*>(str.data()))
  11409. + length:str.length()
  11410. + encoding:NSUTF8StringEncoding];
  11411. }
  11412. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/audio/TensorAudio.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/audio/TensorAudio.java
  11413. index 2b59c675b0316..6f2f2d437fb4a 100644
  11414. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/audio/TensorAudio.java
  11415. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/audio/TensorAudio.java
  11416. @@ -15,19 +15,24 @@ limitations under the License.
  11417. package org.tensorflow.lite.support.audio;
  11418. -import static java.lang.System.arraycopy;
  11419. import static org.tensorflow.lite.support.common.internal.SupportPreconditions.checkArgument;
  11420. +import static java.lang.System.arraycopy;
  11421. +
  11422. import android.media.AudioFormat;
  11423. import android.media.AudioRecord;
  11424. import android.os.Build;
  11425. +
  11426. import androidx.annotation.RequiresApi;
  11427. +
  11428. import com.google.auto.value.AutoValue;
  11429. +
  11430. +import org.tensorflow.lite.DataType;
  11431. +import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
  11432. +
  11433. import java.nio.ByteBuffer;
  11434. import java.nio.ByteOrder;
  11435. import java.nio.FloatBuffer;
  11436. -import org.tensorflow.lite.DataType;
  11437. -import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
  11438. /**
  11439. * Defines a ring buffer and some utility functions to prepare the input audio samples.
  11440. @@ -60,285 +65,282 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
  11441. * </pre>
  11442. */
  11443. public class TensorAudio {
  11444. + private static final String TAG = TensorAudio.class.getSimpleName();
  11445. + private final FloatRingBuffer buffer;
  11446. + private final TensorAudioFormat format;
  11447. - private static final String TAG = TensorAudio.class.getSimpleName();
  11448. - private final FloatRingBuffer buffer;
  11449. - private final TensorAudioFormat format;
  11450. -
  11451. - /**
  11452. - * Creates a {@link android.media.AudioRecord} instance with a ring buffer whose size is {@code
  11453. - * sampleCounts} * {@code format.getChannels()}.
  11454. - *
  11455. - * @param format the expected {@link TensorAudioFormat} of audio data loaded into this class.
  11456. - * @param sampleCounts the number of samples to be fed into the model
  11457. - */
  11458. - public static TensorAudio create(TensorAudioFormat format, int sampleCounts) {
  11459. - return new TensorAudio(format, sampleCounts);
  11460. - }
  11461. -
  11462. - /**
  11463. - * Creates a {@link TensorAudio} instance with a ring buffer whose size is {@code sampleCounts} *
  11464. - * {@code format.getChannelCount()}.
  11465. - *
  11466. - * @param format the {@link android.media.AudioFormat} required by the TFLite model. It defines
  11467. - * the number of channels and sample rate.
  11468. - * @param sampleCounts the number of samples to be fed into the model
  11469. - */
  11470. - public static TensorAudio create(AudioFormat format, int sampleCounts) {
  11471. - return new TensorAudio(TensorAudioFormat.create(format), sampleCounts);
  11472. - }
  11473. -
  11474. - /**
  11475. - * Wraps a few constants describing the format of the incoming audio samples, namely number of
  11476. - * channels and the sample rate. By default, channels is set to 1.
  11477. - */
  11478. - @AutoValue
  11479. - public abstract static class TensorAudioFormat {
  11480. - private static final int DEFAULT_CHANNELS = 1;
  11481. -
  11482. - /** Creates a {@link TensorAudioFormat} instance from Android AudioFormat class. */
  11483. - @RequiresApi(Build.VERSION_CODES.M)
  11484. - public static TensorAudioFormat create(AudioFormat format) {
  11485. - return TensorAudioFormat.builder()
  11486. - .setChannels(format.getChannelCount())
  11487. - .setSampleRate(format.getSampleRate())
  11488. - .build();
  11489. + /**
  11490. + * Creates a {@link android.media.AudioRecord} instance with a ring buffer whose size is {@code
  11491. + * sampleCounts} * {@code format.getChannels()}.
  11492. + *
  11493. + * @param format the expected {@link TensorAudioFormat} of audio data loaded into this class.
  11494. + * @param sampleCounts the number of samples to be fed into the model
  11495. + */
  11496. + public static TensorAudio create(TensorAudioFormat format, int sampleCounts) {
  11497. + return new TensorAudio(format, sampleCounts);
  11498. }
  11499. - public abstract int getChannels();
  11500. -
  11501. - public abstract int getSampleRate();
  11502. -
  11503. - public static Builder builder() {
  11504. - return new AutoValue_TensorAudio_TensorAudioFormat.Builder().setChannels(DEFAULT_CHANNELS);
  11505. + /**
  11506. + * Creates a {@link TensorAudio} instance with a ring buffer whose size is {@code sampleCounts}
  11507. + * *
  11508. + * {@code format.getChannelCount()}.
  11509. + *
  11510. + * @param format the {@link android.media.AudioFormat} required by the TFLite model. It defines
  11511. + * the number of channels and sample rate.
  11512. + * @param sampleCounts the number of samples to be fed into the model
  11513. + */
  11514. + public static TensorAudio create(AudioFormat format, int sampleCounts) {
  11515. + return new TensorAudio(TensorAudioFormat.create(format), sampleCounts);
  11516. }
  11517. - /** Builder for {@link TensorAudioFormat} */
  11518. - @AutoValue.Builder
  11519. - public abstract static class Builder {
  11520. -
  11521. - /* By default, it's set to have 1 channel. */
  11522. - public abstract Builder setChannels(int value);
  11523. -
  11524. - public abstract Builder setSampleRate(int value);
  11525. -
  11526. - abstract TensorAudioFormat autoBuild();
  11527. -
  11528. - public TensorAudioFormat build() {
  11529. - TensorAudioFormat format = autoBuild();
  11530. - checkArgument(format.getChannels() > 0, "Number of channels should be greater than 0");
  11531. - checkArgument(format.getSampleRate() > 0, "Sample rate should be greater than 0");
  11532. - return format;
  11533. - }
  11534. - }
  11535. - }
  11536. -
  11537. - /**
  11538. - * Stores the input audio samples {@code src} in the ring buffer.
  11539. - *
  11540. - * @param src input audio samples in {@link android.media.AudioFormat#ENCODING_PCM_FLOAT}. For
  11541. - * multi-channel input, the array is interleaved.
  11542. - */
  11543. - public void load(float[] src) {
  11544. - load(src, 0, src.length);
  11545. - }
  11546. -
  11547. - /**
  11548. - * Stores the input audio samples {@code src} in the ring buffer.
  11549. - *
  11550. - * @param src input audio samples in {@link android.media.AudioFormat#ENCODING_PCM_FLOAT}. For
  11551. - * multi-channel input, the array is interleaved.
  11552. - * @param offsetInFloat starting position in the {@code src} array
  11553. - * @param sizeInFloat the number of float values to be copied
  11554. - * @throws IllegalArgumentException for incompatible audio format or incorrect input size
  11555. - */
  11556. - public void load(float[] src, int offsetInFloat, int sizeInFloat) {
  11557. - checkArgument(
  11558. - sizeInFloat % format.getChannels() == 0,
  11559. - String.format(
  11560. - "Size (%d) needs to be a multiplier of the number of channels (%d)",
  11561. - sizeInFloat, format.getChannels()));
  11562. - buffer.load(src, offsetInFloat, sizeInFloat);
  11563. - }
  11564. -
  11565. - /**
  11566. - * Converts the input audio samples {@code src} to ENCODING_PCM_FLOAT, then stores it in the ring
  11567. - * buffer.
  11568. - *
  11569. - * @param src input audio samples in {@link android.media.AudioFormat#ENCODING_PCM_16BIT}. For
  11570. - * multi-channel input, the array is interleaved.
  11571. - */
  11572. - public void load(short[] src) {
  11573. - load(src, 0, src.length);
  11574. - }
  11575. -
  11576. - /**
  11577. - * Converts the input audio samples {@code src} to ENCODING_PCM_FLOAT, then stores it in the ring
  11578. - * buffer.
  11579. - *
  11580. - * @param src input audio samples in {@link android.media.AudioFormat#ENCODING_PCM_16BIT}. For
  11581. - * multi-channel input, the array is interleaved.
  11582. - * @param offsetInShort starting position in the src array
  11583. - * @param sizeInShort the number of short values to be copied
  11584. - * @throws IllegalArgumentException if the source array can't be copied
  11585. - */
  11586. - public void load(short[] src, int offsetInShort, int sizeInShort) {
  11587. - checkArgument(
  11588. - offsetInShort + sizeInShort <= src.length,
  11589. - String.format(
  11590. - "Index out of range. offset (%d) + size (%d) should <= newData.length (%d)",
  11591. - offsetInShort, sizeInShort, src.length));
  11592. - float[] floatData = new float[sizeInShort];
  11593. - for (int i = 0; i < sizeInShort; i++) {
  11594. - // Convert the data to PCM Float encoding i.e. values between -1 and 1
  11595. - floatData[i] = src[i + offsetInShort] * 1.f / Short.MAX_VALUE;
  11596. - }
  11597. - load(floatData);
  11598. - }
  11599. -
  11600. - /**
  11601. - * Loads latest data from the {@link android.media.AudioRecord} in a non-blocking way. Only
  11602. - * supporting ENCODING_PCM_16BIT and ENCODING_PCM_FLOAT.
  11603. - *
  11604. - * @param record an instance of {@link android.media.AudioRecord}
  11605. - * @return number of captured audio values whose size is {@code channelCount * sampleCount}. If
  11606. - * there was no new data in the AudioRecord or an error occurred, this method will return 0.
  11607. - * @throws IllegalArgumentException for unsupported audio encoding format
  11608. - * @throws IllegalStateException if reading from AudioRecord failed
  11609. - */
  11610. - @RequiresApi(Build.VERSION_CODES.M)
  11611. - public int load(AudioRecord record) {
  11612. - checkArgument(
  11613. - this.format.equals(TensorAudioFormat.create(record.getFormat())),
  11614. - "Incompatible audio format.");
  11615. - int loadedValues = 0;
  11616. - if (record.getAudioFormat() == AudioFormat.ENCODING_PCM_FLOAT) {
  11617. - float[] newData = new float[record.getChannelCount() * record.getBufferSizeInFrames()];
  11618. - loadedValues = record.read(newData, 0, newData.length, AudioRecord.READ_NON_BLOCKING);
  11619. - if (loadedValues > 0) {
  11620. - load(newData, 0, loadedValues);
  11621. - return loadedValues;
  11622. - }
  11623. - } else if (record.getAudioFormat() == AudioFormat.ENCODING_PCM_16BIT) {
  11624. - short[] newData = new short[record.getChannelCount() * record.getBufferSizeInFrames()];
  11625. - loadedValues = record.read(newData, 0, newData.length, AudioRecord.READ_NON_BLOCKING);
  11626. - if (loadedValues > 0) {
  11627. - load(newData, 0, loadedValues);
  11628. - return loadedValues;
  11629. - }
  11630. - } else {
  11631. - throw new IllegalArgumentException(
  11632. - "Unsupported encoding. Requires ENCODING_PCM_16BIT or ENCODING_PCM_FLOAT.");
  11633. + /**
  11634. + * Wraps a few constants describing the format of the incoming audio samples, namely number of
  11635. + * channels and the sample rate. By default, channels is set to 1.
  11636. + */
  11637. + @AutoValue
  11638. + public abstract static class TensorAudioFormat {
  11639. + private static final int DEFAULT_CHANNELS = 1;
  11640. +
  11641. + /** Creates a {@link TensorAudioFormat} instance from Android AudioFormat class. */
  11642. + @RequiresApi(Build.VERSION_CODES.M)
  11643. + public static TensorAudioFormat create(AudioFormat format) {
  11644. + return TensorAudioFormat.builder()
  11645. + .setChannels(format.getChannelCount())
  11646. + .setSampleRate(format.getSampleRate())
  11647. + .build();
  11648. + }
  11649. +
  11650. + public abstract int getChannels();
  11651. +
  11652. + public abstract int getSampleRate();
  11653. +
  11654. + public static Builder builder() {
  11655. + return new AutoValue_TensorAudio_TensorAudioFormat.Builder().setChannels(
  11656. + DEFAULT_CHANNELS);
  11657. + }
  11658. +
  11659. + /** Builder for {@link TensorAudioFormat} */
  11660. + @AutoValue.Builder
  11661. + public abstract static class Builder {
  11662. + /* By default, it's set to have 1 channel. */
  11663. + public abstract Builder setChannels(int value);
  11664. +
  11665. + public abstract Builder setSampleRate(int value);
  11666. +
  11667. + abstract TensorAudioFormat autoBuild();
  11668. +
  11669. + public TensorAudioFormat build() {
  11670. + TensorAudioFormat format = autoBuild();
  11671. + checkArgument(
  11672. + format.getChannels() > 0, "Number of channels should be greater than 0");
  11673. + checkArgument(format.getSampleRate() > 0, "Sample rate should be greater than 0");
  11674. + return format;
  11675. + }
  11676. + }
  11677. }
  11678. - switch (loadedValues) {
  11679. - case AudioRecord.ERROR_INVALID_OPERATION:
  11680. - throw new IllegalStateException("AudioRecord.ERROR_INVALID_OPERATION");
  11681. -
  11682. - case AudioRecord.ERROR_BAD_VALUE:
  11683. - throw new IllegalStateException("AudioRecord.ERROR_BAD_VALUE");
  11684. -
  11685. - case AudioRecord.ERROR_DEAD_OBJECT:
  11686. - throw new IllegalStateException("AudioRecord.ERROR_DEAD_OBJECT");
  11687. + /**
  11688. + * Stores the input audio samples {@code src} in the ring buffer.
  11689. + *
  11690. + * @param src input audio samples in {@link android.media.AudioFormat#ENCODING_PCM_FLOAT}. For
  11691. + * multi-channel input, the array is interleaved.
  11692. + */
  11693. + public void load(float[] src) {
  11694. + load(src, 0, src.length);
  11695. + }
  11696. - case AudioRecord.ERROR:
  11697. - throw new IllegalStateException("AudioRecord.ERROR");
  11698. + /**
  11699. + * Stores the input audio samples {@code src} in the ring buffer.
  11700. + *
  11701. + * @param src input audio samples in {@link android.media.AudioFormat#ENCODING_PCM_FLOAT}. For
  11702. + * multi-channel input, the array is interleaved.
  11703. + * @param offsetInFloat starting position in the {@code src} array
  11704. + * @param sizeInFloat the number of float values to be copied
  11705. + * @throws IllegalArgumentException for incompatible audio format or incorrect input size
  11706. + */
  11707. + public void load(float[] src, int offsetInFloat, int sizeInFloat) {
  11708. + checkArgument(sizeInFloat % format.getChannels() == 0,
  11709. + String.format("Size (%d) needs to be a multiplier of the number of channels (%d)",
  11710. + sizeInFloat, format.getChannels()));
  11711. + buffer.load(src, offsetInFloat, sizeInFloat);
  11712. + }
  11713. - default:
  11714. - return 0;
  11715. + /**
  11716. + * Converts the input audio samples {@code src} to ENCODING_PCM_FLOAT, then stores it in the
  11717. + * ring buffer.
  11718. + *
  11719. + * @param src input audio samples in {@link android.media.AudioFormat#ENCODING_PCM_16BIT}. For
  11720. + * multi-channel input, the array is interleaved.
  11721. + */
  11722. + public void load(short[] src) {
  11723. + load(src, 0, src.length);
  11724. }
  11725. - }
  11726. -
  11727. - /**
  11728. - * Returns a float {@link TensorBuffer} holding all the available audio samples in {@link
  11729. - * android.media.AudioFormat#ENCODING_PCM_FLOAT} i.e. values are in the range of [-1, 1].
  11730. - */
  11731. - public TensorBuffer getTensorBuffer() {
  11732. - ByteBuffer byteBuffer = buffer.getBuffer();
  11733. - TensorBuffer tensorBuffer =
  11734. - TensorBuffer.createFixedSize(
  11735. - new int[] {
  11736. - /* batch= */ 1, /* modelInputLengthInFloat= */ byteBuffer.asFloatBuffer().limit()
  11737. - },
  11738. - DataType.FLOAT32);
  11739. - tensorBuffer.loadBuffer(byteBuffer);
  11740. - return tensorBuffer;
  11741. - }
  11742. -
  11743. - /* Returns the {@link TensorAudioFormat} associated with the tensor. */
  11744. - public TensorAudioFormat getFormat() {
  11745. - return format;
  11746. - }
  11747. -
  11748. - private TensorAudio(TensorAudioFormat format, int sampleCounts) {
  11749. - this.format = format;
  11750. - this.buffer = new FloatRingBuffer(sampleCounts * format.getChannels());
  11751. - }
  11752. -
  11753. - /** Actual implementation of the ring buffer. */
  11754. - private static class FloatRingBuffer {
  11755. -
  11756. - private final float[] buffer;
  11757. - private int nextIndex = 0;
  11758. -
  11759. - public FloatRingBuffer(int flatSize) {
  11760. - buffer = new float[flatSize];
  11761. +
  11762. + /**
  11763. + * Converts the input audio samples {@code src} to ENCODING_PCM_FLOAT, then stores it in the
  11764. + * ring buffer.
  11765. + *
  11766. + * @param src input audio samples in {@link android.media.AudioFormat#ENCODING_PCM_16BIT}. For
  11767. + * multi-channel input, the array is interleaved.
  11768. + * @param offsetInShort starting position in the src array
  11769. + * @param sizeInShort the number of short values to be copied
  11770. + * @throws IllegalArgumentException if the source array can't be copied
  11771. + */
  11772. + public void load(short[] src, int offsetInShort, int sizeInShort) {
  11773. + checkArgument(offsetInShort + sizeInShort <= src.length,
  11774. + String.format(
  11775. + "Index out of range. offset (%d) + size (%d) should <= newData.length (%d)",
  11776. + offsetInShort, sizeInShort, src.length));
  11777. + float[] floatData = new float[sizeInShort];
  11778. + for (int i = 0; i < sizeInShort; i++) {
  11779. + // Convert the data to PCM Float encoding i.e. values between -1 and 1
  11780. + floatData[i] = src[i + offsetInShort] * 1.f / Short.MAX_VALUE;
  11781. + }
  11782. + load(floatData);
  11783. }
  11784. /**
  11785. - * Loads the entire float array to the ring buffer. If the float array is longer than ring
  11786. - * buffer's capacity, samples with lower indices in the array will be ignored.
  11787. + * Loads latest data from the {@link android.media.AudioRecord} in a non-blocking way. Only
  11788. + * supporting ENCODING_PCM_16BIT and ENCODING_PCM_FLOAT.
  11789. + *
  11790. + * @param record an instance of {@link android.media.AudioRecord}
  11791. + * @return number of captured audio values whose size is {@code channelCount * sampleCount}. If
  11792. + * there was no new data in the AudioRecord or an error occurred, this method will return 0.
  11793. + * @throws IllegalArgumentException for unsupported audio encoding format
  11794. + * @throws IllegalStateException if reading from AudioRecord failed
  11795. */
  11796. - public void load(float[] newData) {
  11797. - load(newData, 0, newData.length);
  11798. + @RequiresApi(Build.VERSION_CODES.M)
  11799. + public int load(AudioRecord record) {
  11800. + checkArgument(this.format.equals(TensorAudioFormat.create(record.getFormat())),
  11801. + "Incompatible audio format.");
  11802. + int loadedValues = 0;
  11803. + if (record.getAudioFormat() == AudioFormat.ENCODING_PCM_FLOAT) {
  11804. + float[] newData = new float[record.getChannelCount() * record.getBufferSizeInFrames()];
  11805. + loadedValues = record.read(newData, 0, newData.length, AudioRecord.READ_NON_BLOCKING);
  11806. + if (loadedValues > 0) {
  11807. + load(newData, 0, loadedValues);
  11808. + return loadedValues;
  11809. + }
  11810. + } else if (record.getAudioFormat() == AudioFormat.ENCODING_PCM_16BIT) {
  11811. + short[] newData = new short[record.getChannelCount() * record.getBufferSizeInFrames()];
  11812. + loadedValues = record.read(newData, 0, newData.length, AudioRecord.READ_NON_BLOCKING);
  11813. + if (loadedValues > 0) {
  11814. + load(newData, 0, loadedValues);
  11815. + return loadedValues;
  11816. + }
  11817. + } else {
  11818. + throw new IllegalArgumentException(
  11819. + "Unsupported encoding. Requires ENCODING_PCM_16BIT or ENCODING_PCM_FLOAT.");
  11820. + }
  11821. +
  11822. + switch (loadedValues) {
  11823. + case AudioRecord.ERROR_INVALID_OPERATION:
  11824. + throw new IllegalStateException("AudioRecord.ERROR_INVALID_OPERATION");
  11825. +
  11826. + case AudioRecord.ERROR_BAD_VALUE:
  11827. + throw new IllegalStateException("AudioRecord.ERROR_BAD_VALUE");
  11828. +
  11829. + case AudioRecord.ERROR_DEAD_OBJECT:
  11830. + throw new IllegalStateException("AudioRecord.ERROR_DEAD_OBJECT");
  11831. +
  11832. + case AudioRecord.ERROR:
  11833. + throw new IllegalStateException("AudioRecord.ERROR");
  11834. +
  11835. + default:
  11836. + return 0;
  11837. + }
  11838. }
  11839. /**
  11840. - * Loads a slice of the float array to the ring buffer. If the float array is longer than ring
  11841. - * buffer's capacity, samples with lower indices in the array will be ignored.
  11842. + * Returns a float {@link TensorBuffer} holding all the available audio samples in {@link
  11843. + * android.media.AudioFormat#ENCODING_PCM_FLOAT} i.e. values are in the range of [-1, 1].
  11844. */
  11845. - public void load(float[] newData, int offset, int size) {
  11846. - checkArgument(
  11847. - offset + size <= newData.length,
  11848. - String.format(
  11849. - "Index out of range. offset (%d) + size (%d) should <= newData.length (%d)",
  11850. - offset, size, newData.length));
  11851. - // If buffer can't hold all the data, only keep the most recent data of size buffer.length
  11852. - if (size > buffer.length) {
  11853. - offset += (size - buffer.length);
  11854. - size = buffer.length;
  11855. - }
  11856. - if (nextIndex + size < buffer.length) {
  11857. - // No need to wrap nextIndex, just copy newData[offset:offset + size]
  11858. - // to buffer[nextIndex:nextIndex+size]
  11859. - arraycopy(newData, offset, buffer, nextIndex, size);
  11860. - } else {
  11861. - // Need to wrap nextIndex, perform copy in two chunks.
  11862. - int firstChunkSize = buffer.length - nextIndex;
  11863. - // First copy newData[offset:offset+firstChunkSize] to buffer[nextIndex:buffer.length]
  11864. - arraycopy(newData, offset, buffer, nextIndex, firstChunkSize);
  11865. - // Then copy newData[offset+firstChunkSize:offset+size] to buffer[0:size-firstChunkSize]
  11866. - arraycopy(newData, offset + firstChunkSize, buffer, 0, size - firstChunkSize);
  11867. - }
  11868. -
  11869. - nextIndex = (nextIndex + size) % buffer.length;
  11870. + public TensorBuffer getTensorBuffer() {
  11871. + ByteBuffer byteBuffer = buffer.getBuffer();
  11872. + TensorBuffer tensorBuffer = TensorBuffer.createFixedSize(
  11873. + new int[] {/* batch= */ 1,
  11874. + /* modelInputLengthInFloat= */ byteBuffer.asFloatBuffer().limit()},
  11875. + DataType.FLOAT32);
  11876. + tensorBuffer.loadBuffer(byteBuffer);
  11877. + return tensorBuffer;
  11878. + }
  11879. +
  11880. + /* Returns the {@link TensorAudioFormat} associated with the tensor. */
  11881. + public TensorAudioFormat getFormat() {
  11882. + return format;
  11883. }
  11884. - public ByteBuffer getBuffer() {
  11885. - // Create non-direct buffers. On Pixel 4, creating direct buffer costs around 0.1 ms, which
  11886. - // can be 5x ~ 10x longer compared to non-direct buffer backed by arrays (around 0.01ms), so
  11887. - // generally we don't create direct buffer for every invocation.
  11888. - ByteBuffer byteBuffer = ByteBuffer.allocate(DataType.FLOAT32.byteSize() * buffer.length);
  11889. - byteBuffer.order(ByteOrder.nativeOrder());
  11890. - FloatBuffer result = byteBuffer.asFloatBuffer();
  11891. - result.put(buffer, nextIndex, buffer.length - nextIndex);
  11892. - result.put(buffer, 0, nextIndex);
  11893. - byteBuffer.rewind();
  11894. - return byteBuffer;
  11895. + private TensorAudio(TensorAudioFormat format, int sampleCounts) {
  11896. + this.format = format;
  11897. + this.buffer = new FloatRingBuffer(sampleCounts * format.getChannels());
  11898. }
  11899. - public int getCapacity() {
  11900. - return buffer.length;
  11901. + /** Actual implementation of the ring buffer. */
  11902. + private static class FloatRingBuffer {
  11903. + private final float[] buffer;
  11904. + private int nextIndex = 0;
  11905. +
  11906. + public FloatRingBuffer(int flatSize) {
  11907. + buffer = new float[flatSize];
  11908. + }
  11909. +
  11910. + /**
  11911. + * Loads the entire float array to the ring buffer. If the float array is longer than ring
  11912. + * buffer's capacity, samples with lower indices in the array will be ignored.
  11913. + */
  11914. + public void load(float[] newData) {
  11915. + load(newData, 0, newData.length);
  11916. + }
  11917. +
  11918. + /**
  11919. + * Loads a slice of the float array to the ring buffer. If the float array is longer than
  11920. + * ring buffer's capacity, samples with lower indices in the array will be ignored.
  11921. + */
  11922. + public void load(float[] newData, int offset, int size) {
  11923. + checkArgument(offset + size <= newData.length,
  11924. + String.format(
  11925. + "Index out of range. offset (%d) + size (%d) should <= newData.length (%d)",
  11926. + offset, size, newData.length));
  11927. + // If buffer can't hold all the data, only keep the most recent data of size
  11928. + // buffer.length
  11929. + if (size > buffer.length) {
  11930. + offset += (size - buffer.length);
  11931. + size = buffer.length;
  11932. + }
  11933. + if (nextIndex + size < buffer.length) {
  11934. + // No need to wrap nextIndex, just copy newData[offset:offset + size]
  11935. + // to buffer[nextIndex:nextIndex+size]
  11936. + arraycopy(newData, offset, buffer, nextIndex, size);
  11937. + } else {
  11938. + // Need to wrap nextIndex, perform copy in two chunks.
  11939. + int firstChunkSize = buffer.length - nextIndex;
  11940. + // First copy newData[offset:offset+firstChunkSize] to
  11941. + // buffer[nextIndex:buffer.length]
  11942. + arraycopy(newData, offset, buffer, nextIndex, firstChunkSize);
  11943. + // Then copy newData[offset+firstChunkSize:offset+size] to
  11944. + // buffer[0:size-firstChunkSize]
  11945. + arraycopy(newData, offset + firstChunkSize, buffer, 0, size - firstChunkSize);
  11946. + }
  11947. +
  11948. + nextIndex = (nextIndex + size) % buffer.length;
  11949. + }
  11950. +
  11951. + public ByteBuffer getBuffer() {
  11952. + // Create non-direct buffers. On Pixel 4, creating direct buffer costs around 0.1 ms,
  11953. + // which can be 5x ~ 10x longer compared to non-direct buffer backed by arrays (around
  11954. + // 0.01ms), so generally we don't create direct buffer for every invocation.
  11955. + ByteBuffer byteBuffer =
  11956. + ByteBuffer.allocate(DataType.FLOAT32.byteSize() * buffer.length);
  11957. + byteBuffer.order(ByteOrder.nativeOrder());
  11958. + FloatBuffer result = byteBuffer.asFloatBuffer();
  11959. + result.put(buffer, nextIndex, buffer.length - nextIndex);
  11960. + result.put(buffer, 0, nextIndex);
  11961. + byteBuffer.rewind();
  11962. + return byteBuffer;
  11963. + }
  11964. +
  11965. + public int getCapacity() {
  11966. + return buffer.length;
  11967. + }
  11968. }
  11969. - }
  11970. }
  11971. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/FileUtil.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/FileUtil.java
  11972. index 776391b526b47..6090f85d99083 100644
  11973. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/FileUtil.java
  11974. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/FileUtil.java
  11975. @@ -17,6 +17,10 @@ package org.tensorflow.lite.support.common;
  11976. import android.content.Context;
  11977. import android.content.res.AssetFileDescriptor;
  11978. +
  11979. +import org.checkerframework.checker.nullness.qual.NonNull;
  11980. +import org.tensorflow.lite.support.common.internal.SupportPreconditions;
  11981. +
  11982. import java.io.BufferedReader;
  11983. import java.io.FileInputStream;
  11984. import java.io.IOException;
  11985. @@ -28,160 +32,159 @@ import java.nio.channels.FileChannel;
  11986. import java.nio.charset.Charset;
  11987. import java.util.ArrayList;
  11988. import java.util.List;
  11989. -import org.checkerframework.checker.nullness.qual.NonNull;
  11990. -import org.tensorflow.lite.support.common.internal.SupportPreconditions;
  11991. /** File I/O utilities. */
  11992. public class FileUtil {
  11993. - private FileUtil() {}
  11994. -
  11995. - /**
  11996. - * Loads labels from the label file into a list of strings.
  11997. - *
  11998. - * <p>A legal label file is the plain text file whose contents are split into lines, and each line
  11999. - * is an individual value. The file should be in assets of the context.
  12000. - *
  12001. - * @param context The context holds assets.
  12002. - * @param filePath The path of the label file, relative with assets directory.
  12003. - * @return a list of labels.
  12004. - * @throws IOException if error occurs to open or read the file.
  12005. - */
  12006. - @NonNull
  12007. - public static List<String> loadLabels(@NonNull Context context, @NonNull String filePath)
  12008. - throws IOException {
  12009. - return loadLabels(context, filePath, Charset.defaultCharset());
  12010. - }
  12011. -
  12012. - /**
  12013. - * Loads labels from the label file into a list of strings.
  12014. - *
  12015. - * <p>A legal label file is the plain text file whose contents are split into lines, and each line
  12016. - * is an individual value. The empty lines will be ignored. The file should be in assets of the
  12017. - * context.
  12018. - *
  12019. - * @param context The context holds assets.
  12020. - * @param filePath The path of the label file, relative with assets directory.
  12021. - * @param cs {@code Charset} to use when decoding content of label file.
  12022. - * @return a list of labels.
  12023. - * @throws IOException if error occurs to open or read the file.
  12024. - */
  12025. - @NonNull
  12026. - public static List<String> loadLabels(
  12027. - @NonNull Context context, @NonNull String filePath, Charset cs) throws IOException {
  12028. - SupportPreconditions.checkNotNull(context, "Context cannot be null.");
  12029. - SupportPreconditions.checkNotNull(filePath, "File path cannot be null.");
  12030. - try (InputStream inputStream = context.getAssets().open(filePath)) {
  12031. - return loadLabels(inputStream, cs);
  12032. + private FileUtil() {}
  12033. +
  12034. + /**
  12035. + * Loads labels from the label file into a list of strings.
  12036. + *
  12037. + * <p>A legal label file is the plain text file whose contents are split into lines, and each
  12038. + * line is an individual value. The file should be in assets of the context.
  12039. + *
  12040. + * @param context The context holds assets.
  12041. + * @param filePath The path of the label file, relative with assets directory.
  12042. + * @return a list of labels.
  12043. + * @throws IOException if error occurs to open or read the file.
  12044. + */
  12045. + @NonNull
  12046. + public static List<String> loadLabels(@NonNull Context context, @NonNull String filePath)
  12047. + throws IOException {
  12048. + return loadLabels(context, filePath, Charset.defaultCharset());
  12049. }
  12050. - }
  12051. -
  12052. - /**
  12053. - * Loads labels from an input stream of an opened label file. See details for label files in
  12054. - * {@link FileUtil#loadLabels(Context, String)}.
  12055. - *
  12056. - * @param inputStream the input stream of an opened label file.
  12057. - * @return a list of labels.
  12058. - * @throws IOException if error occurs to open or read the file.
  12059. - */
  12060. - @NonNull
  12061. - public static List<String> loadLabels(@NonNull InputStream inputStream) throws IOException {
  12062. - return loadLabels(inputStream, Charset.defaultCharset());
  12063. - }
  12064. -
  12065. - /**
  12066. - * Loads labels from an input stream of an opened label file. See details for label files in
  12067. - * {@link FileUtil#loadLabels(Context, String)}.
  12068. - *
  12069. - * @param inputStream the input stream of an opened label file.
  12070. - * @param cs {@code Charset} to use when decoding content of label file.
  12071. - * @return a list of labels.
  12072. - * @throws IOException if error occurs to open or read the file.
  12073. - */
  12074. - @NonNull
  12075. - public static List<String> loadLabels(@NonNull InputStream inputStream, Charset cs)
  12076. - throws IOException {
  12077. - List<String> labels = new ArrayList<>();
  12078. - try (BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream, cs))) {
  12079. - String line;
  12080. - while ((line = reader.readLine()) != null) {
  12081. - if (line.trim().length() > 0) {
  12082. - labels.add(line);
  12083. +
  12084. + /**
  12085. + * Loads labels from the label file into a list of strings.
  12086. + *
  12087. + * <p>A legal label file is the plain text file whose contents are split into lines, and each
  12088. + * line is an individual value. The empty lines will be ignored. The file should be in assets of
  12089. + * the context.
  12090. + *
  12091. + * @param context The context holds assets.
  12092. + * @param filePath The path of the label file, relative with assets directory.
  12093. + * @param cs {@code Charset} to use when decoding content of label file.
  12094. + * @return a list of labels.
  12095. + * @throws IOException if error occurs to open or read the file.
  12096. + */
  12097. + @NonNull
  12098. + public static List<String> loadLabels(
  12099. + @NonNull Context context, @NonNull String filePath, Charset cs) throws IOException {
  12100. + SupportPreconditions.checkNotNull(context, "Context cannot be null.");
  12101. + SupportPreconditions.checkNotNull(filePath, "File path cannot be null.");
  12102. + try (InputStream inputStream = context.getAssets().open(filePath)) {
  12103. + return loadLabels(inputStream, cs);
  12104. }
  12105. - }
  12106. - return labels;
  12107. }
  12108. - }
  12109. -
  12110. - /**
  12111. - * Loads a vocabulary file (a single-column text file) into a list of strings.
  12112. - *
  12113. - * <p>A vocabulary file is a single-column plain text file whose contents are split into lines,
  12114. - * and each line is an individual value. The file should be in assets of the context.
  12115. - *
  12116. - * @param context The context holds assets.
  12117. - * @param filePath The path of the vocabulary file, relative with assets directory.
  12118. - * @return a list of vocabulary words.
  12119. - * @throws IOException if error occurs to open or read the file.
  12120. - */
  12121. - @NonNull
  12122. - public static List<String> loadSingleColumnTextFile(
  12123. - @NonNull Context context, @NonNull String filePath, Charset cs) throws IOException {
  12124. - return loadLabels(context, filePath, cs);
  12125. - }
  12126. -
  12127. - /**
  12128. - * Loads vocabulary from an input stream of an opened vocabulary file (which is a single-column
  12129. - * text file).
  12130. - *
  12131. - * <p>A vocabulary file is a single-column plain text file whose contents are split into lines,
  12132. - * and each line is an individual value. The file should be in assets of the context.
  12133. - *
  12134. - * @param inputStream the input stream of an opened vocabulary file.
  12135. - * @return a list of vocabulary words.
  12136. - * @throws IOException if error occurs to open or read the file.
  12137. - */
  12138. - @NonNull
  12139. - public static List<String> loadSingleColumnTextFile(@NonNull InputStream inputStream, Charset cs)
  12140. - throws IOException {
  12141. - return loadLabels(inputStream, cs);
  12142. - }
  12143. -
  12144. - /**
  12145. - * Loads a file from the asset folder through memory mapping.
  12146. - *
  12147. - * @param context Application context to access assets.
  12148. - * @param filePath Asset path of the file.
  12149. - * @return the loaded memory mapped file.
  12150. - * @throws IOException if an I/O error occurs when loading the tflite model.
  12151. - */
  12152. - @NonNull
  12153. - public static MappedByteBuffer loadMappedFile(@NonNull Context context, @NonNull String filePath)
  12154. - throws IOException {
  12155. - SupportPreconditions.checkNotNull(context, "Context should not be null.");
  12156. - SupportPreconditions.checkNotNull(filePath, "File path cannot be null.");
  12157. - try (AssetFileDescriptor fileDescriptor = context.getAssets().openFd(filePath);
  12158. - FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor())) {
  12159. - FileChannel fileChannel = inputStream.getChannel();
  12160. - long startOffset = fileDescriptor.getStartOffset();
  12161. - long declaredLength = fileDescriptor.getDeclaredLength();
  12162. - return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
  12163. +
  12164. + /**
  12165. + * Loads labels from an input stream of an opened label file. See details for label files in
  12166. + * {@link FileUtil#loadLabels(Context, String)}.
  12167. + *
  12168. + * @param inputStream the input stream of an opened label file.
  12169. + * @return a list of labels.
  12170. + * @throws IOException if error occurs to open or read the file.
  12171. + */
  12172. + @NonNull
  12173. + public static List<String> loadLabels(@NonNull InputStream inputStream) throws IOException {
  12174. + return loadLabels(inputStream, Charset.defaultCharset());
  12175. + }
  12176. +
  12177. + /**
  12178. + * Loads labels from an input stream of an opened label file. See details for label files in
  12179. + * {@link FileUtil#loadLabels(Context, String)}.
  12180. + *
  12181. + * @param inputStream the input stream of an opened label file.
  12182. + * @param cs {@code Charset} to use when decoding content of label file.
  12183. + * @return a list of labels.
  12184. + * @throws IOException if error occurs to open or read the file.
  12185. + */
  12186. + @NonNull
  12187. + public static List<String> loadLabels(@NonNull InputStream inputStream, Charset cs)
  12188. + throws IOException {
  12189. + List<String> labels = new ArrayList<>();
  12190. + try (BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream, cs))) {
  12191. + String line;
  12192. + while ((line = reader.readLine()) != null) {
  12193. + if (line.trim().length() > 0) {
  12194. + labels.add(line);
  12195. + }
  12196. + }
  12197. + return labels;
  12198. + }
  12199. + }
  12200. +
  12201. + /**
  12202. + * Loads a vocabulary file (a single-column text file) into a list of strings.
  12203. + *
  12204. + * <p>A vocabulary file is a single-column plain text file whose contents are split into lines,
  12205. + * and each line is an individual value. The file should be in assets of the context.
  12206. + *
  12207. + * @param context The context holds assets.
  12208. + * @param filePath The path of the vocabulary file, relative with assets directory.
  12209. + * @return a list of vocabulary words.
  12210. + * @throws IOException if error occurs to open or read the file.
  12211. + */
  12212. + @NonNull
  12213. + public static List<String> loadSingleColumnTextFile(
  12214. + @NonNull Context context, @NonNull String filePath, Charset cs) throws IOException {
  12215. + return loadLabels(context, filePath, cs);
  12216. + }
  12217. +
  12218. + /**
  12219. + * Loads vocabulary from an input stream of an opened vocabulary file (which is a single-column
  12220. + * text file).
  12221. + *
  12222. + * <p>A vocabulary file is a single-column plain text file whose contents are split into lines,
  12223. + * and each line is an individual value. The file should be in assets of the context.
  12224. + *
  12225. + * @param inputStream the input stream of an opened vocabulary file.
  12226. + * @return a list of vocabulary words.
  12227. + * @throws IOException if error occurs to open or read the file.
  12228. + */
  12229. + @NonNull
  12230. + public static List<String> loadSingleColumnTextFile(
  12231. + @NonNull InputStream inputStream, Charset cs) throws IOException {
  12232. + return loadLabels(inputStream, cs);
  12233. + }
  12234. +
  12235. + /**
  12236. + * Loads a file from the asset folder through memory mapping.
  12237. + *
  12238. + * @param context Application context to access assets.
  12239. + * @param filePath Asset path of the file.
  12240. + * @return the loaded memory mapped file.
  12241. + * @throws IOException if an I/O error occurs when loading the tflite model.
  12242. + */
  12243. + @NonNull
  12244. + public static MappedByteBuffer loadMappedFile(
  12245. + @NonNull Context context, @NonNull String filePath) throws IOException {
  12246. + SupportPreconditions.checkNotNull(context, "Context should not be null.");
  12247. + SupportPreconditions.checkNotNull(filePath, "File path cannot be null.");
  12248. + try (AssetFileDescriptor fileDescriptor = context.getAssets().openFd(filePath);
  12249. + FileInputStream inputStream =
  12250. + new FileInputStream(fileDescriptor.getFileDescriptor())) {
  12251. + FileChannel fileChannel = inputStream.getChannel();
  12252. + long startOffset = fileDescriptor.getStartOffset();
  12253. + long declaredLength = fileDescriptor.getDeclaredLength();
  12254. + return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
  12255. + }
  12256. + }
  12257. +
  12258. + /**
  12259. + * Loads a binary file from the asset folder.
  12260. + *
  12261. + * @param context Application context to access assets.
  12262. + * @param filePath Asset path of the file.
  12263. + * @return the byte array for the binary file.
  12264. + * @throws IOException if an I/O error occurs when loading file.
  12265. + */
  12266. + @NonNull
  12267. + public static byte[] loadByteFromFile(@NonNull Context context, @NonNull String filePath)
  12268. + throws IOException {
  12269. + ByteBuffer buffer = loadMappedFile(context, filePath);
  12270. + byte[] byteArray = new byte[buffer.remaining()];
  12271. + buffer.get(byteArray);
  12272. + return byteArray;
  12273. }
  12274. - }
  12275. -
  12276. - /**
  12277. - * Loads a binary file from the asset folder.
  12278. - *
  12279. - * @param context Application context to access assets.
  12280. - * @param filePath Asset path of the file.
  12281. - * @return the byte array for the binary file.
  12282. - * @throws IOException if an I/O error occurs when loading file.
  12283. - */
  12284. - @NonNull
  12285. - public static byte[] loadByteFromFile(@NonNull Context context, @NonNull String filePath)
  12286. - throws IOException {
  12287. - ByteBuffer buffer = loadMappedFile(context, filePath);
  12288. - byte[] byteArray = new byte[buffer.remaining()];
  12289. - buffer.get(byteArray);
  12290. - return byteArray;
  12291. - }
  12292. }
  12293. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/Operator.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/Operator.java
  12294. index 38dfe8818cbbc..45dfc4d9d868b 100644
  12295. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/Operator.java
  12296. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/Operator.java
  12297. @@ -20,12 +20,11 @@ package org.tensorflow.lite.support.common;
  12298. * @param <T> The class which Operator handles.
  12299. */
  12300. public interface Operator<T> {
  12301. -
  12302. - /**
  12303. - * Applies an operation on a T object, returning a T object.
  12304. - *
  12305. - * <p>Note: The returned object could probably be the same one with given input, and given input
  12306. - * could probably be changed.
  12307. - */
  12308. - T apply(T x);
  12309. + /**
  12310. + * Applies an operation on a T object, returning a T object.
  12311. + *
  12312. + * <p>Note: The returned object could probably be the same one with given input, and given input
  12313. + * could probably be changed.
  12314. + */
  12315. + T apply(T x);
  12316. }
  12317. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/Processor.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/Processor.java
  12318. index 9d0024b2f5887..a94adb89b8666 100644
  12319. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/Processor.java
  12320. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/Processor.java
  12321. @@ -17,5 +17,5 @@ package org.tensorflow.lite.support.common;
  12322. /** Processes T object with prepared {@code Operator<T>}. */
  12323. public interface Processor<T> {
  12324. - T process(T input);
  12325. + T process(T input);
  12326. }
  12327. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/SequentialProcessor.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/SequentialProcessor.java
  12328. index af688c863c254..aa900b7c93d87 100644
  12329. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/SequentialProcessor.java
  12330. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/SequentialProcessor.java
  12331. @@ -15,13 +15,14 @@ limitations under the License.
  12332. package org.tensorflow.lite.support.common;
  12333. +import org.checkerframework.checker.nullness.qual.NonNull;
  12334. +import org.tensorflow.lite.support.common.internal.SupportPreconditions;
  12335. +
  12336. import java.util.ArrayList;
  12337. import java.util.Collections;
  12338. import java.util.HashMap;
  12339. import java.util.List;
  12340. import java.util.Map;
  12341. -import org.checkerframework.checker.nullness.qual.NonNull;
  12342. -import org.tensorflow.lite.support.common.internal.SupportPreconditions;
  12343. /**
  12344. * A processor base class that chains a serial of {@code Operator<T>} and executes them.
  12345. @@ -32,52 +33,50 @@ import org.tensorflow.lite.support.common.internal.SupportPreconditions;
  12346. * @param <T> The type that the Operator is handling.
  12347. */
  12348. public class SequentialProcessor<T> implements Processor<T> {
  12349. + /** List of operators added to this {@link SequentialProcessor}. */
  12350. + protected final List<Operator<T>> operatorList;
  12351. + /**
  12352. + * The {@link Map} between the operator name and the corresponding op indexes in {@code
  12353. + * operatorList}. An operator may be added multiple times into this {@link SequentialProcessor}.
  12354. + */
  12355. + protected final Map<String, List<Integer>> operatorIndex;
  12356. - /** List of operators added to this {@link SequentialProcessor}. */
  12357. - protected final List<Operator<T>> operatorList;
  12358. - /**
  12359. - * The {@link Map} between the operator name and the corresponding op indexes in {@code
  12360. - * operatorList}. An operator may be added multiple times into this {@link SequentialProcessor}.
  12361. - */
  12362. - protected final Map<String, List<Integer>> operatorIndex;
  12363. -
  12364. - protected SequentialProcessor(Builder<T> builder) {
  12365. - operatorList = builder.operatorList;
  12366. - operatorIndex = Collections.unmodifiableMap(builder.operatorIndex);
  12367. - }
  12368. + protected SequentialProcessor(Builder<T> builder) {
  12369. + operatorList = builder.operatorList;
  12370. + operatorIndex = Collections.unmodifiableMap(builder.operatorIndex);
  12371. + }
  12372. - @Override
  12373. - public T process(T x) {
  12374. - for (Operator<T> op : operatorList) {
  12375. - x = op.apply(x);
  12376. + @Override
  12377. + public T process(T x) {
  12378. + for (Operator<T> op : operatorList) {
  12379. + x = op.apply(x);
  12380. + }
  12381. + return x;
  12382. }
  12383. - return x;
  12384. - }
  12385. - /** The inner builder class to build a Sequential Processor. */
  12386. - protected static class Builder<T> {
  12387. + /** The inner builder class to build a Sequential Processor. */
  12388. + protected static class Builder<T> {
  12389. + private final List<Operator<T>> operatorList;
  12390. + private final Map<String, List<Integer>> operatorIndex;
  12391. - private final List<Operator<T>> operatorList;
  12392. - private final Map<String, List<Integer>> operatorIndex;
  12393. + protected Builder() {
  12394. + operatorList = new ArrayList<>();
  12395. + operatorIndex = new HashMap<>();
  12396. + }
  12397. - protected Builder() {
  12398. - operatorList = new ArrayList<>();
  12399. - operatorIndex = new HashMap<>();
  12400. - }
  12401. -
  12402. - public Builder<T> add(@NonNull Operator<T> op) {
  12403. - SupportPreconditions.checkNotNull(op, "Adding null Op is illegal.");
  12404. - operatorList.add(op);
  12405. - String operatorName = op.getClass().getName();
  12406. - if (!operatorIndex.containsKey(operatorName)) {
  12407. - operatorIndex.put(operatorName, new ArrayList<Integer>());
  12408. - }
  12409. - operatorIndex.get(operatorName).add(operatorList.size() - 1);
  12410. - return this;
  12411. - }
  12412. + public Builder<T> add(@NonNull Operator<T> op) {
  12413. + SupportPreconditions.checkNotNull(op, "Adding null Op is illegal.");
  12414. + operatorList.add(op);
  12415. + String operatorName = op.getClass().getName();
  12416. + if (!operatorIndex.containsKey(operatorName)) {
  12417. + operatorIndex.put(operatorName, new ArrayList<Integer>());
  12418. + }
  12419. + operatorIndex.get(operatorName).add(operatorList.size() - 1);
  12420. + return this;
  12421. + }
  12422. - public SequentialProcessor<T> build() {
  12423. - return new SequentialProcessor<T>(this);
  12424. + public SequentialProcessor<T> build() {
  12425. + return new SequentialProcessor<T>(this);
  12426. + }
  12427. }
  12428. - }
  12429. }
  12430. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/TensorOperator.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/TensorOperator.java
  12431. index d1b7021df257c..692c2d479dcce 100644
  12432. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/TensorOperator.java
  12433. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/TensorOperator.java
  12434. @@ -21,7 +21,7 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
  12435. * Applies some operation on TensorBuffers.
  12436. */
  12437. public interface TensorOperator extends Operator<TensorBuffer> {
  12438. - /** @see Operator#apply(Object) . */
  12439. - @Override
  12440. - TensorBuffer apply(TensorBuffer input);
  12441. + /** @see Operator#apply(Object) . */
  12442. + @Override
  12443. + TensorBuffer apply(TensorBuffer input);
  12444. }
  12445. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/TensorProcessor.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/TensorProcessor.java
  12446. index 8096d0c764bab..faad66edeb04e 100644
  12447. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/TensorProcessor.java
  12448. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/TensorProcessor.java
  12449. @@ -32,37 +32,36 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
  12450. * @see TensorProcessor#process to apply the processor on a {@code TensorBuffer}.
  12451. */
  12452. public class TensorProcessor extends SequentialProcessor<TensorBuffer> {
  12453. - private TensorProcessor(Builder builder) {
  12454. - super(builder);
  12455. - }
  12456. -
  12457. - /** The Builder to create an {@link TensorProcessor}, which could be executed later. */
  12458. - public static class Builder extends SequentialProcessor.Builder<TensorBuffer> {
  12459. -
  12460. - /**
  12461. - * Creates a Builder to build {@link TensorProcessor}.
  12462. - *
  12463. - * @see #add(TensorOperator) to add an Op.
  12464. - * @see #build() to complete the building process and get a built Processor.
  12465. - */
  12466. - public Builder() {
  12467. - super();
  12468. + private TensorProcessor(Builder builder) {
  12469. + super(builder);
  12470. }
  12471. - /**
  12472. - * Adds an {@link TensorOperator} into the Operator chain.
  12473. - *
  12474. - * @param op the Operator instance to be executed then.
  12475. - */
  12476. - public TensorProcessor.Builder add(TensorOperator op) {
  12477. - super.add(op);
  12478. - return this;
  12479. - }
  12480. + /** The Builder to create an {@link TensorProcessor}, which could be executed later. */
  12481. + public static class Builder extends SequentialProcessor.Builder<TensorBuffer> {
  12482. + /**
  12483. + * Creates a Builder to build {@link TensorProcessor}.
  12484. + *
  12485. + * @see #add(TensorOperator) to add an Op.
  12486. + * @see #build() to complete the building process and get a built Processor.
  12487. + */
  12488. + public Builder() {
  12489. + super();
  12490. + }
  12491. +
  12492. + /**
  12493. + * Adds an {@link TensorOperator} into the Operator chain.
  12494. + *
  12495. + * @param op the Operator instance to be executed then.
  12496. + */
  12497. + public TensorProcessor.Builder add(TensorOperator op) {
  12498. + super.add(op);
  12499. + return this;
  12500. + }
  12501. - /** Completes the building process and gets the {@link TensorProcessor} instance. */
  12502. - @Override
  12503. - public TensorProcessor build() {
  12504. - return new TensorProcessor(this);
  12505. + /** Completes the building process and gets the {@link TensorProcessor} instance. */
  12506. + @Override
  12507. + public TensorProcessor build() {
  12508. + return new TensorProcessor(this);
  12509. + }
  12510. }
  12511. - }
  12512. }
  12513. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/internal/SupportPreconditions.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/internal/SupportPreconditions.java
  12514. index e3e962a5f8252..29faa545b71f2 100644
  12515. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/internal/SupportPreconditions.java
  12516. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/internal/SupportPreconditions.java
  12517. @@ -19,164 +19,168 @@ import org.checkerframework.checker.nullness.qual.Nullable;
  12518. /** Static error checking util methods. */
  12519. public final class SupportPreconditions {
  12520. - /**
  12521. - * Ensures that an object reference passed as a parameter to the calling method is not null.
  12522. - *
  12523. - * @param reference an object reference
  12524. - * @return the non-null reference that was validated
  12525. - * @throws NullPointerException if {@code reference} is null
  12526. - */
  12527. - public static <T extends Object> T checkNotNull(T reference) {
  12528. - if (reference == null) {
  12529. - throw new NullPointerException("The object reference is null.");
  12530. + /**
  12531. + * Ensures that an object reference passed as a parameter to the calling method is not null.
  12532. + *
  12533. + * @param reference an object reference
  12534. + * @return the non-null reference that was validated
  12535. + * @throws NullPointerException if {@code reference} is null
  12536. + */
  12537. + public static <T extends Object> T checkNotNull(T reference) {
  12538. + if (reference == null) {
  12539. + throw new NullPointerException("The object reference is null.");
  12540. + }
  12541. + return reference;
  12542. }
  12543. - return reference;
  12544. - }
  12545. -
  12546. - /**
  12547. - * Ensures that an object reference passed as a parameter to the calling method is not null.
  12548. - *
  12549. - * @param reference an object reference
  12550. - * @param errorMessage the exception message to use if the check fails; will be converted to a
  12551. - * string using {@link String#valueOf(Object)}
  12552. - * @return the non-null reference that was validated
  12553. - * @throws NullPointerException if {@code reference} is null
  12554. - */
  12555. - public static <T extends Object> T checkNotNull(T reference, @Nullable Object errorMessage) {
  12556. - if (reference == null) {
  12557. - throw new NullPointerException(String.valueOf(errorMessage));
  12558. +
  12559. + /**
  12560. + * Ensures that an object reference passed as a parameter to the calling method is not null.
  12561. + *
  12562. + * @param reference an object reference
  12563. + * @param errorMessage the exception message to use if the check fails; will be converted to a
  12564. + * string using {@link String#valueOf(Object)}
  12565. + * @return the non-null reference that was validated
  12566. + * @throws NullPointerException if {@code reference} is null
  12567. + */
  12568. + public static <T extends Object> T checkNotNull(T reference, @Nullable Object errorMessage) {
  12569. + if (reference == null) {
  12570. + throw new NullPointerException(String.valueOf(errorMessage));
  12571. + }
  12572. + return reference;
  12573. + }
  12574. +
  12575. + /**
  12576. + * Ensures that the given String is not empty and not null.
  12577. + *
  12578. + * @param string the String to test
  12579. + * @return the non-null non-empty String that was validated
  12580. + * @throws IllegalArgumentException if {@code string} is null or empty
  12581. + */
  12582. + public static String checkNotEmpty(String string) {
  12583. + if (string == null || string.length() == 0) {
  12584. + throw new IllegalArgumentException("Given String is empty or null.");
  12585. + }
  12586. + return string;
  12587. }
  12588. - return reference;
  12589. - }
  12590. -
  12591. - /**
  12592. - * Ensures that the given String is not empty and not null.
  12593. - *
  12594. - * @param string the String to test
  12595. - * @return the non-null non-empty String that was validated
  12596. - * @throws IllegalArgumentException if {@code string} is null or empty
  12597. - */
  12598. - public static String checkNotEmpty(String string) {
  12599. - if (string == null || string.length() == 0) {
  12600. - throw new IllegalArgumentException("Given String is empty or null.");
  12601. +
  12602. + /**
  12603. + * Ensures that the given String is not empty and not null.
  12604. + *
  12605. + * @param string the String to test
  12606. + * @param errorMessage the exception message to use if the check fails; will be converted to a
  12607. + * string using {@link String#valueOf(Object)}
  12608. + * @return the non-null non-empty String that was validated
  12609. + * @throws IllegalArgumentException if {@code string} is null or empty
  12610. + */
  12611. + public static String checkNotEmpty(String string, Object errorMessage) {
  12612. + if (string == null || string.length() == 0) {
  12613. + throw new IllegalArgumentException(String.valueOf(errorMessage));
  12614. + }
  12615. + return string;
  12616. }
  12617. - return string;
  12618. - }
  12619. -
  12620. - /**
  12621. - * Ensures that the given String is not empty and not null.
  12622. - *
  12623. - * @param string the String to test
  12624. - * @param errorMessage the exception message to use if the check fails; will be converted to a
  12625. - * string using {@link String#valueOf(Object)}
  12626. - * @return the non-null non-empty String that was validated
  12627. - * @throws IllegalArgumentException if {@code string} is null or empty
  12628. - */
  12629. - public static String checkNotEmpty(String string, Object errorMessage) {
  12630. - if (string == null || string.length() == 0) {
  12631. - throw new IllegalArgumentException(String.valueOf(errorMessage));
  12632. +
  12633. + /**
  12634. + * Ensures the truth of an expression involving one or more parameters to the calling method.
  12635. + *
  12636. + * @param expression a boolean expression.
  12637. + * @throws IllegalArgumentException if {@code expression} is false.
  12638. + */
  12639. + public static void checkArgument(boolean expression) {
  12640. + if (!expression) {
  12641. + throw new IllegalArgumentException();
  12642. + }
  12643. }
  12644. - return string;
  12645. - }
  12646. -
  12647. - /**
  12648. - * Ensures the truth of an expression involving one or more parameters to the calling method.
  12649. - *
  12650. - * @param expression a boolean expression.
  12651. - * @throws IllegalArgumentException if {@code expression} is false.
  12652. - */
  12653. - public static void checkArgument(boolean expression) {
  12654. - if (!expression) {
  12655. - throw new IllegalArgumentException();
  12656. +
  12657. + /**
  12658. + * Ensures the truth of an expression involving one or more parameters to the calling method.
  12659. + *
  12660. + * @param expression a boolean expression.
  12661. + * @param errorMessage the exception message to use if the check fails; will be converted to a
  12662. + * string using {@link String#valueOf(Object)}.
  12663. + * @throws IllegalArgumentException if {@code expression} is false.
  12664. + */
  12665. + public static void checkArgument(boolean expression, @Nullable Object errorMessage) {
  12666. + if (!expression) {
  12667. + throw new IllegalArgumentException(String.valueOf(errorMessage));
  12668. + }
  12669. }
  12670. - }
  12671. -
  12672. - /**
  12673. - * Ensures the truth of an expression involving one or more parameters to the calling method.
  12674. - *
  12675. - * @param expression a boolean expression.
  12676. - * @param errorMessage the exception message to use if the check fails; will be converted to a
  12677. - * string using {@link String#valueOf(Object)}.
  12678. - * @throws IllegalArgumentException if {@code expression} is false.
  12679. - */
  12680. - public static void checkArgument(boolean expression, @Nullable Object errorMessage) {
  12681. - if (!expression) {
  12682. - throw new IllegalArgumentException(String.valueOf(errorMessage));
  12683. +
  12684. + /**
  12685. + * Ensures that {@code index} specifies a valid <i>element</i> in an array, list or string of
  12686. + * size
  12687. + * {@code size}. An element index may range from zero, inclusive, to {@code size}, exclusive.
  12688. + *
  12689. + * @param index a user-supplied index identifying an element of an array, list or string
  12690. + * @param size the size of that array, list or string
  12691. + * @return the value of {@code index}
  12692. + * @throws IndexOutOfBoundsException if {@code index} is negative or is not less than {@code
  12693. + * size}
  12694. + * @throws IllegalArgumentException if {@code size} is negative
  12695. + */
  12696. + public static int checkElementIndex(int index, int size) {
  12697. + return checkElementIndex(index, size, "index");
  12698. }
  12699. - }
  12700. -
  12701. - /**
  12702. - * Ensures that {@code index} specifies a valid <i>element</i> in an array, list or string of size
  12703. - * {@code size}. An element index may range from zero, inclusive, to {@code size}, exclusive.
  12704. - *
  12705. - * @param index a user-supplied index identifying an element of an array, list or string
  12706. - * @param size the size of that array, list or string
  12707. - * @return the value of {@code index}
  12708. - * @throws IndexOutOfBoundsException if {@code index} is negative or is not less than {@code size}
  12709. - * @throws IllegalArgumentException if {@code size} is negative
  12710. - */
  12711. - public static int checkElementIndex(int index, int size) {
  12712. - return checkElementIndex(index, size, "index");
  12713. - }
  12714. -
  12715. - /**
  12716. - * Ensures that {@code index} specifies a valid <i>element</i> in an array, list or string of size
  12717. - * {@code size}. An element index may range from zero, inclusive, to {@code size}, exclusive.
  12718. - *
  12719. - * @param index a user-supplied index identifying an element of an array, list or string
  12720. - * @param size the size of that array, list or string
  12721. - * @param desc the text to use to describe this index in an error message
  12722. - * @return the value of {@code index}
  12723. - * @throws IndexOutOfBoundsException if {@code index} is negative or is not less than {@code size}
  12724. - * @throws IllegalArgumentException if {@code size} is negative
  12725. - */
  12726. - public static int checkElementIndex(int index, int size, @Nullable String desc) {
  12727. - // Carefully optimized for execution by hotspot (explanatory comment above)
  12728. - if (index < 0 || index >= size) {
  12729. - throw new IndexOutOfBoundsException(badElementIndex(index, size, desc));
  12730. +
  12731. + /**
  12732. + * Ensures that {@code index} specifies a valid <i>element</i> in an array, list or string of
  12733. + * size
  12734. + * {@code size}. An element index may range from zero, inclusive, to {@code size}, exclusive.
  12735. + *
  12736. + * @param index a user-supplied index identifying an element of an array, list or string
  12737. + * @param size the size of that array, list or string
  12738. + * @param desc the text to use to describe this index in an error message
  12739. + * @return the value of {@code index}
  12740. + * @throws IndexOutOfBoundsException if {@code index} is negative or is not less than {@code
  12741. + * size}
  12742. + * @throws IllegalArgumentException if {@code size} is negative
  12743. + */
  12744. + public static int checkElementIndex(int index, int size, @Nullable String desc) {
  12745. + // Carefully optimized for execution by hotspot (explanatory comment above)
  12746. + if (index < 0 || index >= size) {
  12747. + throw new IndexOutOfBoundsException(badElementIndex(index, size, desc));
  12748. + }
  12749. + return index;
  12750. }
  12751. - return index;
  12752. - }
  12753. -
  12754. - /**
  12755. - * Ensures the truth of an expression involving the state of the calling instance, but not
  12756. - * involving any parameters to the calling method.
  12757. - *
  12758. - * @param expression a boolean expression
  12759. - * @throws IllegalStateException if {@code expression} is false
  12760. - */
  12761. - public static void checkState(boolean expression) {
  12762. - if (!expression) {
  12763. - throw new IllegalStateException();
  12764. +
  12765. + /**
  12766. + * Ensures the truth of an expression involving the state of the calling instance, but not
  12767. + * involving any parameters to the calling method.
  12768. + *
  12769. + * @param expression a boolean expression
  12770. + * @throws IllegalStateException if {@code expression} is false
  12771. + */
  12772. + public static void checkState(boolean expression) {
  12773. + if (!expression) {
  12774. + throw new IllegalStateException();
  12775. + }
  12776. }
  12777. - }
  12778. -
  12779. - /**
  12780. - * Ensures the truth of an expression involving the state of the calling instance, but not
  12781. - * involving any parameters to the calling method.
  12782. - *
  12783. - * @param expression a boolean expression
  12784. - * @param errorMessage the exception message to use if the check fails; will be converted to a
  12785. - * string using {@link String#valueOf(Object)}
  12786. - * @throws IllegalStateException if {@code expression} is false
  12787. - */
  12788. - public static void checkState(boolean expression, @Nullable Object errorMessage) {
  12789. - if (!expression) {
  12790. - throw new IllegalStateException(String.valueOf(errorMessage));
  12791. +
  12792. + /**
  12793. + * Ensures the truth of an expression involving the state of the calling instance, but not
  12794. + * involving any parameters to the calling method.
  12795. + *
  12796. + * @param expression a boolean expression
  12797. + * @param errorMessage the exception message to use if the check fails; will be converted to a
  12798. + * string using {@link String#valueOf(Object)}
  12799. + * @throws IllegalStateException if {@code expression} is false
  12800. + */
  12801. + public static void checkState(boolean expression, @Nullable Object errorMessage) {
  12802. + if (!expression) {
  12803. + throw new IllegalStateException(String.valueOf(errorMessage));
  12804. + }
  12805. }
  12806. - }
  12807. -
  12808. - private static String badElementIndex(int index, int size, @Nullable String desc) {
  12809. - if (index < 0) {
  12810. - return String.format("%s (%s) must not be negative", desc, index);
  12811. - } else if (size < 0) {
  12812. - throw new IllegalArgumentException("negative size: " + size);
  12813. - } else { // index >= size
  12814. - return String.format("%s (%s) must be less than size (%s)", desc, index, size);
  12815. +
  12816. + private static String badElementIndex(int index, int size, @Nullable String desc) {
  12817. + if (index < 0) {
  12818. + return String.format("%s (%s) must not be negative", desc, index);
  12819. + } else if (size < 0) {
  12820. + throw new IllegalArgumentException("negative size: " + size);
  12821. + } else { // index >= size
  12822. + return String.format("%s (%s) must be less than size (%s)", desc, index, size);
  12823. + }
  12824. }
  12825. - }
  12826. - private SupportPreconditions() {
  12827. - throw new AssertionError("SupportPreconditions is Uninstantiable.");
  12828. - }
  12829. + private SupportPreconditions() {
  12830. + throw new AssertionError("SupportPreconditions is Uninstantiable.");
  12831. + }
  12832. }
  12833. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/CastOp.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/CastOp.java
  12834. index 742a1ef90994c..a14cd1f1e503d 100644
  12835. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/CastOp.java
  12836. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/CastOp.java
  12837. @@ -22,34 +22,33 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
  12838. /** Casts a {@link TensorBuffer} to a specified data type. */
  12839. public class CastOp implements TensorOperator {
  12840. + private final DataType destinationType;
  12841. +
  12842. + /**
  12843. + * Constructs a CastOp.
  12844. + *
  12845. + * <p>Note: For only converting type for a certain {@link TensorBuffer} on-the-fly rather than
  12846. + * in a processor, please directly use {@link TensorBuffer#createFrom(TensorBuffer, DataType)}.
  12847. + *
  12848. + * <p>When this Op is executed, if the original {@link TensorBuffer} is already in {@code
  12849. + * destinationType}, the original buffer will be directly returned.
  12850. + *
  12851. + * @param destinationType The type of the casted {@link TensorBuffer}.
  12852. + * @throws IllegalArgumentException if {@code destinationType} is neither {@link DataType#UINT8}
  12853. + * nor {@link DataType#FLOAT32}.
  12854. + */
  12855. + public CastOp(DataType destinationType) {
  12856. + SupportPreconditions.checkArgument(
  12857. + destinationType == DataType.UINT8 || destinationType == DataType.FLOAT32,
  12858. + "Destination type " + destinationType + " is not supported.");
  12859. + this.destinationType = destinationType;
  12860. + }
  12861. - private final DataType destinationType;
  12862. -
  12863. - /**
  12864. - * Constructs a CastOp.
  12865. - *
  12866. - * <p>Note: For only converting type for a certain {@link TensorBuffer} on-the-fly rather than in
  12867. - * a processor, please directly use {@link TensorBuffer#createFrom(TensorBuffer, DataType)}.
  12868. - *
  12869. - * <p>When this Op is executed, if the original {@link TensorBuffer} is already in {@code
  12870. - * destinationType}, the original buffer will be directly returned.
  12871. - *
  12872. - * @param destinationType The type of the casted {@link TensorBuffer}.
  12873. - * @throws IllegalArgumentException if {@code destinationType} is neither {@link DataType#UINT8}
  12874. - * nor {@link DataType#FLOAT32}.
  12875. - */
  12876. - public CastOp(DataType destinationType) {
  12877. - SupportPreconditions.checkArgument(
  12878. - destinationType == DataType.UINT8 || destinationType == DataType.FLOAT32,
  12879. - "Destination type " + destinationType + " is not supported.");
  12880. - this.destinationType = destinationType;
  12881. - }
  12882. -
  12883. - @Override
  12884. - public TensorBuffer apply(TensorBuffer input) {
  12885. - if (input.getDataType() == destinationType) {
  12886. - return input;
  12887. + @Override
  12888. + public TensorBuffer apply(TensorBuffer input) {
  12889. + if (input.getDataType() == destinationType) {
  12890. + return input;
  12891. + }
  12892. + return TensorBuffer.createFrom(input, destinationType);
  12893. }
  12894. - return TensorBuffer.createFrom(input, destinationType);
  12895. - }
  12896. }
  12897. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/DequantizeOp.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/DequantizeOp.java
  12898. index 1881747870be3..8b6d183189b7f 100644
  12899. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/DequantizeOp.java
  12900. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/DequantizeOp.java
  12901. @@ -32,9 +32,8 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
  12902. * as 0.
  12903. */
  12904. public class DequantizeOp extends NormalizeOp implements TensorOperator {
  12905. -
  12906. - public DequantizeOp(float zeroPoint, float scale) {
  12907. - // Quantization: f = (q - z) * s
  12908. - super(zeroPoint, 1 / scale);
  12909. - }
  12910. + public DequantizeOp(float zeroPoint, float scale) {
  12911. + // Quantization: f = (q - z) * s
  12912. + super(zeroPoint, 1 / scale);
  12913. + }
  12914. }
  12915. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/NormalizeOp.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/NormalizeOp.java
  12916. index cff4d0b55d60a..912df13b59cec 100644
  12917. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/NormalizeOp.java
  12918. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/NormalizeOp.java
  12919. @@ -26,135 +26,134 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBufferFloat;
  12920. * Normalizes a {@link TensorBuffer} with given mean and stddev: output = (input - mean) / stddev.
  12921. */
  12922. public class NormalizeOp implements TensorOperator {
  12923. + // mean.length should always be equal to stddev.length and always >= 1.
  12924. + private final float[] mean;
  12925. + private final float[] stddev;
  12926. + private final int numChannels;
  12927. + private final boolean isIdentityOp;
  12928. - // mean.length should always be equal to stddev.length and always >= 1.
  12929. - private final float[] mean;
  12930. - private final float[] stddev;
  12931. - private final int numChannels;
  12932. - private final boolean isIdentityOp;
  12933. + /**
  12934. + * Initializes a NormalizeOp. When being called, it creates a new {@link TensorBuffer}, which
  12935. + * satisfies:
  12936. + *
  12937. + * <pre>
  12938. + * output = (input - mean) / stddev
  12939. + * </pre>
  12940. + *
  12941. + * <p>In the following two cases, reset {@code mean} to 0 and {@code stddev} to 1 to bypass the
  12942. + * normalization. <br>
  12943. + * 1. Both {@code mean} and {code stddev} are 0. <br>
  12944. + * 2. {@code mean} is 0 and {stddev} is Infinity.
  12945. + *
  12946. + * <p>Note: If {@code mean} is set to 0 and {@code stddev} is set to 1, no computation will
  12947. + * happen, and original input will be directly returned in execution.
  12948. + *
  12949. + * <p>Note: The returned {@link TensorBuffer} is always a {@link DataType#FLOAT32} tensor at
  12950. + * present, except when the input is a {@link DataType#UINT8} tensor, {@code mean} is set to 0
  12951. + * and
  12952. + * {@code stddev} is set to 1, so that the original {@link DataType#UINT8} tensor is returned.
  12953. + *
  12954. + * @param mean the mean value to be subtracted first.
  12955. + * @param stddev the standard deviation value to divide then.
  12956. + * @throws IllegalArgumentException if {@code stddev} is zero.
  12957. + */
  12958. + public NormalizeOp(float mean, float stddev) {
  12959. + // Make exceptions to the cases that
  12960. + // 1. Both mean and stddev are 0.0f. This may happen when reading the normalization
  12961. + // parameters from a tensor which does not have the values populated in the metadata. The
  12962. + // same situation may also happen to the quantization parameters.
  12963. + // 2. mean is 0.0f and stddev is Infinity. This may happen when reading the quantization
  12964. + // parameters from a tensor which does not have the values populated in the metadata, and
  12965. + // then passing the parameters into the DequantizeOp. Bypass both of the two cases, by
  12966. + // reseting stddev to 1.0f.
  12967. + if (mean == 0.0f && (stddev == 0.0f || Float.isInfinite(stddev))) {
  12968. + stddev = 1.0f;
  12969. + }
  12970. - /**
  12971. - * Initializes a NormalizeOp. When being called, it creates a new {@link TensorBuffer}, which
  12972. - * satisfies:
  12973. - *
  12974. - * <pre>
  12975. - * output = (input - mean) / stddev
  12976. - * </pre>
  12977. - *
  12978. - * <p>In the following two cases, reset {@code mean} to 0 and {@code stddev} to 1 to bypass the
  12979. - * normalization. <br>
  12980. - * 1. Both {@code mean} and {code stddev} are 0. <br>
  12981. - * 2. {@code mean} is 0 and {stddev} is Infinity.
  12982. - *
  12983. - * <p>Note: If {@code mean} is set to 0 and {@code stddev} is set to 1, no computation will
  12984. - * happen, and original input will be directly returned in execution.
  12985. - *
  12986. - * <p>Note: The returned {@link TensorBuffer} is always a {@link DataType#FLOAT32} tensor at
  12987. - * present, except when the input is a {@link DataType#UINT8} tensor, {@code mean} is set to 0 and
  12988. - * {@code stddev} is set to 1, so that the original {@link DataType#UINT8} tensor is returned.
  12989. - *
  12990. - * @param mean the mean value to be subtracted first.
  12991. - * @param stddev the standard deviation value to divide then.
  12992. - * @throws IllegalArgumentException if {@code stddev} is zero.
  12993. - */
  12994. - public NormalizeOp(float mean, float stddev) {
  12995. - // Make exceptions to the cases that
  12996. - // 1. Both mean and stddev are 0.0f. This may happen when reading the normalization parameters
  12997. - // from a tensor which does not have the values populated in the metadata. The same situation
  12998. - // may also happen to the quantization parameters.
  12999. - // 2. mean is 0.0f and stddev is Infinity. This may happen when reading the quantization
  13000. - // parameters from a tensor which does not have the values populated in the metadata, and then
  13001. - // passing the parameters into the DequantizeOp.
  13002. - // Bypass both of the two cases, by reseting stddev to 1.0f.
  13003. - if (mean == 0.0f && (stddev == 0.0f || Float.isInfinite(stddev))) {
  13004. - stddev = 1.0f;
  13005. - }
  13006. + SupportPreconditions.checkArgument(stddev != 0.0f, "Stddev cannot be zero.");
  13007. + boolean meansIsZeroAndDevsIs1 = false;
  13008. + if (mean == 0.0f && stddev == 1.0f) {
  13009. + meansIsZeroAndDevsIs1 = true;
  13010. + }
  13011. - SupportPreconditions.checkArgument(stddev != 0.0f, "Stddev cannot be zero.");
  13012. - boolean meansIsZeroAndDevsIs1 = false;
  13013. - if (mean == 0.0f && stddev == 1.0f) {
  13014. - meansIsZeroAndDevsIs1 = true;
  13015. + this.isIdentityOp = meansIsZeroAndDevsIs1;
  13016. + this.mean = new float[] {mean};
  13017. + this.stddev = new float[] {stddev};
  13018. + this.numChannels = 1;
  13019. }
  13020. - this.isIdentityOp = meansIsZeroAndDevsIs1;
  13021. - this.mean = new float[] {mean};
  13022. - this.stddev = new float[] {stddev};
  13023. - this.numChannels = 1;
  13024. - }
  13025. -
  13026. - /**
  13027. - * Initializes a NormalizeOp. When being called, it creates a new {@link TensorBuffer}, which
  13028. - * satisfies:
  13029. - *
  13030. - * <pre>
  13031. - * // Pseudo code. [...][i] means a certain element whose channel id is i.
  13032. - * output[...][i] = (input[...][i] - mean[i]) / stddev[i]
  13033. - * </pre>
  13034. - *
  13035. - * <p>Note: If all values in {@code mean} are set to 0 and all {@code stddev} are set to 1, no
  13036. - * computation will happen, and original input will be directly returned in execution.
  13037. - *
  13038. - * <p>Note: The returned {@link TensorBuffer} is always a {@link DataType#FLOAT32} tensor at
  13039. - * present, except that the input is a {@link DataType#UINT8} tensor, all {@code mean} are set to
  13040. - * 0 and all {@code stddev} are set to 1.
  13041. - *
  13042. - * @param mean the mean values to be subtracted first for each channel.
  13043. - * @param stddev the standard deviation values to divide then for each channel.
  13044. - * @throws IllegalArgumentException if any {@code stddev} is zero, or {@code mean} has different
  13045. - * number of elements with {@code stddev}, or any of them is empty.
  13046. - */
  13047. - public NormalizeOp(@NonNull float[] mean, @NonNull float[] stddev) {
  13048. - SupportPreconditions.checkNotNull(mean, "Mean cannot be null");
  13049. - SupportPreconditions.checkNotNull(stddev, "Stddev cannot be null");
  13050. - SupportPreconditions.checkArgument(
  13051. - mean.length == stddev.length,
  13052. - "Per channel normalization requires same number of means and stddevs");
  13053. - SupportPreconditions.checkArgument(mean.length > 0, "Means and stddevs are empty.");
  13054. - this.mean = mean.clone();
  13055. - this.stddev = stddev.clone();
  13056. - boolean allMeansAreZeroAndAllDevsAre1 = true;
  13057. - this.numChannels = mean.length;
  13058. - for (int i = 0; i < numChannels; i++) {
  13059. - SupportPreconditions.checkArgument(this.stddev[i] != 0, "Stddev cannot be zero.");
  13060. - if (this.stddev[i] != 1 || this.mean[i] != 0) {
  13061. - allMeansAreZeroAndAllDevsAre1 = false;
  13062. - }
  13063. + /**
  13064. + * Initializes a NormalizeOp. When being called, it creates a new {@link TensorBuffer}, which
  13065. + * satisfies:
  13066. + *
  13067. + * <pre>
  13068. + * // Pseudo code. [...][i] means a certain element whose channel id is i.
  13069. + * output[...][i] = (input[...][i] - mean[i]) / stddev[i]
  13070. + * </pre>
  13071. + *
  13072. + * <p>Note: If all values in {@code mean} are set to 0 and all {@code stddev} are set to 1, no
  13073. + * computation will happen, and original input will be directly returned in execution.
  13074. + *
  13075. + * <p>Note: The returned {@link TensorBuffer} is always a {@link DataType#FLOAT32} tensor at
  13076. + * present, except that the input is a {@link DataType#UINT8} tensor, all {@code mean} are set
  13077. + * to 0 and all {@code stddev} are set to 1.
  13078. + *
  13079. + * @param mean the mean values to be subtracted first for each channel.
  13080. + * @param stddev the standard deviation values to divide then for each channel.
  13081. + * @throws IllegalArgumentException if any {@code stddev} is zero, or {@code mean} has different
  13082. + * number of elements with {@code stddev}, or any of them is empty.
  13083. + */
  13084. + public NormalizeOp(@NonNull float[] mean, @NonNull float[] stddev) {
  13085. + SupportPreconditions.checkNotNull(mean, "Mean cannot be null");
  13086. + SupportPreconditions.checkNotNull(stddev, "Stddev cannot be null");
  13087. + SupportPreconditions.checkArgument(mean.length == stddev.length,
  13088. + "Per channel normalization requires same number of means and stddevs");
  13089. + SupportPreconditions.checkArgument(mean.length > 0, "Means and stddevs are empty.");
  13090. + this.mean = mean.clone();
  13091. + this.stddev = stddev.clone();
  13092. + boolean allMeansAreZeroAndAllDevsAre1 = true;
  13093. + this.numChannels = mean.length;
  13094. + for (int i = 0; i < numChannels; i++) {
  13095. + SupportPreconditions.checkArgument(this.stddev[i] != 0, "Stddev cannot be zero.");
  13096. + if (this.stddev[i] != 1 || this.mean[i] != 0) {
  13097. + allMeansAreZeroAndAllDevsAre1 = false;
  13098. + }
  13099. + }
  13100. + this.isIdentityOp = allMeansAreZeroAndAllDevsAre1;
  13101. }
  13102. - this.isIdentityOp = allMeansAreZeroAndAllDevsAre1;
  13103. - }
  13104. - /**
  13105. - * Applies the defined normalization on given tensor and returns the result.
  13106. - *
  13107. - * <p>Note: {@code input} is possibly the same instance with the output.
  13108. - *
  13109. - * @param input input tensor. It may be the same instance with the output.
  13110. - * @return output tensor.
  13111. - */
  13112. - @Override
  13113. - @NonNull
  13114. - public TensorBuffer apply(@NonNull TensorBuffer input) {
  13115. - if (isIdentityOp) {
  13116. - return input;
  13117. - }
  13118. - int[] shape = input.getShape();
  13119. - SupportPreconditions.checkArgument(
  13120. - numChannels == 1 || (shape.length != 0 && shape[shape.length - 1] == numChannels),
  13121. - "Number of means (stddevs) is not same with number of channels (size of last axis).");
  13122. - // TODO(136750944): Eliminate the array copy here.
  13123. - float[] values = input.getFloatArray();
  13124. - int j = 0;
  13125. - for (int i = 0; i < values.length; i++) {
  13126. - values[i] = (values[i] - mean[j]) / stddev[j];
  13127. - j = (j + 1) % numChannels;
  13128. - }
  13129. - TensorBuffer output;
  13130. - if (input.isDynamic()) {
  13131. - output = TensorBufferFloat.createDynamic(DataType.FLOAT32);
  13132. - } else {
  13133. - output = TensorBufferFloat.createFixedSize(shape, DataType.FLOAT32);
  13134. + /**
  13135. + * Applies the defined normalization on given tensor and returns the result.
  13136. + *
  13137. + * <p>Note: {@code input} is possibly the same instance with the output.
  13138. + *
  13139. + * @param input input tensor. It may be the same instance with the output.
  13140. + * @return output tensor.
  13141. + */
  13142. + @Override
  13143. + @NonNull
  13144. + public TensorBuffer apply(@NonNull TensorBuffer input) {
  13145. + if (isIdentityOp) {
  13146. + return input;
  13147. + }
  13148. + int[] shape = input.getShape();
  13149. + SupportPreconditions.checkArgument(
  13150. + numChannels == 1 || (shape.length != 0 && shape[shape.length - 1] == numChannels),
  13151. + "Number of means (stddevs) is not same with number of channels (size of last axis).");
  13152. + // TODO(136750944): Eliminate the array copy here.
  13153. + float[] values = input.getFloatArray();
  13154. + int j = 0;
  13155. + for (int i = 0; i < values.length; i++) {
  13156. + values[i] = (values[i] - mean[j]) / stddev[j];
  13157. + j = (j + 1) % numChannels;
  13158. + }
  13159. + TensorBuffer output;
  13160. + if (input.isDynamic()) {
  13161. + output = TensorBufferFloat.createDynamic(DataType.FLOAT32);
  13162. + } else {
  13163. + output = TensorBufferFloat.createFixedSize(shape, DataType.FLOAT32);
  13164. + }
  13165. + output.loadArray(values, shape);
  13166. + return output;
  13167. }
  13168. - output.loadArray(values, shape);
  13169. - return output;
  13170. - }
  13171. }
  13172. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/QuantizeOp.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/QuantizeOp.java
  13173. index 8b3e82aee13ef..84cb856fd4ed9 100644
  13174. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/QuantizeOp.java
  13175. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/QuantizeOp.java
  13176. @@ -33,9 +33,8 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
  13177. * as 0.
  13178. */
  13179. public class QuantizeOp extends NormalizeOp implements TensorOperator {
  13180. -
  13181. - public QuantizeOp(float zeroPoint, float scale) {
  13182. - // Quantization: f = (q - z) * s, i.e. q = f / s + z = (f - (-z * s)) / s
  13183. - super(-zeroPoint * scale, scale);
  13184. - }
  13185. + public QuantizeOp(float zeroPoint, float scale) {
  13186. + // Quantization: f = (q - z) * s, i.e. q = f / s + z = (f - (-z * s)) / s
  13187. + super(-zeroPoint * scale, scale);
  13188. + }
  13189. }
  13190. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/BitmapContainer.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/BitmapContainer.java
  13191. index 9bee78d139efa..f9b6a1f874bff 100644
  13192. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/BitmapContainer.java
  13193. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/BitmapContainer.java
  13194. @@ -21,67 +21,67 @@ import static org.tensorflow.lite.support.common.internal.SupportPreconditions.c
  13195. import android.graphics.Bitmap;
  13196. import android.graphics.Bitmap.Config;
  13197. import android.media.Image;
  13198. +
  13199. import org.tensorflow.lite.DataType;
  13200. import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
  13201. /** Holds a {@link Bitmap} and converts it to other image formats as needed. */
  13202. final class BitmapContainer implements ImageContainer {
  13203. -
  13204. - private final Bitmap bitmap;
  13205. -
  13206. - /**
  13207. - * Creates a {@link BitmapContainer} object with ARGB_8888 {@link Bitmap}.
  13208. - *
  13209. - * @throws IllegalArgumentException if the bitmap configuration is not ARGB_8888
  13210. - */
  13211. - static BitmapContainer create(Bitmap bitmap) {
  13212. - return new BitmapContainer(bitmap);
  13213. - }
  13214. -
  13215. - private BitmapContainer(Bitmap bitmap) {
  13216. - checkNotNull(bitmap, "Cannot load null bitmap.");
  13217. - checkArgument(
  13218. - bitmap.getConfig().equals(Config.ARGB_8888), "Only supports loading ARGB_8888 bitmaps.");
  13219. - this.bitmap = bitmap;
  13220. - }
  13221. -
  13222. - @Override
  13223. - public BitmapContainer clone() {
  13224. - return create(bitmap.copy(bitmap.getConfig(), bitmap.isMutable()));
  13225. - }
  13226. -
  13227. - @Override
  13228. - public Bitmap getBitmap() {
  13229. - // Not making a defensive copy for performance considerations. During image processing,
  13230. - // users may need to set and get the bitmap many times.
  13231. - return bitmap;
  13232. - }
  13233. -
  13234. - @Override
  13235. - public TensorBuffer getTensorBuffer(DataType dataType) {
  13236. - TensorBuffer buffer = TensorBuffer.createDynamic(dataType);
  13237. - ImageConversions.convertBitmapToTensorBuffer(bitmap, buffer);
  13238. - return buffer;
  13239. - }
  13240. -
  13241. - @Override
  13242. - public Image getMediaImage() {
  13243. - throw new UnsupportedOperationException(
  13244. - "Converting from Bitmap to android.media.Image is unsupported.");
  13245. - }
  13246. -
  13247. - @Override
  13248. - public int getWidth() {
  13249. - return bitmap.getWidth();
  13250. - }
  13251. -
  13252. - @Override
  13253. - public int getHeight() {
  13254. - return bitmap.getHeight();
  13255. - }
  13256. -
  13257. - @Override
  13258. - public ColorSpaceType getColorSpaceType() {
  13259. - return ColorSpaceType.fromBitmapConfig(bitmap.getConfig());
  13260. - }
  13261. + private final Bitmap bitmap;
  13262. +
  13263. + /**
  13264. + * Creates a {@link BitmapContainer} object with ARGB_8888 {@link Bitmap}.
  13265. + *
  13266. + * @throws IllegalArgumentException if the bitmap configuration is not ARGB_8888
  13267. + */
  13268. + static BitmapContainer create(Bitmap bitmap) {
  13269. + return new BitmapContainer(bitmap);
  13270. + }
  13271. +
  13272. + private BitmapContainer(Bitmap bitmap) {
  13273. + checkNotNull(bitmap, "Cannot load null bitmap.");
  13274. + checkArgument(bitmap.getConfig().equals(Config.ARGB_8888),
  13275. + "Only supports loading ARGB_8888 bitmaps.");
  13276. + this.bitmap = bitmap;
  13277. + }
  13278. +
  13279. + @Override
  13280. + public BitmapContainer clone() {
  13281. + return create(bitmap.copy(bitmap.getConfig(), bitmap.isMutable()));
  13282. + }
  13283. +
  13284. + @Override
  13285. + public Bitmap getBitmap() {
  13286. + // Not making a defensive copy for performance considerations. During image processing,
  13287. + // users may need to set and get the bitmap many times.
  13288. + return bitmap;
  13289. + }
  13290. +
  13291. + @Override
  13292. + public TensorBuffer getTensorBuffer(DataType dataType) {
  13293. + TensorBuffer buffer = TensorBuffer.createDynamic(dataType);
  13294. + ImageConversions.convertBitmapToTensorBuffer(bitmap, buffer);
  13295. + return buffer;
  13296. + }
  13297. +
  13298. + @Override
  13299. + public Image getMediaImage() {
  13300. + throw new UnsupportedOperationException(
  13301. + "Converting from Bitmap to android.media.Image is unsupported.");
  13302. + }
  13303. +
  13304. + @Override
  13305. + public int getWidth() {
  13306. + return bitmap.getWidth();
  13307. + }
  13308. +
  13309. + @Override
  13310. + public int getHeight() {
  13311. + return bitmap.getHeight();
  13312. + }
  13313. +
  13314. + @Override
  13315. + public ColorSpaceType getColorSpaceType() {
  13316. + return ColorSpaceType.fromBitmapConfig(bitmap.getConfig());
  13317. + }
  13318. }
  13319. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/BoundingBoxUtil.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/BoundingBoxUtil.java
  13320. index 8571d6227e136..a2e833b68d6d0 100644
  13321. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/BoundingBoxUtil.java
  13322. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/BoundingBoxUtil.java
  13323. @@ -18,13 +18,15 @@ package org.tensorflow.lite.support.image;
  13324. import static org.tensorflow.lite.support.common.internal.SupportPreconditions.checkArgument;
  13325. import android.graphics.RectF;
  13326. +
  13327. +import org.tensorflow.lite.DataType;
  13328. +import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
  13329. +
  13330. import java.nio.ByteBuffer;
  13331. import java.nio.FloatBuffer;
  13332. import java.util.ArrayList;
  13333. import java.util.Arrays;
  13334. import java.util.List;
  13335. -import org.tensorflow.lite.DataType;
  13336. -import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
  13337. /**
  13338. * Helper class for converting values that represents bounding boxes into rectangles.
  13339. @@ -37,207 +39,186 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
  13340. * elements in each type is configurable as well.
  13341. */
  13342. public final class BoundingBoxUtil {
  13343. + /** Denotes how a bounding box is represented. */
  13344. + public enum Type {
  13345. + /**
  13346. + * Represents the bounding box by using the combination of boundaries, {left, top, right,
  13347. + * bottom}. The default order is {left, top, right, bottom}. Other orders can be indicated
  13348. + * by an index array.
  13349. + */
  13350. + BOUNDARIES,
  13351. + /**
  13352. + * Represents the bounding box by using the upper_left corner, width and height. The default
  13353. + * order is {upper_left_x, upper_left_y, width, height}. Other orders can be indicated by an
  13354. + * index array.
  13355. + */
  13356. + UPPER_LEFT,
  13357. + /**
  13358. + * Represents the bounding box by using the center of the box, width and height. The default
  13359. + * order is {center_x, center_y, width, height}. Other orders can be indicated by an index
  13360. + * array.
  13361. + */
  13362. + CENTER,
  13363. + }
  13364. +
  13365. + /** Denotes if the coordinates are actual pixels or relative ratios. */
  13366. + public enum CoordinateType {
  13367. + /** The coordinates are relative ratios in range [0, 1]. */
  13368. + RATIO,
  13369. + /** The coordinates are actual pixel values. */
  13370. + PIXEL
  13371. + }
  13372. - /** Denotes how a bounding box is represented. */
  13373. - public enum Type {
  13374. - /**
  13375. - * Represents the bounding box by using the combination of boundaries, {left, top, right,
  13376. - * bottom}. The default order is {left, top, right, bottom}. Other orders can be indicated by an
  13377. - * index array.
  13378. - */
  13379. - BOUNDARIES,
  13380. - /**
  13381. - * Represents the bounding box by using the upper_left corner, width and height. The default
  13382. - * order is {upper_left_x, upper_left_y, width, height}. Other orders can be indicated by an
  13383. - * index array.
  13384. - */
  13385. - UPPER_LEFT,
  13386. /**
  13387. - * Represents the bounding box by using the center of the box, width and height. The default
  13388. - * order is {center_x, center_y, width, height}. Other orders can be indicated by an index
  13389. - * array.
  13390. + * Creates a list of bounding boxes from a {@link TensorBuffer} which represents bounding boxes.
  13391. + *
  13392. + * @param tensor holds the data representing some boxes.
  13393. + * @param valueIndex denotes the order of the elements defined in each bounding box type. An
  13394. + * empty
  13395. + * index array represent the default order of each bounding box type. For example, to denote
  13396. + * the default order of BOUNDARIES, {left, top, right, bottom}, the index should be {0, 1,
  13397. + * 2, 3}. To denote the order {left, right, top, bottom}, the order should be {0, 2, 1, 3}.
  13398. + * <p>The index array can be applied to all bounding box types to adjust the order of their
  13399. + * corresponding underlying elements.
  13400. + * @param boundingBoxAxis specifies the index of the dimension that represents bounding box. The
  13401. + * size of that dimension is required to be 4. Index here starts from 0. For example, if the
  13402. + * tensor has shape 4x10, the axis for bounding boxes is likely to be 0. Negative axis is
  13403. + * also supported: -1 gives the last axis and -2 gives the second, .etc. theFor shape 10x4, the
  13404. + * axis is likely to be 1 (or -1, equivalently).
  13405. + * @param type defines how values should be converted into boxes. See {@link Type}
  13406. + * @param coordinateType defines how values are interpreted to coordinates. See {@link
  13407. + * CoordinateType}
  13408. + * @param height the height of the image which the boxes belong to. Only has effects when {@code
  13409. + * coordinateType} is {@link CoordinateType#RATIO}
  13410. + * @param width the width of the image which the boxes belong to. Only has effects when {@code
  13411. + * coordinateType} is {@link CoordinateType#RATIO}
  13412. + * @return A list of bounding boxes that the {@code tensor} represents. All dimensions except
  13413. + * {@code boundingBoxAxis} will be collapsed with order kept. For example, given {@code
  13414. + * tensor} with shape {1, 4, 10, 2} and {@code boundingBoxAxis = 1}, The result will be a
  13415. + * list of 20 bounding boxes.
  13416. + * @throws IllegalArgumentException if size of bounding box dimension (set by {@code
  13417. + * boundingBoxAxis}) is not 4.
  13418. + * @throws IllegalArgumentException if {@code boundingBoxAxis} is not in {@code (-(D+1), D)}
  13419. + * where
  13420. + * {@code D} is the number of dimensions of the {@code tensor}.
  13421. + * @throws IllegalArgumentException if {@code tensor} has data type other than {@link
  13422. + * DataType#FLOAT32}.
  13423. */
  13424. - CENTER,
  13425. - }
  13426. -
  13427. - /** Denotes if the coordinates are actual pixels or relative ratios. */
  13428. - public enum CoordinateType {
  13429. - /** The coordinates are relative ratios in range [0, 1]. */
  13430. - RATIO,
  13431. - /** The coordinates are actual pixel values. */
  13432. - PIXEL
  13433. - }
  13434. -
  13435. - /**
  13436. - * Creates a list of bounding boxes from a {@link TensorBuffer} which represents bounding boxes.
  13437. - *
  13438. - * @param tensor holds the data representing some boxes.
  13439. - * @param valueIndex denotes the order of the elements defined in each bounding box type. An empty
  13440. - * index array represent the default order of each bounding box type. For example, to denote
  13441. - * the default order of BOUNDARIES, {left, top, right, bottom}, the index should be {0, 1, 2,
  13442. - * 3}. To denote the order {left, right, top, bottom}, the order should be {0, 2, 1, 3}.
  13443. - * <p>The index array can be applied to all bounding box types to adjust the order of their
  13444. - * corresponding underlying elements.
  13445. - * @param boundingBoxAxis specifies the index of the dimension that represents bounding box. The
  13446. - * size of that dimension is required to be 4. Index here starts from 0. For example, if the
  13447. - * tensor has shape 4x10, the axis for bounding boxes is likely to be 0. Negative axis is also
  13448. - * supported: -1 gives the last axis and -2 gives the second, .etc. theFor shape 10x4, the
  13449. - * axis is likely to be 1 (or -1, equivalently).
  13450. - * @param type defines how values should be converted into boxes. See {@link Type}
  13451. - * @param coordinateType defines how values are interpreted to coordinates. See {@link
  13452. - * CoordinateType}
  13453. - * @param height the height of the image which the boxes belong to. Only has effects when {@code
  13454. - * coordinateType} is {@link CoordinateType#RATIO}
  13455. - * @param width the width of the image which the boxes belong to. Only has effects when {@code
  13456. - * coordinateType} is {@link CoordinateType#RATIO}
  13457. - * @return A list of bounding boxes that the {@code tensor} represents. All dimensions except
  13458. - * {@code boundingBoxAxis} will be collapsed with order kept. For example, given {@code
  13459. - * tensor} with shape {1, 4, 10, 2} and {@code boundingBoxAxis = 1}, The result will be a list
  13460. - * of 20 bounding boxes.
  13461. - * @throws IllegalArgumentException if size of bounding box dimension (set by {@code
  13462. - * boundingBoxAxis}) is not 4.
  13463. - * @throws IllegalArgumentException if {@code boundingBoxAxis} is not in {@code (-(D+1), D)} where
  13464. - * {@code D} is the number of dimensions of the {@code tensor}.
  13465. - * @throws IllegalArgumentException if {@code tensor} has data type other than {@link
  13466. - * DataType#FLOAT32}.
  13467. - */
  13468. - public static List<RectF> convert(
  13469. - TensorBuffer tensor,
  13470. - int[] valueIndex,
  13471. - int boundingBoxAxis,
  13472. - Type type,
  13473. - CoordinateType coordinateType,
  13474. - int height,
  13475. - int width) {
  13476. - int[] shape = tensor.getShape();
  13477. - checkArgument(
  13478. - boundingBoxAxis >= -shape.length && boundingBoxAxis < shape.length,
  13479. - String.format(
  13480. - "Axis %d is not in range (-(D+1), D), where D is the number of dimensions of input"
  13481. - + " tensor (shape=%s)",
  13482. - boundingBoxAxis, Arrays.toString(shape)));
  13483. - if (boundingBoxAxis < 0) {
  13484. - boundingBoxAxis = shape.length + boundingBoxAxis;
  13485. - }
  13486. - checkArgument(
  13487. - shape[boundingBoxAxis] == 4,
  13488. - String.format(
  13489. - "Size of bounding box dimension %d is not 4. Got %d in shape %s",
  13490. - boundingBoxAxis, shape[boundingBoxAxis], Arrays.toString(shape)));
  13491. - checkArgument(
  13492. - valueIndex.length == 4,
  13493. - String.format(
  13494. - "Bounding box index array length %d is not 4. Got index array %s",
  13495. - valueIndex.length, Arrays.toString(valueIndex)));
  13496. - checkArgument(
  13497. - tensor.getDataType() == DataType.FLOAT32,
  13498. - "Bounding Boxes only create from FLOAT32 buffers. Got: " + tensor.getDataType().name());
  13499. - List<RectF> boundingBoxList = new ArrayList<>();
  13500. - // Collapse dimensions to {a, 4, b}. So each bounding box could be represent as (i, j), and its
  13501. - // four values are (i, k, j), where 0 <= k < 4. We can compute the 4 flattened index by
  13502. - // i * 4b + k * b + j.
  13503. - int a = 1;
  13504. - for (int i = 0; i < boundingBoxAxis; i++) {
  13505. - a *= shape[i];
  13506. + public static List<RectF> convert(TensorBuffer tensor, int[] valueIndex, int boundingBoxAxis,
  13507. + Type type, CoordinateType coordinateType, int height, int width) {
  13508. + int[] shape = tensor.getShape();
  13509. + checkArgument(boundingBoxAxis >= -shape.length && boundingBoxAxis < shape.length,
  13510. + String.format(
  13511. + "Axis %d is not in range (-(D+1), D), where D is the number of dimensions of input"
  13512. + + " tensor (shape=%s)",
  13513. + boundingBoxAxis, Arrays.toString(shape)));
  13514. + if (boundingBoxAxis < 0) {
  13515. + boundingBoxAxis = shape.length + boundingBoxAxis;
  13516. + }
  13517. + checkArgument(shape[boundingBoxAxis] == 4,
  13518. + String.format("Size of bounding box dimension %d is not 4. Got %d in shape %s",
  13519. + boundingBoxAxis, shape[boundingBoxAxis], Arrays.toString(shape)));
  13520. + checkArgument(valueIndex.length == 4,
  13521. + String.format("Bounding box index array length %d is not 4. Got index array %s",
  13522. + valueIndex.length, Arrays.toString(valueIndex)));
  13523. + checkArgument(tensor.getDataType() == DataType.FLOAT32,
  13524. + "Bounding Boxes only create from FLOAT32 buffers. Got: "
  13525. + + tensor.getDataType().name());
  13526. + List<RectF> boundingBoxList = new ArrayList<>();
  13527. + // Collapse dimensions to {a, 4, b}. So each bounding box could be represent as (i, j), and
  13528. + // its four values are (i, k, j), where 0 <= k < 4. We can compute the 4 flattened index by
  13529. + // i * 4b + k * b + j.
  13530. + int a = 1;
  13531. + for (int i = 0; i < boundingBoxAxis; i++) {
  13532. + a *= shape[i];
  13533. + }
  13534. + int b = 1;
  13535. + for (int i = boundingBoxAxis + 1; i < shape.length; i++) {
  13536. + b *= shape[i];
  13537. + }
  13538. + float[] values = new float[4];
  13539. + ByteBuffer byteBuffer = tensor.getBuffer();
  13540. + byteBuffer.rewind();
  13541. + FloatBuffer floatBuffer = byteBuffer.asFloatBuffer();
  13542. + for (int i = 0; i < a; i++) {
  13543. + for (int j = 0; j < b; j++) {
  13544. + for (int k = 0; k < 4; k++) {
  13545. + values[k] = floatBuffer.get((i * 4 + k) * b + j);
  13546. + }
  13547. + boundingBoxList.add(convertOneBoundingBox(
  13548. + values, valueIndex, type, coordinateType, height, width));
  13549. + }
  13550. + }
  13551. + byteBuffer.rewind();
  13552. + return boundingBoxList;
  13553. }
  13554. - int b = 1;
  13555. - for (int i = boundingBoxAxis + 1; i < shape.length; i++) {
  13556. - b *= shape[i];
  13557. +
  13558. + private static RectF convertOneBoundingBox(float[] values, int[] valueIndex, Type type,
  13559. + CoordinateType coordinateType, int height, int width) {
  13560. + float[] orderedValues = new float[4];
  13561. + for (int i = 0; i < 4; i++) {
  13562. + orderedValues[i] = values[valueIndex[i]];
  13563. + }
  13564. + return convertOneBoundingBox(orderedValues, type, coordinateType, height, width);
  13565. }
  13566. - float[] values = new float[4];
  13567. - ByteBuffer byteBuffer = tensor.getBuffer();
  13568. - byteBuffer.rewind();
  13569. - FloatBuffer floatBuffer = byteBuffer.asFloatBuffer();
  13570. - for (int i = 0; i < a; i++) {
  13571. - for (int j = 0; j < b; j++) {
  13572. - for (int k = 0; k < 4; k++) {
  13573. - values[k] = floatBuffer.get((i * 4 + k) * b + j);
  13574. +
  13575. + private static RectF convertOneBoundingBox(
  13576. + float[] values, Type type, CoordinateType coordinateType, int height, int width) {
  13577. + switch (type) {
  13578. + case BOUNDARIES:
  13579. + return convertFromBoundaries(values, coordinateType, height, width);
  13580. + case UPPER_LEFT:
  13581. + return convertFromUpperLeft(values, coordinateType, height, width);
  13582. + case CENTER:
  13583. + return convertFromCenter(values, coordinateType, height, width);
  13584. }
  13585. - boundingBoxList.add(
  13586. - convertOneBoundingBox(values, valueIndex, type, coordinateType, height, width));
  13587. - }
  13588. + throw new IllegalArgumentException("Cannot recognize BoundingBox.Type " + type);
  13589. }
  13590. - byteBuffer.rewind();
  13591. - return boundingBoxList;
  13592. - }
  13593. -
  13594. - private static RectF convertOneBoundingBox(
  13595. - float[] values,
  13596. - int[] valueIndex,
  13597. - Type type,
  13598. - CoordinateType coordinateType,
  13599. - int height,
  13600. - int width) {
  13601. - float[] orderedValues = new float[4];
  13602. - for (int i = 0; i < 4; i++) {
  13603. - orderedValues[i] = values[valueIndex[i]];
  13604. +
  13605. + private static RectF convertFromBoundaries(
  13606. + float[] values, CoordinateType coordinateType, int imageHeight, int imageWidth) {
  13607. + float left = values[0];
  13608. + float top = values[1];
  13609. + float right = values[2];
  13610. + float bottom = values[3];
  13611. + return getRectF(left, top, right, bottom, imageHeight, imageWidth, coordinateType);
  13612. + }
  13613. +
  13614. + private static RectF convertFromUpperLeft(
  13615. + float[] values, CoordinateType coordinateType, int imageHeight, int imageWidth) {
  13616. + float left = values[0];
  13617. + float top = values[1];
  13618. + float right = values[0] + values[2];
  13619. + float bottom = values[1] + values[3];
  13620. + return getRectF(left, top, right, bottom, imageHeight, imageWidth, coordinateType);
  13621. }
  13622. - return convertOneBoundingBox(orderedValues, type, coordinateType, height, width);
  13623. - }
  13624. -
  13625. - private static RectF convertOneBoundingBox(
  13626. - float[] values, Type type, CoordinateType coordinateType, int height, int width) {
  13627. - switch (type) {
  13628. - case BOUNDARIES:
  13629. - return convertFromBoundaries(values, coordinateType, height, width);
  13630. - case UPPER_LEFT:
  13631. - return convertFromUpperLeft(values, coordinateType, height, width);
  13632. - case CENTER:
  13633. - return convertFromCenter(values, coordinateType, height, width);
  13634. +
  13635. + private static RectF convertFromCenter(
  13636. + float[] values, CoordinateType coordinateType, int imageHeight, int imageWidth) {
  13637. + float centerX = values[0];
  13638. + float centerY = values[1];
  13639. + float w = values[2];
  13640. + float h = values[3];
  13641. +
  13642. + float left = centerX - w / 2;
  13643. + float top = centerY - h / 2;
  13644. + float right = centerX + w / 2;
  13645. + float bottom = centerY + h / 2;
  13646. + return getRectF(left, top, right, bottom, imageHeight, imageWidth, coordinateType);
  13647. }
  13648. - throw new IllegalArgumentException("Cannot recognize BoundingBox.Type " + type);
  13649. - }
  13650. -
  13651. - private static RectF convertFromBoundaries(
  13652. - float[] values, CoordinateType coordinateType, int imageHeight, int imageWidth) {
  13653. - float left = values[0];
  13654. - float top = values[1];
  13655. - float right = values[2];
  13656. - float bottom = values[3];
  13657. - return getRectF(left, top, right, bottom, imageHeight, imageWidth, coordinateType);
  13658. - }
  13659. -
  13660. - private static RectF convertFromUpperLeft(
  13661. - float[] values, CoordinateType coordinateType, int imageHeight, int imageWidth) {
  13662. - float left = values[0];
  13663. - float top = values[1];
  13664. - float right = values[0] + values[2];
  13665. - float bottom = values[1] + values[3];
  13666. - return getRectF(left, top, right, bottom, imageHeight, imageWidth, coordinateType);
  13667. - }
  13668. -
  13669. - private static RectF convertFromCenter(
  13670. - float[] values, CoordinateType coordinateType, int imageHeight, int imageWidth) {
  13671. - float centerX = values[0];
  13672. - float centerY = values[1];
  13673. - float w = values[2];
  13674. - float h = values[3];
  13675. -
  13676. - float left = centerX - w / 2;
  13677. - float top = centerY - h / 2;
  13678. - float right = centerX + w / 2;
  13679. - float bottom = centerY + h / 2;
  13680. - return getRectF(left, top, right, bottom, imageHeight, imageWidth, coordinateType);
  13681. - }
  13682. -
  13683. - private static RectF getRectF(
  13684. - float left,
  13685. - float top,
  13686. - float right,
  13687. - float bottom,
  13688. - int imageHeight,
  13689. - int imageWidth,
  13690. - CoordinateType coordinateType) {
  13691. - if (coordinateType == CoordinateType.PIXEL) {
  13692. - return new RectF(left, top, right, bottom);
  13693. - } else if (coordinateType == CoordinateType.RATIO) {
  13694. - return new RectF(
  13695. - left * imageWidth, top * imageHeight, right * imageWidth, bottom * imageHeight);
  13696. - } else {
  13697. - throw new IllegalArgumentException("Cannot convert coordinate type " + coordinateType);
  13698. +
  13699. + private static RectF getRectF(float left, float top, float right, float bottom, int imageHeight,
  13700. + int imageWidth, CoordinateType coordinateType) {
  13701. + if (coordinateType == CoordinateType.PIXEL) {
  13702. + return new RectF(left, top, right, bottom);
  13703. + } else if (coordinateType == CoordinateType.RATIO) {
  13704. + return new RectF(
  13705. + left * imageWidth, top * imageHeight, right * imageWidth, bottom * imageHeight);
  13706. + } else {
  13707. + throw new IllegalArgumentException("Cannot convert coordinate type " + coordinateType);
  13708. + }
  13709. }
  13710. - }
  13711. - // Private constructor to prevent initialization.
  13712. - private BoundingBoxUtil() {}
  13713. + // Private constructor to prevent initialization.
  13714. + private BoundingBoxUtil() {}
  13715. }
  13716. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ColorSpaceType.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ColorSpaceType.java
  13717. index 457bcf1da1de3..716cacdf7bf51 100644
  13718. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ColorSpaceType.java
  13719. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ColorSpaceType.java
  13720. @@ -20,354 +20,351 @@ import static org.tensorflow.lite.support.common.internal.SupportPreconditions.c
  13721. import android.graphics.Bitmap;
  13722. import android.graphics.Bitmap.Config;
  13723. import android.graphics.ImageFormat;
  13724. -import java.util.Arrays;
  13725. +
  13726. import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
  13727. +import java.util.Arrays;
  13728. +
  13729. /** Represents the type of color space of an image. */
  13730. public enum ColorSpaceType {
  13731. - /** Each pixel has red, green, and blue color components. */
  13732. - RGB(0) {
  13733. -
  13734. - // The channel axis should always be 3 for RGB images.
  13735. - private static final int CHANNEL_VALUE = 3;
  13736. -
  13737. - @Override
  13738. - Bitmap convertTensorBufferToBitmap(TensorBuffer buffer) {
  13739. - return ImageConversions.convertRgbTensorBufferToBitmap(buffer);
  13740. + /** Each pixel has red, green, and blue color components. */
  13741. + RGB(0) {
  13742. + // The channel axis should always be 3 for RGB images.
  13743. + private static final int CHANNEL_VALUE = 3;
  13744. +
  13745. + @Override
  13746. + Bitmap convertTensorBufferToBitmap(TensorBuffer buffer) {
  13747. + return ImageConversions.convertRgbTensorBufferToBitmap(buffer);
  13748. + }
  13749. +
  13750. + @Override
  13751. + int getChannelValue() {
  13752. + return CHANNEL_VALUE;
  13753. + }
  13754. +
  13755. + @Override
  13756. + int[] getNormalizedShape(int[] shape) {
  13757. + switch (shape.length) {
  13758. + // The shape is in (h, w, c) format.
  13759. + case 3:
  13760. + return insertValue(shape, BATCH_DIM, BATCH_VALUE);
  13761. + case 4:
  13762. + return shape;
  13763. + default:
  13764. + throw new IllegalArgumentException(getShapeInfoMessage()
  13765. + + "The provided image shape is " + Arrays.toString(shape));
  13766. + }
  13767. + }
  13768. +
  13769. + @Override
  13770. + int getNumElements(int height, int width) {
  13771. + return height * width * CHANNEL_VALUE;
  13772. + }
  13773. +
  13774. + @Override
  13775. + String getShapeInfoMessage() {
  13776. + return "The shape of a RGB image should be (h, w, c) or (1, h, w, c), and channels"
  13777. + + " representing R, G, B in order. ";
  13778. + }
  13779. +
  13780. + @Override
  13781. + Config toBitmapConfig() {
  13782. + return Config.ARGB_8888;
  13783. + }
  13784. + },
  13785. +
  13786. + /** Each pixel is a single element representing only the amount of light. */
  13787. + GRAYSCALE(1) {
  13788. + // The channel axis should always be 1 for grayscale images.
  13789. + private static final int CHANNEL_VALUE = 1;
  13790. +
  13791. + @Override
  13792. + Bitmap convertTensorBufferToBitmap(TensorBuffer buffer) {
  13793. + return ImageConversions.convertGrayscaleTensorBufferToBitmap(buffer);
  13794. + }
  13795. +
  13796. + @Override
  13797. + int getChannelValue() {
  13798. + return CHANNEL_VALUE;
  13799. + }
  13800. +
  13801. + @Override
  13802. + int[] getNormalizedShape(int[] shape) {
  13803. + switch (shape.length) {
  13804. + // The shape is in (h, w) format.
  13805. + case 2:
  13806. + int[] shapeWithBatch = insertValue(shape, BATCH_DIM, BATCH_VALUE);
  13807. + return insertValue(shapeWithBatch, CHANNEL_DIM, CHANNEL_VALUE);
  13808. + case 4:
  13809. + return shape;
  13810. + default:
  13811. + // (1, h, w) and (h, w, 1) are potential grayscale image shapes. However, since
  13812. + // they both have three dimensions, it will require extra info to differentiate
  13813. + // between them. Since we haven't encountered real use cases of these two
  13814. + // shapes, they are not supported at this moment to avoid confusion. We may want
  13815. + // to revisit it in the future.
  13816. + throw new IllegalArgumentException(getShapeInfoMessage()
  13817. + + "The provided image shape is " + Arrays.toString(shape));
  13818. + }
  13819. + }
  13820. +
  13821. + @Override
  13822. + int getNumElements(int height, int width) {
  13823. + return height * width;
  13824. + }
  13825. +
  13826. + @Override
  13827. + String getShapeInfoMessage() {
  13828. + return "The shape of a grayscale image should be (h, w) or (1, h, w, 1). ";
  13829. + }
  13830. +
  13831. + @Override
  13832. + Config toBitmapConfig() {
  13833. + return Config.ALPHA_8;
  13834. + }
  13835. + },
  13836. +
  13837. + /** YUV420sp format, encoded as "YYYYYYYY UVUV". */
  13838. + NV12(2) {
  13839. + @Override
  13840. + int getNumElements(int height, int width) {
  13841. + return getYuv420NumElements(height, width);
  13842. + }
  13843. + },
  13844. +
  13845. + /**
  13846. + * YUV420sp format, encoded as "YYYYYYYY VUVU", the standard picture format on Android Camera1
  13847. + * preview.
  13848. + */
  13849. + NV21(3) {
  13850. + @Override
  13851. + int getNumElements(int height, int width) {
  13852. + return getYuv420NumElements(height, width);
  13853. + }
  13854. + },
  13855. +
  13856. + /** YUV420p format, encoded as "YYYYYYYY VV UU". */
  13857. + YV12(4) {
  13858. + @Override
  13859. + int getNumElements(int height, int width) {
  13860. + return getYuv420NumElements(height, width);
  13861. + }
  13862. + },
  13863. +
  13864. + /** YUV420p format, encoded as "YYYYYYYY UU VV". */
  13865. + YV21(5) {
  13866. + @Override
  13867. + int getNumElements(int height, int width) {
  13868. + return getYuv420NumElements(height, width);
  13869. + }
  13870. + },
  13871. +
  13872. + /**
  13873. + * YUV420 format corresponding to {@link android.graphics.ImageFormat#YUV_420_888}. The actual
  13874. + * encoding format (i.e. NV12 / Nv21 / YV12 / YV21) depends on the implementation of the image.
  13875. + *
  13876. + * <p>Use this format only when you load an {@link android.media.Image}.
  13877. + */
  13878. + YUV_420_888(6) {
  13879. + @Override
  13880. + int getNumElements(int height, int width) {
  13881. + return getYuv420NumElements(height, width);
  13882. + }
  13883. + };
  13884. +
  13885. + private static final int BATCH_DIM = 0; // The first element of the normalizaed shape.
  13886. + private static final int BATCH_VALUE = 1; // The batch axis should always be one.
  13887. + private static final int HEIGHT_DIM = 1; // The second element of the normalizaed shape.
  13888. + private static final int WIDTH_DIM = 2; // The third element of the normalizaed shape.
  13889. + private static final int CHANNEL_DIM = 3; // The fourth element of the normalizaed shape.
  13890. + private final int value;
  13891. +
  13892. + ColorSpaceType(int value) {
  13893. + this.value = value;
  13894. }
  13895. - @Override
  13896. - int getChannelValue() {
  13897. - return CHANNEL_VALUE;
  13898. + /**
  13899. + * Converts a bitmap configuration into the corresponding color space type.
  13900. + *
  13901. + * @throws IllegalArgumentException if the config is unsupported
  13902. + */
  13903. + static ColorSpaceType fromBitmapConfig(Config config) {
  13904. + switch (config) {
  13905. + case ARGB_8888:
  13906. + return ColorSpaceType.RGB;
  13907. + case ALPHA_8:
  13908. + return ColorSpaceType.GRAYSCALE;
  13909. + default:
  13910. + throw new IllegalArgumentException(
  13911. + "Bitmap configuration: " + config + ", is not supported yet.");
  13912. + }
  13913. }
  13914. - @Override
  13915. - int[] getNormalizedShape(int[] shape) {
  13916. - switch (shape.length) {
  13917. - // The shape is in (h, w, c) format.
  13918. - case 3:
  13919. - return insertValue(shape, BATCH_DIM, BATCH_VALUE);
  13920. - case 4:
  13921. - return shape;
  13922. - default:
  13923. - throw new IllegalArgumentException(
  13924. - getShapeInfoMessage() + "The provided image shape is " + Arrays.toString(shape));
  13925. - }
  13926. + /**
  13927. + * Converts an {@link ImageFormat} value into the corresponding color space type.
  13928. + *
  13929. + * @throws IllegalArgumentException if the config is unsupported
  13930. + */
  13931. + static ColorSpaceType fromImageFormat(int imageFormat) {
  13932. + switch (imageFormat) {
  13933. + case ImageFormat.NV21:
  13934. + return ColorSpaceType.NV21;
  13935. + case ImageFormat.YV12:
  13936. + return ColorSpaceType.YV12;
  13937. + case ImageFormat.YUV_420_888:
  13938. + return ColorSpaceType.YUV_420_888;
  13939. + default:
  13940. + throw new IllegalArgumentException(
  13941. + "ImageFormat: " + imageFormat + ", is not supported yet.");
  13942. + }
  13943. }
  13944. - @Override
  13945. - int getNumElements(int height, int width) {
  13946. - return height * width * CHANNEL_VALUE;
  13947. + public int getValue() {
  13948. + return value;
  13949. }
  13950. - @Override
  13951. - String getShapeInfoMessage() {
  13952. - return "The shape of a RGB image should be (h, w, c) or (1, h, w, c), and channels"
  13953. - + " representing R, G, B in order. ";
  13954. + /**
  13955. + * Verifies if the given shape matches the color space type.
  13956. + *
  13957. + * @throws IllegalArgumentException if {@code shape} does not match the color space type
  13958. + * @throws UnsupportedOperationException if the color space type is not RGB or GRAYSCALE
  13959. + */
  13960. + void assertShape(int[] shape) {
  13961. + assertRgbOrGrayScale("assertShape()");
  13962. +
  13963. + int[] normalizedShape = getNormalizedShape(shape);
  13964. + checkArgument(isValidNormalizedShape(normalizedShape),
  13965. + getShapeInfoMessage() + "The provided image shape is " + Arrays.toString(shape));
  13966. }
  13967. - @Override
  13968. - Config toBitmapConfig() {
  13969. - return Config.ARGB_8888;
  13970. + /**
  13971. + * Verifies if the given {@code numElements} in an image buffer matches {@code height} / {@code
  13972. + * width} under this color space type. For example, the {@code numElements} of an RGB image of
  13973. + * 30 x 20 should be {@code 30 * 20 * 3 = 1800}; the {@code numElements} of a NV21 image of 30 x
  13974. + * 20 should be {@code 30 * 20 + ((30 + 1) / 2 * (20 + 1) / 2) * 2 = 952}.
  13975. + *
  13976. + * @throws IllegalArgumentException if {@code shape} does not match the color space type
  13977. + */
  13978. + void assertNumElements(int numElements, int height, int width) {
  13979. + checkArgument(numElements >= getNumElements(height, width),
  13980. + String.format(
  13981. + "The given number of elements (%d) does not match the image (%s) in %d x %d. The"
  13982. + + " expected number of elements should be at least %d.",
  13983. + numElements, this.name(), height, width, getNumElements(height, width)));
  13984. }
  13985. - },
  13986. -
  13987. - /** Each pixel is a single element representing only the amount of light. */
  13988. - GRAYSCALE(1) {
  13989. -
  13990. - // The channel axis should always be 1 for grayscale images.
  13991. - private static final int CHANNEL_VALUE = 1;
  13992. - @Override
  13993. + /**
  13994. + * Converts a {@link TensorBuffer} that represents an image to a Bitmap with the color space
  13995. + * type.
  13996. + *
  13997. + * @throws IllegalArgumentException if the shape of buffer does not match the color space type,
  13998. + * @throws UnsupportedOperationException if the color space type is not RGB or GRAYSCALE
  13999. + */
  14000. Bitmap convertTensorBufferToBitmap(TensorBuffer buffer) {
  14001. - return ImageConversions.convertGrayscaleTensorBufferToBitmap(buffer);
  14002. + throw new UnsupportedOperationException(
  14003. + "convertTensorBufferToBitmap() is unsupported for the color space type "
  14004. + + this.name());
  14005. }
  14006. - @Override
  14007. - int getChannelValue() {
  14008. - return CHANNEL_VALUE;
  14009. + /**
  14010. + * Returns the width of the given shape corresponding to the color space type.
  14011. + *
  14012. + * @throws IllegalArgumentException if {@code shape} does not match the color space type
  14013. + * @throws UnsupportedOperationException if the color space type is not RGB or GRAYSCALE
  14014. + */
  14015. + int getWidth(int[] shape) {
  14016. + assertRgbOrGrayScale("getWidth()");
  14017. + assertShape(shape);
  14018. + return getNormalizedShape(shape)[WIDTH_DIM];
  14019. }
  14020. - @Override
  14021. - int[] getNormalizedShape(int[] shape) {
  14022. - switch (shape.length) {
  14023. - // The shape is in (h, w) format.
  14024. - case 2:
  14025. - int[] shapeWithBatch = insertValue(shape, BATCH_DIM, BATCH_VALUE);
  14026. - return insertValue(shapeWithBatch, CHANNEL_DIM, CHANNEL_VALUE);
  14027. - case 4:
  14028. - return shape;
  14029. - default:
  14030. - // (1, h, w) and (h, w, 1) are potential grayscale image shapes. However, since they
  14031. - // both have three dimensions, it will require extra info to differentiate between them.
  14032. - // Since we haven't encountered real use cases of these two shapes, they are not supported
  14033. - // at this moment to avoid confusion. We may want to revisit it in the future.
  14034. - throw new IllegalArgumentException(
  14035. - getShapeInfoMessage() + "The provided image shape is " + Arrays.toString(shape));
  14036. - }
  14037. + /**
  14038. + * Returns the height of the given shape corresponding to the color space type.
  14039. + *
  14040. + * @throws IllegalArgumentException if {@code shape} does not match the color space type
  14041. + * @throws UnsupportedOperationException if the color space type is not RGB or GRAYSCALE
  14042. + */
  14043. + int getHeight(int[] shape) {
  14044. + assertRgbOrGrayScale("getHeight()");
  14045. + assertShape(shape);
  14046. + return getNormalizedShape(shape)[HEIGHT_DIM];
  14047. }
  14048. - @Override
  14049. - int getNumElements(int height, int width) {
  14050. - return height * width;
  14051. + /**
  14052. + * Returns the channel value corresponding to the color space type.
  14053. + *
  14054. + * @throws UnsupportedOperationException if the color space type is not RGB or GRAYSCALE
  14055. + */
  14056. + int getChannelValue() {
  14057. + throw new UnsupportedOperationException(
  14058. + "getChannelValue() is unsupported for the color space type " + this.name());
  14059. + }
  14060. + /**
  14061. + * Gets the normalized shape in the form of (1, h, w, c). Sometimes, a given shape may not have
  14062. + * batch or channel axis.
  14063. + *
  14064. + * @throws UnsupportedOperationException if the color space type is not RGB or GRAYSCALE
  14065. + */
  14066. + int[] getNormalizedShape(int[] shape) {
  14067. + throw new UnsupportedOperationException(
  14068. + "getNormalizedShape() is unsupported for the color space type " + this.name());
  14069. }
  14070. - @Override
  14071. + /**
  14072. + * Returns the shape information corresponding to the color space type.
  14073. + *
  14074. + * @throws UnsupportedOperationException if the color space type is not RGB or GRAYSCALE
  14075. + */
  14076. String getShapeInfoMessage() {
  14077. - return "The shape of a grayscale image should be (h, w) or (1, h, w, 1). ";
  14078. + throw new UnsupportedOperationException(
  14079. + "getShapeInfoMessage() is unsupported for the color space type " + this.name());
  14080. }
  14081. - @Override
  14082. + /**
  14083. + * Converts the color space type to the corresponding bitmap config.
  14084. + *
  14085. + * @throws UnsupportedOperationException if the color space type is not RGB or GRAYSCALE
  14086. + */
  14087. Config toBitmapConfig() {
  14088. - return Config.ALPHA_8;
  14089. + throw new UnsupportedOperationException(
  14090. + "toBitmapConfig() is unsupported for the color space type " + this.name());
  14091. }
  14092. - },
  14093. - /** YUV420sp format, encoded as "YYYYYYYY UVUV". */
  14094. - NV12(2) {
  14095. - @Override
  14096. - int getNumElements(int height, int width) {
  14097. - return getYuv420NumElements(height, width);
  14098. - }
  14099. - },
  14100. -
  14101. - /**
  14102. - * YUV420sp format, encoded as "YYYYYYYY VUVU", the standard picture format on Android Camera1
  14103. - * preview.
  14104. - */
  14105. - NV21(3) {
  14106. - @Override
  14107. - int getNumElements(int height, int width) {
  14108. - return getYuv420NumElements(height, width);
  14109. - }
  14110. - },
  14111. + /**
  14112. + * Gets the number of elements given the height and width of an image. For example, the number
  14113. + * of elements of an RGB image of 30 x 20 is {@code 30 * 20 * 3 = 1800}; the number of elements
  14114. + * of a NV21 image of 30 x 20 is {@code 30 * 20 + ((30 + 1) / 2 * (20 + 1) / 2) * 2 = 952}.
  14115. + */
  14116. + abstract int getNumElements(int height, int width);
  14117. - /** YUV420p format, encoded as "YYYYYYYY VV UU". */
  14118. - YV12(4) {
  14119. - @Override
  14120. - int getNumElements(int height, int width) {
  14121. - return getYuv420NumElements(height, width);
  14122. + private static int getYuv420NumElements(int height, int width) {
  14123. + // Height and width of U/V planes are half of the Y plane.
  14124. + return height * width + ((height + 1) / 2) * ((width + 1) / 2) * 2;
  14125. }
  14126. - },
  14127. - /** YUV420p format, encoded as "YYYYYYYY UU VV". */
  14128. - YV21(5) {
  14129. - @Override
  14130. - int getNumElements(int height, int width) {
  14131. - return getYuv420NumElements(height, width);
  14132. + /** Inserts a value at the specified position and return the new array. */
  14133. + private static int[] insertValue(int[] array, int pos, int value) {
  14134. + int[] newArray = new int[array.length + 1];
  14135. + for (int i = 0; i < pos; i++) {
  14136. + newArray[i] = array[i];
  14137. + }
  14138. + newArray[pos] = value;
  14139. + for (int i = pos + 1; i < newArray.length; i++) {
  14140. + newArray[i] = array[i - 1];
  14141. + }
  14142. + return newArray;
  14143. }
  14144. - },
  14145. -
  14146. - /**
  14147. - * YUV420 format corresponding to {@link android.graphics.ImageFormat#YUV_420_888}. The actual
  14148. - * encoding format (i.e. NV12 / Nv21 / YV12 / YV21) depends on the implementation of the image.
  14149. - *
  14150. - * <p>Use this format only when you load an {@link android.media.Image}.
  14151. - */
  14152. - YUV_420_888(6) {
  14153. - @Override
  14154. - int getNumElements(int height, int width) {
  14155. - return getYuv420NumElements(height, width);
  14156. - }
  14157. - };
  14158. -
  14159. - private static final int BATCH_DIM = 0; // The first element of the normalizaed shape.
  14160. - private static final int BATCH_VALUE = 1; // The batch axis should always be one.
  14161. - private static final int HEIGHT_DIM = 1; // The second element of the normalizaed shape.
  14162. - private static final int WIDTH_DIM = 2; // The third element of the normalizaed shape.
  14163. - private static final int CHANNEL_DIM = 3; // The fourth element of the normalizaed shape.
  14164. - private final int value;
  14165. -
  14166. - ColorSpaceType(int value) {
  14167. - this.value = value;
  14168. - }
  14169. -
  14170. - /**
  14171. - * Converts a bitmap configuration into the corresponding color space type.
  14172. - *
  14173. - * @throws IllegalArgumentException if the config is unsupported
  14174. - */
  14175. - static ColorSpaceType fromBitmapConfig(Config config) {
  14176. - switch (config) {
  14177. - case ARGB_8888:
  14178. - return ColorSpaceType.RGB;
  14179. - case ALPHA_8:
  14180. - return ColorSpaceType.GRAYSCALE;
  14181. - default:
  14182. - throw new IllegalArgumentException(
  14183. - "Bitmap configuration: " + config + ", is not supported yet.");
  14184. - }
  14185. - }
  14186. -
  14187. - /**
  14188. - * Converts an {@link ImageFormat} value into the corresponding color space type.
  14189. - *
  14190. - * @throws IllegalArgumentException if the config is unsupported
  14191. - */
  14192. - static ColorSpaceType fromImageFormat(int imageFormat) {
  14193. - switch (imageFormat) {
  14194. - case ImageFormat.NV21:
  14195. - return ColorSpaceType.NV21;
  14196. - case ImageFormat.YV12:
  14197. - return ColorSpaceType.YV12;
  14198. - case ImageFormat.YUV_420_888:
  14199. - return ColorSpaceType.YUV_420_888;
  14200. - default:
  14201. - throw new IllegalArgumentException(
  14202. - "ImageFormat: " + imageFormat + ", is not supported yet.");
  14203. - }
  14204. - }
  14205. -
  14206. - public int getValue() {
  14207. - return value;
  14208. - }
  14209. -
  14210. - /**
  14211. - * Verifies if the given shape matches the color space type.
  14212. - *
  14213. - * @throws IllegalArgumentException if {@code shape} does not match the color space type
  14214. - * @throws UnsupportedOperationException if the color space type is not RGB or GRAYSCALE
  14215. - */
  14216. - void assertShape(int[] shape) {
  14217. - assertRgbOrGrayScale("assertShape()");
  14218. -
  14219. - int[] normalizedShape = getNormalizedShape(shape);
  14220. - checkArgument(
  14221. - isValidNormalizedShape(normalizedShape),
  14222. - getShapeInfoMessage() + "The provided image shape is " + Arrays.toString(shape));
  14223. - }
  14224. -
  14225. - /**
  14226. - * Verifies if the given {@code numElements} in an image buffer matches {@code height} / {@code
  14227. - * width} under this color space type. For example, the {@code numElements} of an RGB image of 30
  14228. - * x 20 should be {@code 30 * 20 * 3 = 1800}; the {@code numElements} of a NV21 image of 30 x 20
  14229. - * should be {@code 30 * 20 + ((30 + 1) / 2 * (20 + 1) / 2) * 2 = 952}.
  14230. - *
  14231. - * @throws IllegalArgumentException if {@code shape} does not match the color space type
  14232. - */
  14233. - void assertNumElements(int numElements, int height, int width) {
  14234. - checkArgument(
  14235. - numElements >= getNumElements(height, width),
  14236. - String.format(
  14237. - "The given number of elements (%d) does not match the image (%s) in %d x %d. The"
  14238. - + " expected number of elements should be at least %d.",
  14239. - numElements, this.name(), height, width, getNumElements(height, width)));
  14240. - }
  14241. -
  14242. - /**
  14243. - * Converts a {@link TensorBuffer} that represents an image to a Bitmap with the color space type.
  14244. - *
  14245. - * @throws IllegalArgumentException if the shape of buffer does not match the color space type,
  14246. - * @throws UnsupportedOperationException if the color space type is not RGB or GRAYSCALE
  14247. - */
  14248. - Bitmap convertTensorBufferToBitmap(TensorBuffer buffer) {
  14249. - throw new UnsupportedOperationException(
  14250. - "convertTensorBufferToBitmap() is unsupported for the color space type " + this.name());
  14251. - }
  14252. -
  14253. - /**
  14254. - * Returns the width of the given shape corresponding to the color space type.
  14255. - *
  14256. - * @throws IllegalArgumentException if {@code shape} does not match the color space type
  14257. - * @throws UnsupportedOperationException if the color space type is not RGB or GRAYSCALE
  14258. - */
  14259. - int getWidth(int[] shape) {
  14260. - assertRgbOrGrayScale("getWidth()");
  14261. - assertShape(shape);
  14262. - return getNormalizedShape(shape)[WIDTH_DIM];
  14263. - }
  14264. -
  14265. - /**
  14266. - * Returns the height of the given shape corresponding to the color space type.
  14267. - *
  14268. - * @throws IllegalArgumentException if {@code shape} does not match the color space type
  14269. - * @throws UnsupportedOperationException if the color space type is not RGB or GRAYSCALE
  14270. - */
  14271. - int getHeight(int[] shape) {
  14272. - assertRgbOrGrayScale("getHeight()");
  14273. - assertShape(shape);
  14274. - return getNormalizedShape(shape)[HEIGHT_DIM];
  14275. - }
  14276. -
  14277. - /**
  14278. - * Returns the channel value corresponding to the color space type.
  14279. - *
  14280. - * @throws UnsupportedOperationException if the color space type is not RGB or GRAYSCALE
  14281. - */
  14282. - int getChannelValue() {
  14283. - throw new UnsupportedOperationException(
  14284. - "getChannelValue() is unsupported for the color space type " + this.name());
  14285. - }
  14286. - /**
  14287. - * Gets the normalized shape in the form of (1, h, w, c). Sometimes, a given shape may not have
  14288. - * batch or channel axis.
  14289. - *
  14290. - * @throws UnsupportedOperationException if the color space type is not RGB or GRAYSCALE
  14291. - */
  14292. - int[] getNormalizedShape(int[] shape) {
  14293. - throw new UnsupportedOperationException(
  14294. - "getNormalizedShape() is unsupported for the color space type " + this.name());
  14295. - }
  14296. -
  14297. - /**
  14298. - * Returns the shape information corresponding to the color space type.
  14299. - *
  14300. - * @throws UnsupportedOperationException if the color space type is not RGB or GRAYSCALE
  14301. - */
  14302. - String getShapeInfoMessage() {
  14303. - throw new UnsupportedOperationException(
  14304. - "getShapeInfoMessage() is unsupported for the color space type " + this.name());
  14305. - }
  14306. -
  14307. - /**
  14308. - * Converts the color space type to the corresponding bitmap config.
  14309. - *
  14310. - * @throws UnsupportedOperationException if the color space type is not RGB or GRAYSCALE
  14311. - */
  14312. - Config toBitmapConfig() {
  14313. - throw new UnsupportedOperationException(
  14314. - "toBitmapConfig() is unsupported for the color space type " + this.name());
  14315. - }
  14316. -
  14317. - /**
  14318. - * Gets the number of elements given the height and width of an image. For example, the number of
  14319. - * elements of an RGB image of 30 x 20 is {@code 30 * 20 * 3 = 1800}; the number of elements of a
  14320. - * NV21 image of 30 x 20 is {@code 30 * 20 + ((30 + 1) / 2 * (20 + 1) / 2) * 2 = 952}.
  14321. - */
  14322. - abstract int getNumElements(int height, int width);
  14323. -
  14324. - private static int getYuv420NumElements(int height, int width) {
  14325. - // Height and width of U/V planes are half of the Y plane.
  14326. - return height * width + ((height + 1) / 2) * ((width + 1) / 2) * 2;
  14327. - }
  14328. -
  14329. - /** Inserts a value at the specified position and return the new array. */
  14330. - private static int[] insertValue(int[] array, int pos, int value) {
  14331. - int[] newArray = new int[array.length + 1];
  14332. - for (int i = 0; i < pos; i++) {
  14333. - newArray[i] = array[i];
  14334. - }
  14335. - newArray[pos] = value;
  14336. - for (int i = pos + 1; i < newArray.length; i++) {
  14337. - newArray[i] = array[i - 1];
  14338. +
  14339. + protected boolean isValidNormalizedShape(int[] shape) {
  14340. + return shape[BATCH_DIM] == BATCH_VALUE && shape[HEIGHT_DIM] > 0 && shape[WIDTH_DIM] > 0
  14341. + && shape[CHANNEL_DIM] == getChannelValue();
  14342. }
  14343. - return newArray;
  14344. - }
  14345. -
  14346. - protected boolean isValidNormalizedShape(int[] shape) {
  14347. - return shape[BATCH_DIM] == BATCH_VALUE
  14348. - && shape[HEIGHT_DIM] > 0
  14349. - && shape[WIDTH_DIM] > 0
  14350. - && shape[CHANNEL_DIM] == getChannelValue();
  14351. - }
  14352. -
  14353. - /** Some existing methods are only valid for RGB and GRAYSCALE images. */
  14354. - private void assertRgbOrGrayScale(String unsupportedMethodName) {
  14355. - if (this != ColorSpaceType.RGB && this != ColorSpaceType.GRAYSCALE) {
  14356. - throw new UnsupportedOperationException(
  14357. - unsupportedMethodName
  14358. - + " only supports RGB and GRAYSCALE formats, but not "
  14359. - + this.name());
  14360. +
  14361. + /** Some existing methods are only valid for RGB and GRAYSCALE images. */
  14362. + private void assertRgbOrGrayScale(String unsupportedMethodName) {
  14363. + if (this != ColorSpaceType.RGB && this != ColorSpaceType.GRAYSCALE) {
  14364. + throw new UnsupportedOperationException(unsupportedMethodName
  14365. + + " only supports RGB and GRAYSCALE formats, but not " + this.name());
  14366. + }
  14367. }
  14368. - }
  14369. }
  14370. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageContainer.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageContainer.java
  14371. index 379d14798d62d..5c097da5ecb6d 100644
  14372. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageContainer.java
  14373. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageContainer.java
  14374. @@ -17,6 +17,7 @@ package org.tensorflow.lite.support.image;
  14375. import android.graphics.Bitmap;
  14376. import android.media.Image;
  14377. +
  14378. import org.tensorflow.lite.DataType;
  14379. import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
  14380. @@ -32,28 +33,27 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
  14381. * </ul>
  14382. */
  14383. interface ImageContainer {
  14384. + /** Performs deep copy of the {@link ImageContainer}. */
  14385. + ImageContainer clone();
  14386. - /** Performs deep copy of the {@link ImageContainer}. */
  14387. - ImageContainer clone();
  14388. -
  14389. - /** Returns the width of the image. */
  14390. - int getWidth();
  14391. + /** Returns the width of the image. */
  14392. + int getWidth();
  14393. - /** Returns the height of the image. */
  14394. - int getHeight();
  14395. + /** Returns the height of the image. */
  14396. + int getHeight();
  14397. - /** Gets the {@link Bitmap} representation of the underlying image format. */
  14398. - Bitmap getBitmap();
  14399. + /** Gets the {@link Bitmap} representation of the underlying image format. */
  14400. + Bitmap getBitmap();
  14401. - /**
  14402. - * Gets the {@link TensorBuffer} representation with the specific {@code dataType} of the
  14403. - * underlying image format.
  14404. - */
  14405. - TensorBuffer getTensorBuffer(DataType dataType);
  14406. + /**
  14407. + * Gets the {@link TensorBuffer} representation with the specific {@code dataType} of the
  14408. + * underlying image format.
  14409. + */
  14410. + TensorBuffer getTensorBuffer(DataType dataType);
  14411. - /** Gets the {@link Image} representation of the underlying image format. */
  14412. - Image getMediaImage();
  14413. + /** Gets the {@link Image} representation of the underlying image format. */
  14414. + Image getMediaImage();
  14415. - /** Returns the color space type of the image. */
  14416. - ColorSpaceType getColorSpaceType();
  14417. + /** Returns the color space type of the image. */
  14418. + ColorSpaceType getColorSpaceType();
  14419. }
  14420. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageConversions.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageConversions.java
  14421. index 8ed169c49348e..7ed5306fd9f96 100644
  14422. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageConversions.java
  14423. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageConversions.java
  14424. @@ -17,128 +17,127 @@ package org.tensorflow.lite.support.image;
  14425. import android.graphics.Bitmap;
  14426. import android.graphics.Color;
  14427. -import java.nio.ByteBuffer;
  14428. -import java.nio.ByteOrder;
  14429. +
  14430. import org.tensorflow.lite.DataType;
  14431. import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
  14432. +import java.nio.ByteBuffer;
  14433. +import java.nio.ByteOrder;
  14434. +
  14435. /**
  14436. * Implements some stateless image conversion methods.
  14437. *
  14438. * <p>This class is an internal helper for {@link org.tensorflow.lite.support.image}.
  14439. */
  14440. class ImageConversions {
  14441. + /**
  14442. + * Converts a {@link TensorBuffer} that represents a RGB image to an ARGB_8888 Bitmap.
  14443. + *
  14444. + * <p>Data in buffer will be converted into integer to match the Bitmap API.
  14445. + *
  14446. + * @param buffer a RGB image. Its shape should be either (h, w, 3) or (1, h, w, 3)
  14447. + * @throws IllegalArgumentException if the shape of buffer is neither (h, w, 3) nor (1, h, w, 3)
  14448. + */
  14449. + static Bitmap convertRgbTensorBufferToBitmap(TensorBuffer buffer) {
  14450. + int[] shape = buffer.getShape();
  14451. + ColorSpaceType rgb = ColorSpaceType.RGB;
  14452. + rgb.assertShape(shape);
  14453. - /**
  14454. - * Converts a {@link TensorBuffer} that represents a RGB image to an ARGB_8888 Bitmap.
  14455. - *
  14456. - * <p>Data in buffer will be converted into integer to match the Bitmap API.
  14457. - *
  14458. - * @param buffer a RGB image. Its shape should be either (h, w, 3) or (1, h, w, 3)
  14459. - * @throws IllegalArgumentException if the shape of buffer is neither (h, w, 3) nor (1, h, w, 3)
  14460. - */
  14461. - static Bitmap convertRgbTensorBufferToBitmap(TensorBuffer buffer) {
  14462. - int[] shape = buffer.getShape();
  14463. - ColorSpaceType rgb = ColorSpaceType.RGB;
  14464. - rgb.assertShape(shape);
  14465. -
  14466. - int h = rgb.getHeight(shape);
  14467. - int w = rgb.getWidth(shape);
  14468. - Bitmap bitmap = Bitmap.createBitmap(w, h, rgb.toBitmapConfig());
  14469. -
  14470. - // TODO(b/138904567): Find a way to avoid creating multiple intermediate buffers every time.
  14471. - int[] intValues = new int[w * h];
  14472. - int[] rgbValues = buffer.getIntArray();
  14473. - for (int i = 0, j = 0; i < intValues.length; i++) {
  14474. - int r = rgbValues[j++];
  14475. - int g = rgbValues[j++];
  14476. - int b = rgbValues[j++];
  14477. - intValues[i] = Color.rgb(r, g, b);
  14478. - }
  14479. - bitmap.setPixels(intValues, 0, w, 0, 0, w, h);
  14480. -
  14481. - return bitmap;
  14482. - }
  14483. -
  14484. - /**
  14485. - * Converts a {@link TensorBuffer} that represents a grayscale image to an ALPHA_8 Bitmap.
  14486. - *
  14487. - * <p>Data in buffer will be converted into integer to match the Bitmap API.
  14488. - *
  14489. - * @param buffer a grayscale image. Its shape should be either (h, w) or (1, h, w)
  14490. - * @throws IllegalArgumentException if the shape of buffer is neither (h, w) nor (1, h, w, 1)
  14491. - */
  14492. - static Bitmap convertGrayscaleTensorBufferToBitmap(TensorBuffer buffer) {
  14493. - // Convert buffer into Uint8 as needed.
  14494. - TensorBuffer uint8Buffer =
  14495. - buffer.getDataType() == DataType.UINT8
  14496. - ? buffer
  14497. - : TensorBuffer.createFrom(buffer, DataType.UINT8);
  14498. -
  14499. - int[] shape = uint8Buffer.getShape();
  14500. - ColorSpaceType grayscale = ColorSpaceType.GRAYSCALE;
  14501. - grayscale.assertShape(shape);
  14502. -
  14503. - // Even though `Bitmap.createBitmap(int[] colors, int width, int height, Bitmap.Config config)`
  14504. - // seems to work for internal Android testing framework, but it actually doesn't work for the
  14505. - // real Android environment.
  14506. - //
  14507. - // The only reliable way to create an ALPHA_8 Bitmap is to use `copyPixelsFromBuffer()` to load
  14508. - // the pixels from a ByteBuffer, and then use `copyPixelsToBuffer` to read out.
  14509. - // Note: for ALPHA_8 Bitmap, methods such as, `setPixels()` and `getPixels()` do not work.
  14510. - Bitmap bitmap =
  14511. - Bitmap.createBitmap(
  14512. - grayscale.getWidth(shape), grayscale.getHeight(shape), grayscale.toBitmapConfig());
  14513. - uint8Buffer.getBuffer().rewind();
  14514. - bitmap.copyPixelsFromBuffer(uint8Buffer.getBuffer());
  14515. - return bitmap;
  14516. - }
  14517. -
  14518. - /**
  14519. - * Converts an Image in a Bitmap to a TensorBuffer (3D Tensor: Width-Height-Channel) whose memory
  14520. - * is already allocated, or could be dynamically allocated.
  14521. - *
  14522. - * @param bitmap The Bitmap object representing the image. Currently we only support ARGB_8888
  14523. - * config.
  14524. - * @param buffer The destination of the conversion. Needs to be created in advance. If it's
  14525. - * fixed-size, its flat size should be w*h*3.
  14526. - * @throws IllegalArgumentException if the buffer is fixed-size, but the size doesn't match.
  14527. - */
  14528. - static void convertBitmapToTensorBuffer(Bitmap bitmap, TensorBuffer buffer) {
  14529. - int w = bitmap.getWidth();
  14530. - int h = bitmap.getHeight();
  14531. - int[] intValues = new int[w * h];
  14532. - bitmap.getPixels(intValues, 0, w, 0, 0, w, h);
  14533. - // TODO(b/138904567): Find a way to avoid creating multiple intermediate buffers every time.
  14534. - int[] shape = new int[] {h, w, 3};
  14535. - switch (buffer.getDataType()) {
  14536. - case UINT8:
  14537. - byte[] byteArr = new byte[w * h * 3];
  14538. + int h = rgb.getHeight(shape);
  14539. + int w = rgb.getWidth(shape);
  14540. + Bitmap bitmap = Bitmap.createBitmap(w, h, rgb.toBitmapConfig());
  14541. +
  14542. + // TODO(b/138904567): Find a way to avoid creating multiple intermediate buffers every time.
  14543. + int[] intValues = new int[w * h];
  14544. + int[] rgbValues = buffer.getIntArray();
  14545. for (int i = 0, j = 0; i < intValues.length; i++) {
  14546. - byteArr[j++] = (byte) ((intValues[i] >> 16) & 0xff);
  14547. - byteArr[j++] = (byte) ((intValues[i] >> 8) & 0xff);
  14548. - byteArr[j++] = (byte) (intValues[i] & 0xff);
  14549. + int r = rgbValues[j++];
  14550. + int g = rgbValues[j++];
  14551. + int b = rgbValues[j++];
  14552. + intValues[i] = Color.rgb(r, g, b);
  14553. }
  14554. - ByteBuffer byteBuffer = ByteBuffer.wrap(byteArr);
  14555. - byteBuffer.order(ByteOrder.nativeOrder());
  14556. - buffer.loadBuffer(byteBuffer, shape);
  14557. - break;
  14558. - case FLOAT32:
  14559. - float[] floatArr = new float[w * h * 3];
  14560. - for (int i = 0, j = 0; i < intValues.length; i++) {
  14561. - floatArr[j++] = (float) ((intValues[i] >> 16) & 0xff);
  14562. - floatArr[j++] = (float) ((intValues[i] >> 8) & 0xff);
  14563. - floatArr[j++] = (float) (intValues[i] & 0xff);
  14564. + bitmap.setPixels(intValues, 0, w, 0, 0, w, h);
  14565. +
  14566. + return bitmap;
  14567. + }
  14568. +
  14569. + /**
  14570. + * Converts a {@link TensorBuffer} that represents a grayscale image to an ALPHA_8 Bitmap.
  14571. + *
  14572. + * <p>Data in buffer will be converted into integer to match the Bitmap API.
  14573. + *
  14574. + * @param buffer a grayscale image. Its shape should be either (h, w) or (1, h, w)
  14575. + * @throws IllegalArgumentException if the shape of buffer is neither (h, w) nor (1, h, w, 1)
  14576. + */
  14577. + static Bitmap convertGrayscaleTensorBufferToBitmap(TensorBuffer buffer) {
  14578. + // Convert buffer into Uint8 as needed.
  14579. + TensorBuffer uint8Buffer = buffer.getDataType() == DataType.UINT8
  14580. + ? buffer
  14581. + : TensorBuffer.createFrom(buffer, DataType.UINT8);
  14582. +
  14583. + int[] shape = uint8Buffer.getShape();
  14584. + ColorSpaceType grayscale = ColorSpaceType.GRAYSCALE;
  14585. + grayscale.assertShape(shape);
  14586. +
  14587. + // Even though `Bitmap.createBitmap(int[] colors, int width, int height, Bitmap.Config
  14588. + // config)` seems to work for internal Android testing framework, but it actually doesn't
  14589. + // work for the real Android environment.
  14590. + //
  14591. + // The only reliable way to create an ALPHA_8 Bitmap is to use `copyPixelsFromBuffer()` to
  14592. + // load the pixels from a ByteBuffer, and then use `copyPixelsToBuffer` to read out. Note:
  14593. + // for ALPHA_8 Bitmap, methods such as, `setPixels()` and `getPixels()` do not work.
  14594. + Bitmap bitmap = Bitmap.createBitmap(
  14595. + grayscale.getWidth(shape), grayscale.getHeight(shape), grayscale.toBitmapConfig());
  14596. + uint8Buffer.getBuffer().rewind();
  14597. + bitmap.copyPixelsFromBuffer(uint8Buffer.getBuffer());
  14598. + return bitmap;
  14599. + }
  14600. +
  14601. + /**
  14602. + * Converts an Image in a Bitmap to a TensorBuffer (3D Tensor: Width-Height-Channel) whose
  14603. + * memory is already allocated, or could be dynamically allocated.
  14604. + *
  14605. + * @param bitmap The Bitmap object representing the image. Currently we only support ARGB_8888
  14606. + * config.
  14607. + * @param buffer The destination of the conversion. Needs to be created in advance. If it's
  14608. + * fixed-size, its flat size should be w*h*3.
  14609. + * @throws IllegalArgumentException if the buffer is fixed-size, but the size doesn't match.
  14610. + */
  14611. + static void convertBitmapToTensorBuffer(Bitmap bitmap, TensorBuffer buffer) {
  14612. + int w = bitmap.getWidth();
  14613. + int h = bitmap.getHeight();
  14614. + int[] intValues = new int[w * h];
  14615. + bitmap.getPixels(intValues, 0, w, 0, 0, w, h);
  14616. + // TODO(b/138904567): Find a way to avoid creating multiple intermediate buffers every time.
  14617. + int[] shape = new int[] {h, w, 3};
  14618. + switch (buffer.getDataType()) {
  14619. + case UINT8:
  14620. + byte[] byteArr = new byte[w * h * 3];
  14621. + for (int i = 0, j = 0; i < intValues.length; i++) {
  14622. + byteArr[j++] = (byte) ((intValues[i] >> 16) & 0xff);
  14623. + byteArr[j++] = (byte) ((intValues[i] >> 8) & 0xff);
  14624. + byteArr[j++] = (byte) (intValues[i] & 0xff);
  14625. + }
  14626. + ByteBuffer byteBuffer = ByteBuffer.wrap(byteArr);
  14627. + byteBuffer.order(ByteOrder.nativeOrder());
  14628. + buffer.loadBuffer(byteBuffer, shape);
  14629. + break;
  14630. + case FLOAT32:
  14631. + float[] floatArr = new float[w * h * 3];
  14632. + for (int i = 0, j = 0; i < intValues.length; i++) {
  14633. + floatArr[j++] = (float) ((intValues[i] >> 16) & 0xff);
  14634. + floatArr[j++] = (float) ((intValues[i] >> 8) & 0xff);
  14635. + floatArr[j++] = (float) (intValues[i] & 0xff);
  14636. + }
  14637. + buffer.loadArray(floatArr, shape);
  14638. + break;
  14639. + default:
  14640. + // Should never happen.
  14641. + throw new IllegalStateException(
  14642. + "The type of TensorBuffer, " + buffer.getBuffer() + ", is unsupported.");
  14643. }
  14644. - buffer.loadArray(floatArr, shape);
  14645. - break;
  14646. - default:
  14647. - // Should never happen.
  14648. - throw new IllegalStateException(
  14649. - "The type of TensorBuffer, " + buffer.getBuffer() + ", is unsupported.");
  14650. }
  14651. - }
  14652. - // Hide the constructor as the class is static.
  14653. - private ImageConversions() {}
  14654. + // Hide the constructor as the class is static.
  14655. + private ImageConversions() {}
  14656. }
  14657. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageOperator.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageOperator.java
  14658. index 1e546634e90e7..e852569490f0b 100644
  14659. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageOperator.java
  14660. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageOperator.java
  14661. @@ -16,28 +16,29 @@ limitations under the License.
  14662. package org.tensorflow.lite.support.image;
  14663. import android.graphics.PointF;
  14664. +
  14665. import org.tensorflow.lite.support.common.Operator;
  14666. /** Operates a TensorImage object. Used in ImageProcessor. */
  14667. public interface ImageOperator extends Operator<TensorImage> {
  14668. - /** @see org.tensorflow.lite.support.common.Operator#apply(java.lang.Object) */
  14669. - @Override
  14670. - TensorImage apply(TensorImage image);
  14671. -
  14672. - /** Computes the width of the expected output image when input image size is given. */
  14673. - int getOutputImageWidth(int inputImageHeight, int inputImageWidth);
  14674. -
  14675. - /** Computes the height of the expected output image when input image size is given. */
  14676. - int getOutputImageHeight(int inputImageHeight, int inputImageWidth);
  14677. -
  14678. - /**
  14679. - * Transforms a point from coordinates system of the result image back to the one of the input
  14680. - * image.
  14681. - *
  14682. - * @param point the point from the result coordinates system.
  14683. - * @param inputImageHeight the height of input image.
  14684. - * @param inputImageWidth the width of input image.
  14685. - * @return the point with the coordinates from the coordinates system of the input image.
  14686. - */
  14687. - PointF inverseTransform(PointF point, int inputImageHeight, int inputImageWidth);
  14688. + /** @see org.tensorflow.lite.support.common.Operator#apply(java.lang.Object) */
  14689. + @Override
  14690. + TensorImage apply(TensorImage image);
  14691. +
  14692. + /** Computes the width of the expected output image when input image size is given. */
  14693. + int getOutputImageWidth(int inputImageHeight, int inputImageWidth);
  14694. +
  14695. + /** Computes the height of the expected output image when input image size is given. */
  14696. + int getOutputImageHeight(int inputImageHeight, int inputImageWidth);
  14697. +
  14698. + /**
  14699. + * Transforms a point from coordinates system of the result image back to the one of the input
  14700. + * image.
  14701. + *
  14702. + * @param point the point from the result coordinates system.
  14703. + * @param inputImageHeight the height of input image.
  14704. + * @param inputImageWidth the width of input image.
  14705. + * @return the point with the coordinates from the coordinates system of the input image.
  14706. + */
  14707. + PointF inverseTransform(PointF point, int inputImageHeight, int inputImageWidth);
  14708. }
  14709. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageProcessor.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageProcessor.java
  14710. index c44aa9efad708..c7d51355920ee 100644
  14711. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageProcessor.java
  14712. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageProcessor.java
  14713. @@ -20,9 +20,7 @@ import static java.lang.Math.min;
  14714. import android.graphics.PointF;
  14715. import android.graphics.RectF;
  14716. -import java.util.ArrayList;
  14717. -import java.util.List;
  14718. -import java.util.ListIterator;
  14719. +
  14720. import org.tensorflow.lite.support.common.Operator;
  14721. import org.tensorflow.lite.support.common.SequentialProcessor;
  14722. import org.tensorflow.lite.support.common.TensorOperator;
  14723. @@ -30,6 +28,10 @@ import org.tensorflow.lite.support.common.internal.SupportPreconditions;
  14724. import org.tensorflow.lite.support.image.ops.Rot90Op;
  14725. import org.tensorflow.lite.support.image.ops.TensorOperatorWrapper;
  14726. +import java.util.ArrayList;
  14727. +import java.util.List;
  14728. +import java.util.ListIterator;
  14729. +
  14730. /**
  14731. * ImageProcessor is a helper class for preprocessing and postprocessing {@link TensorImage}. It
  14732. * could transform a {@link TensorImage} to another by executing a chain of {@link ImageOperator}.
  14733. @@ -55,156 +57,159 @@ import org.tensorflow.lite.support.image.ops.TensorOperatorWrapper;
  14734. * @see ImageProcessor#process(TensorImage) to apply the processor on a {@code TensorImage}
  14735. */
  14736. public class ImageProcessor extends SequentialProcessor<TensorImage> {
  14737. - private ImageProcessor(Builder builder) {
  14738. - super(builder);
  14739. - }
  14740. -
  14741. - /**
  14742. - * Transforms a point from coordinates system of the result image back to the one of the input
  14743. - * image.
  14744. - *
  14745. - * @param point the point from the result coordinates system.
  14746. - * @param inputImageHeight the height of input image.
  14747. - * @param inputImageWidth the width of input image.
  14748. - * @return the point with the coordinates from the coordinates system of the input image.
  14749. - */
  14750. - public PointF inverseTransform(PointF point, int inputImageHeight, int inputImageWidth) {
  14751. - List<Integer> widths = new ArrayList<>();
  14752. - List<Integer> heights = new ArrayList<>();
  14753. - int currentWidth = inputImageWidth;
  14754. - int currentHeight = inputImageHeight;
  14755. - for (Operator<TensorImage> op : operatorList) {
  14756. - widths.add(currentWidth);
  14757. - heights.add(currentHeight);
  14758. - ImageOperator imageOperator = (ImageOperator) op;
  14759. - int newHeight = imageOperator.getOutputImageHeight(currentHeight, currentWidth);
  14760. - int newWidth = imageOperator.getOutputImageWidth(currentHeight, currentWidth);
  14761. - currentHeight = newHeight;
  14762. - currentWidth = newWidth;
  14763. + private ImageProcessor(Builder builder) {
  14764. + super(builder);
  14765. }
  14766. - ListIterator<Operator<TensorImage>> opIterator = operatorList.listIterator(operatorList.size());
  14767. - ListIterator<Integer> widthIterator = widths.listIterator(widths.size());
  14768. - ListIterator<Integer> heightIterator = heights.listIterator(heights.size());
  14769. - while (opIterator.hasPrevious()) {
  14770. - ImageOperator imageOperator = (ImageOperator) opIterator.previous();
  14771. - int height = heightIterator.previous();
  14772. - int width = widthIterator.previous();
  14773. - point = imageOperator.inverseTransform(point, height, width);
  14774. +
  14775. + /**
  14776. + * Transforms a point from coordinates system of the result image back to the one of the input
  14777. + * image.
  14778. + *
  14779. + * @param point the point from the result coordinates system.
  14780. + * @param inputImageHeight the height of input image.
  14781. + * @param inputImageWidth the width of input image.
  14782. + * @return the point with the coordinates from the coordinates system of the input image.
  14783. + */
  14784. + public PointF inverseTransform(PointF point, int inputImageHeight, int inputImageWidth) {
  14785. + List<Integer> widths = new ArrayList<>();
  14786. + List<Integer> heights = new ArrayList<>();
  14787. + int currentWidth = inputImageWidth;
  14788. + int currentHeight = inputImageHeight;
  14789. + for (Operator<TensorImage> op : operatorList) {
  14790. + widths.add(currentWidth);
  14791. + heights.add(currentHeight);
  14792. + ImageOperator imageOperator = (ImageOperator) op;
  14793. + int newHeight = imageOperator.getOutputImageHeight(currentHeight, currentWidth);
  14794. + int newWidth = imageOperator.getOutputImageWidth(currentHeight, currentWidth);
  14795. + currentHeight = newHeight;
  14796. + currentWidth = newWidth;
  14797. + }
  14798. + ListIterator<Operator<TensorImage>> opIterator =
  14799. + operatorList.listIterator(operatorList.size());
  14800. + ListIterator<Integer> widthIterator = widths.listIterator(widths.size());
  14801. + ListIterator<Integer> heightIterator = heights.listIterator(heights.size());
  14802. + while (opIterator.hasPrevious()) {
  14803. + ImageOperator imageOperator = (ImageOperator) opIterator.previous();
  14804. + int height = heightIterator.previous();
  14805. + int width = widthIterator.previous();
  14806. + point = imageOperator.inverseTransform(point, height, width);
  14807. + }
  14808. + return point;
  14809. + }
  14810. +
  14811. + /**
  14812. + * Transforms a rectangle from coordinates system of the result image back to the one of the
  14813. + * input image.
  14814. + *
  14815. + * @param rect the rectangle from the result coordinates system.
  14816. + * @param inputImageHeight the height of input image.
  14817. + * @param inputImageWidth the width of input image.
  14818. + * @return the rectangle with the coordinates from the coordinates system of the input image.
  14819. + */
  14820. + public RectF inverseTransform(RectF rect, int inputImageHeight, int inputImageWidth) {
  14821. + // when rotation is involved, corner order may change - top left changes to bottom right,
  14822. + // .etc
  14823. + PointF p1 = inverseTransform(
  14824. + new PointF(rect.left, rect.top), inputImageHeight, inputImageWidth);
  14825. + PointF p2 = inverseTransform(
  14826. + new PointF(rect.right, rect.bottom), inputImageHeight, inputImageWidth);
  14827. + return new RectF(min(p1.x, p2.x), min(p1.y, p2.y), max(p1.x, p2.x), max(p1.y, p2.y));
  14828. }
  14829. - return point;
  14830. - }
  14831. -
  14832. - /**
  14833. - * Transforms a rectangle from coordinates system of the result image back to the one of the input
  14834. - * image.
  14835. - *
  14836. - * @param rect the rectangle from the result coordinates system.
  14837. - * @param inputImageHeight the height of input image.
  14838. - * @param inputImageWidth the width of input image.
  14839. - * @return the rectangle with the coordinates from the coordinates system of the input image.
  14840. - */
  14841. - public RectF inverseTransform(RectF rect, int inputImageHeight, int inputImageWidth) {
  14842. - // when rotation is involved, corner order may change - top left changes to bottom right, .etc
  14843. - PointF p1 =
  14844. - inverseTransform(new PointF(rect.left, rect.top), inputImageHeight, inputImageWidth);
  14845. - PointF p2 =
  14846. - inverseTransform(new PointF(rect.right, rect.bottom), inputImageHeight, inputImageWidth);
  14847. - return new RectF(min(p1.x, p2.x), min(p1.y, p2.y), max(p1.x, p2.x), max(p1.y, p2.y));
  14848. - }
  14849. -
  14850. - /**
  14851. - * Processes a {@link TensorImage} object with prepared {@link TensorOperator}.
  14852. - *
  14853. - * @throws IllegalArgumentException if the image is not supported by any op.
  14854. - */
  14855. - @Override
  14856. - public TensorImage process(TensorImage image) {
  14857. - return super.process(image);
  14858. - }
  14859. -
  14860. - /**
  14861. - * The Builder to create an ImageProcessor, which could be executed later.
  14862. - *
  14863. - * @see #add(TensorOperator) to add a general TensorOperator
  14864. - * @see #add(ImageOperator) to add an ImageOperator
  14865. - * @see #build() complete the building process and get a built Processor
  14866. - */
  14867. - public static class Builder extends SequentialProcessor.Builder<TensorImage> {
  14868. - public Builder() {
  14869. - super();
  14870. +
  14871. + /**
  14872. + * Processes a {@link TensorImage} object with prepared {@link TensorOperator}.
  14873. + *
  14874. + * @throws IllegalArgumentException if the image is not supported by any op.
  14875. + */
  14876. + @Override
  14877. + public TensorImage process(TensorImage image) {
  14878. + return super.process(image);
  14879. }
  14880. /**
  14881. - * Adds an {@link ImageOperator} into the Operator chain.
  14882. + * The Builder to create an ImageProcessor, which could be executed later.
  14883. *
  14884. - * @param op the Operator instance to be executed then
  14885. + * @see #add(TensorOperator) to add a general TensorOperator
  14886. + * @see #add(ImageOperator) to add an ImageOperator
  14887. + * @see #build() complete the building process and get a built Processor
  14888. */
  14889. - public Builder add(ImageOperator op) {
  14890. - super.add(op);
  14891. - return this;
  14892. + public static class Builder extends SequentialProcessor.Builder<TensorImage> {
  14893. + public Builder() {
  14894. + super();
  14895. + }
  14896. +
  14897. + /**
  14898. + * Adds an {@link ImageOperator} into the Operator chain.
  14899. + *
  14900. + * @param op the Operator instance to be executed then
  14901. + */
  14902. + public Builder add(ImageOperator op) {
  14903. + super.add(op);
  14904. + return this;
  14905. + }
  14906. +
  14907. + /**
  14908. + * Adds a {@link TensorOperator} into the Operator chain. In execution, the processor calls
  14909. + * {@link TensorImage#getTensorBuffer()} to transform the {@link TensorImage} by
  14910. + * transforming the underlying {@link
  14911. + * org.tensorflow.lite.support.tensorbuffer.TensorBuffer}.
  14912. + *
  14913. + * @param op the Operator instance to be executed then
  14914. + */
  14915. + public Builder add(TensorOperator op) {
  14916. + return add(new TensorOperatorWrapper(op));
  14917. + }
  14918. +
  14919. + /** Completes the building process and gets the {@link ImageProcessor} instance. */
  14920. + @Override
  14921. + public ImageProcessor build() {
  14922. + return new ImageProcessor(this);
  14923. + }
  14924. }
  14925. /**
  14926. - * Adds a {@link TensorOperator} into the Operator chain. In execution, the processor calls
  14927. - * {@link TensorImage#getTensorBuffer()} to transform the {@link TensorImage} by transforming
  14928. - * the underlying {@link org.tensorflow.lite.support.tensorbuffer.TensorBuffer}.
  14929. + * Updates the number of rotations for the first {@link Rot90Op} in this {@link ImageProcessor}.
  14930. + *
  14931. + * <p><b>WARNING:</b>this method is <b>not</b> thread-safe. Updating the number of rotations and
  14932. + * then processing images (using {@link #process}) must be protected from concurrent access with
  14933. + * additional synchronization.
  14934. *
  14935. - * @param op the Operator instance to be executed then
  14936. + * @param k the number of rotations
  14937. + * @throws IllegalStateException if {@link Rot90Op} has not been added to this {@link
  14938. + * ImageProcessor}
  14939. */
  14940. - public Builder add(TensorOperator op) {
  14941. - return add(new TensorOperatorWrapper(op));
  14942. + public void updateNumberOfRotations(int k) {
  14943. + updateNumberOfRotations(k, /*occurrence=*/0);
  14944. }
  14945. - /** Completes the building process and gets the {@link ImageProcessor} instance. */
  14946. - @Override
  14947. - public ImageProcessor build() {
  14948. - return new ImageProcessor(this);
  14949. + /**
  14950. + * Updates the number of rotations for the {@link Rot90Op} specified by {@code occurrence} in
  14951. + * this
  14952. + * {@link ImageProcessor}.
  14953. + *
  14954. + * <p><b>WARNING:</b>this method is <b>not</b> thread-safe. Updating the number of rotations and
  14955. + * then processing images (using {@link #process}) must be protected from concurrent access with
  14956. + * additional synchronization.
  14957. + *
  14958. + * @param k the number of rotations
  14959. + * @param occurrence the index of perticular {@link Rot90Op} in this {@link ImageProcessor}. For
  14960. + * example, if the second {@link Rot90Op} needs to be updated, {@code occurrence} should be
  14961. + * set to 1.
  14962. + * @throws IndexOutOfBoundsException if {@code occurrence} is negative or is not less than the
  14963. + * number of {@link Rot90Op} in this {@link ImageProcessor}
  14964. + * @throws IllegalStateException if {@link Rot90Op} has not been added to this {@link
  14965. + * ImageProcessor}
  14966. + */
  14967. + public synchronized void updateNumberOfRotations(int k, int occurrence) {
  14968. + SupportPreconditions.checkState(operatorIndex.containsKey(Rot90Op.class.getName()),
  14969. + "The Rot90Op has not been added to the ImageProcessor.");
  14970. +
  14971. + List<Integer> indexes = operatorIndex.get(Rot90Op.class.getName());
  14972. + SupportPreconditions.checkElementIndex(occurrence, indexes.size(), "occurrence");
  14973. +
  14974. + // The index of the Rot90Op to be replaced in operatorList.
  14975. + int index = indexes.get(occurrence);
  14976. + Rot90Op newRot = new Rot90Op(k);
  14977. + operatorList.set(index, newRot);
  14978. }
  14979. - }
  14980. -
  14981. - /**
  14982. - * Updates the number of rotations for the first {@link Rot90Op} in this {@link ImageProcessor}.
  14983. - *
  14984. - * <p><b>WARNING:</b>this method is <b>not</b> thread-safe. Updating the number of rotations and
  14985. - * then processing images (using {@link #process}) must be protected from concurrent access with
  14986. - * additional synchronization.
  14987. - *
  14988. - * @param k the number of rotations
  14989. - * @throws IllegalStateException if {@link Rot90Op} has not been added to this {@link
  14990. - * ImageProcessor}
  14991. - */
  14992. - public void updateNumberOfRotations(int k) {
  14993. - updateNumberOfRotations(k, /*occurrence=*/ 0);
  14994. - }
  14995. -
  14996. - /**
  14997. - * Updates the number of rotations for the {@link Rot90Op} specified by {@code occurrence} in this
  14998. - * {@link ImageProcessor}.
  14999. - *
  15000. - * <p><b>WARNING:</b>this method is <b>not</b> thread-safe. Updating the number of rotations and
  15001. - * then processing images (using {@link #process}) must be protected from concurrent access with
  15002. - * additional synchronization.
  15003. - *
  15004. - * @param k the number of rotations
  15005. - * @param occurrence the index of perticular {@link Rot90Op} in this {@link ImageProcessor}. For
  15006. - * example, if the second {@link Rot90Op} needs to be updated, {@code occurrence} should be
  15007. - * set to 1.
  15008. - * @throws IndexOutOfBoundsException if {@code occurrence} is negative or is not less than the
  15009. - * number of {@link Rot90Op} in this {@link ImageProcessor}
  15010. - * @throws IllegalStateException if {@link Rot90Op} has not been added to this {@link
  15011. - * ImageProcessor}
  15012. - */
  15013. - public synchronized void updateNumberOfRotations(int k, int occurrence) {
  15014. - SupportPreconditions.checkState(
  15015. - operatorIndex.containsKey(Rot90Op.class.getName()),
  15016. - "The Rot90Op has not been added to the ImageProcessor.");
  15017. -
  15018. - List<Integer> indexes = operatorIndex.get(Rot90Op.class.getName());
  15019. - SupportPreconditions.checkElementIndex(occurrence, indexes.size(), "occurrence");
  15020. -
  15021. - // The index of the Rot90Op to be replaced in operatorList.
  15022. - int index = indexes.get(occurrence);
  15023. - Rot90Op newRot = new Rot90Op(k);
  15024. - operatorList.set(index, newRot);
  15025. - }
  15026. }
  15027. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageProperties.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageProperties.java
  15028. index 96daf85a02f5a..f61f59fa13ce7 100644
  15029. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageProperties.java
  15030. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageProperties.java
  15031. @@ -26,52 +26,51 @@ import com.google.auto.value.AutoValue;
  15032. */
  15033. @AutoValue
  15034. public abstract class ImageProperties {
  15035. + private static final int DEFAULT_HEIGHT = -1;
  15036. + private static final int DEFAULT_WIDTH = -1;
  15037. - private static final int DEFAULT_HEIGHT = -1;
  15038. - private static final int DEFAULT_WIDTH = -1;
  15039. -
  15040. - public abstract int getHeight();
  15041. -
  15042. - public abstract int getWidth();
  15043. -
  15044. - public abstract ColorSpaceType getColorSpaceType();
  15045. -
  15046. - public static Builder builder() {
  15047. - return new AutoValue_ImageProperties.Builder()
  15048. - .setHeight(DEFAULT_HEIGHT)
  15049. - .setWidth(DEFAULT_WIDTH);
  15050. - }
  15051. -
  15052. - /**
  15053. - * Builder for {@link ImageProperties}. Different image objects may require different properties.
  15054. - * See the detais below:
  15055. - *
  15056. - * <ul>
  15057. - * {@link org.tensorflow.lite.support.tensorbuffer.TensorBuffer}:
  15058. - * <li>Mandatory proterties: height / width / colorSpaceType. The shape of the TensorBuffer
  15059. - * object will not be used to determine image height and width.
  15060. - * </ul>
  15061. - */
  15062. - @AutoValue.Builder
  15063. - public abstract static class Builder {
  15064. - public abstract Builder setHeight(int height);
  15065. -
  15066. - public abstract Builder setWidth(int width);
  15067. -
  15068. - public abstract Builder setColorSpaceType(ColorSpaceType colorSpaceType);
  15069. -
  15070. - abstract ImageProperties autoBuild();
  15071. -
  15072. - public ImageProperties build() {
  15073. - ImageProperties properties = autoBuild();
  15074. - // If width or hight are not configured by the Builder, they will be -1.
  15075. - // Enforcing all properties to be populated (AutoValue will error out if objects, like
  15076. - // colorSpaceType, are not set up), since they are required for TensorBuffer images.
  15077. - // If in the future we have some image object types that only require a portion of these
  15078. - // properties, we can delay the check when TensorImage#load() is executed.
  15079. - checkState(properties.getHeight() >= 0, "Negative image height is not allowed.");
  15080. - checkState(properties.getWidth() >= 0, "Negative image width is not allowed.");
  15081. - return properties;
  15082. + public abstract int getHeight();
  15083. +
  15084. + public abstract int getWidth();
  15085. +
  15086. + public abstract ColorSpaceType getColorSpaceType();
  15087. +
  15088. + public static Builder builder() {
  15089. + return new AutoValue_ImageProperties.Builder()
  15090. + .setHeight(DEFAULT_HEIGHT)
  15091. + .setWidth(DEFAULT_WIDTH);
  15092. + }
  15093. +
  15094. + /**
  15095. + * Builder for {@link ImageProperties}. Different image objects may require different
  15096. + * properties. See the detais below:
  15097. + *
  15098. + * <ul>
  15099. + * {@link org.tensorflow.lite.support.tensorbuffer.TensorBuffer}:
  15100. + * <li>Mandatory proterties: height / width / colorSpaceType. The shape of the TensorBuffer
  15101. + * object will not be used to determine image height and width.
  15102. + * </ul>
  15103. + */
  15104. + @AutoValue.Builder
  15105. + public abstract static class Builder {
  15106. + public abstract Builder setHeight(int height);
  15107. +
  15108. + public abstract Builder setWidth(int width);
  15109. +
  15110. + public abstract Builder setColorSpaceType(ColorSpaceType colorSpaceType);
  15111. +
  15112. + abstract ImageProperties autoBuild();
  15113. +
  15114. + public ImageProperties build() {
  15115. + ImageProperties properties = autoBuild();
  15116. + // If width or hight are not configured by the Builder, they will be -1.
  15117. + // Enforcing all properties to be populated (AutoValue will error out if objects, like
  15118. + // colorSpaceType, are not set up), since they are required for TensorBuffer images.
  15119. + // If in the future we have some image object types that only require a portion of these
  15120. + // properties, we can delay the check when TensorImage#load() is executed.
  15121. + checkState(properties.getHeight() >= 0, "Negative image height is not allowed.");
  15122. + checkState(properties.getWidth() >= 0, "Negative image width is not allowed.");
  15123. + return properties;
  15124. + }
  15125. }
  15126. - }
  15127. }
  15128. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/MediaImageContainer.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/MediaImageContainer.java
  15129. index 50d787b5afab1..519aacaf7f20b 100644
  15130. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/MediaImageContainer.java
  15131. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/MediaImageContainer.java
  15132. @@ -21,65 +21,65 @@ import static org.tensorflow.lite.support.common.internal.SupportPreconditions.c
  15133. import android.graphics.Bitmap;
  15134. import android.graphics.ImageFormat;
  15135. import android.media.Image;
  15136. +
  15137. import org.tensorflow.lite.DataType;
  15138. import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
  15139. /** Holds an {@link Image} and converts it to other image formats as needed. */
  15140. final class MediaImageContainer implements ImageContainer {
  15141. -
  15142. - private final Image image;
  15143. -
  15144. - /**
  15145. - * Creates a {@link MediaImageContainer} object with a YUV_420_888 {@link Image}.
  15146. - *
  15147. - * @throws IllegalArgumentException if the {@link ImageFormat} of {@code image} is not ARGB_8888
  15148. - */
  15149. - static MediaImageContainer create(Image image) {
  15150. - return new MediaImageContainer(image);
  15151. - }
  15152. -
  15153. - private MediaImageContainer(Image image) {
  15154. - checkNotNull(image, "Cannot load null Image.");
  15155. - checkArgument(
  15156. - image.getFormat() == ImageFormat.YUV_420_888, "Only supports loading YUV_420_888 Image.");
  15157. - this.image = image;
  15158. - }
  15159. -
  15160. - @Override
  15161. - public MediaImageContainer clone() {
  15162. - throw new UnsupportedOperationException(
  15163. - "android.media.Image is an abstract class and cannot be cloned.");
  15164. - }
  15165. -
  15166. - @Override
  15167. - public Bitmap getBitmap() {
  15168. - throw new UnsupportedOperationException(
  15169. - "Converting an android.media.Image to Bitmap is not supported.");
  15170. - }
  15171. -
  15172. - @Override
  15173. - public TensorBuffer getTensorBuffer(DataType dataType) {
  15174. - throw new UnsupportedOperationException(
  15175. - "Converting an android.media.Image to TesorBuffer is not supported.");
  15176. - }
  15177. -
  15178. - @Override
  15179. - public Image getMediaImage() {
  15180. - return image;
  15181. - }
  15182. -
  15183. - @Override
  15184. - public int getWidth() {
  15185. - return image.getWidth();
  15186. - }
  15187. -
  15188. - @Override
  15189. - public int getHeight() {
  15190. - return image.getHeight();
  15191. - }
  15192. -
  15193. - @Override
  15194. - public ColorSpaceType getColorSpaceType() {
  15195. - return ColorSpaceType.fromImageFormat(image.getFormat());
  15196. - }
  15197. + private final Image image;
  15198. +
  15199. + /**
  15200. + * Creates a {@link MediaImageContainer} object with a YUV_420_888 {@link Image}.
  15201. + *
  15202. + * @throws IllegalArgumentException if the {@link ImageFormat} of {@code image} is not ARGB_8888
  15203. + */
  15204. + static MediaImageContainer create(Image image) {
  15205. + return new MediaImageContainer(image);
  15206. + }
  15207. +
  15208. + private MediaImageContainer(Image image) {
  15209. + checkNotNull(image, "Cannot load null Image.");
  15210. + checkArgument(image.getFormat() == ImageFormat.YUV_420_888,
  15211. + "Only supports loading YUV_420_888 Image.");
  15212. + this.image = image;
  15213. + }
  15214. +
  15215. + @Override
  15216. + public MediaImageContainer clone() {
  15217. + throw new UnsupportedOperationException(
  15218. + "android.media.Image is an abstract class and cannot be cloned.");
  15219. + }
  15220. +
  15221. + @Override
  15222. + public Bitmap getBitmap() {
  15223. + throw new UnsupportedOperationException(
  15224. + "Converting an android.media.Image to Bitmap is not supported.");
  15225. + }
  15226. +
  15227. + @Override
  15228. + public TensorBuffer getTensorBuffer(DataType dataType) {
  15229. + throw new UnsupportedOperationException(
  15230. + "Converting an android.media.Image to TesorBuffer is not supported.");
  15231. + }
  15232. +
  15233. + @Override
  15234. + public Image getMediaImage() {
  15235. + return image;
  15236. + }
  15237. +
  15238. + @Override
  15239. + public int getWidth() {
  15240. + return image.getWidth();
  15241. + }
  15242. +
  15243. + @Override
  15244. + public int getHeight() {
  15245. + return image.getHeight();
  15246. + }
  15247. +
  15248. + @Override
  15249. + public ColorSpaceType getColorSpaceType() {
  15250. + return ColorSpaceType.fromImageFormat(image.getFormat());
  15251. + }
  15252. }
  15253. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/MlImageAdapter.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/MlImageAdapter.java
  15254. index ed066e5308fb9..03017bf733f02 100644
  15255. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/MlImageAdapter.java
  15256. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/MlImageAdapter.java
  15257. @@ -21,91 +21,99 @@ import com.google.android.odml.image.MediaImageExtractor;
  15258. import com.google.android.odml.image.MlImage;
  15259. import com.google.android.odml.image.MlImage.ImageFormat;
  15260. import com.google.auto.value.AutoValue;
  15261. +
  15262. import java.nio.ByteBuffer;
  15263. /** Converts {@code MlImage} to {@link TensorImage} and vice versa. */
  15264. public class MlImageAdapter {
  15265. + /** Proxies an {@link ImageFormat} and its equivalent {@link ColorSpaceType}. */
  15266. + @AutoValue
  15267. + abstract static class ImageFormatProxy {
  15268. + abstract ColorSpaceType getColorSpaceType();
  15269. - /** Proxies an {@link ImageFormat} and its equivalent {@link ColorSpaceType}. */
  15270. - @AutoValue
  15271. - abstract static class ImageFormatProxy {
  15272. -
  15273. - abstract ColorSpaceType getColorSpaceType();
  15274. + @ImageFormat
  15275. + abstract int getImageFormat();
  15276. - @ImageFormat
  15277. - abstract int getImageFormat();
  15278. -
  15279. - static ImageFormatProxy createFromImageFormat(@ImageFormat int format) {
  15280. - switch (format) {
  15281. - case MlImage.IMAGE_FORMAT_RGB:
  15282. - return new AutoValue_MlImageAdapter_ImageFormatProxy(ColorSpaceType.RGB, format);
  15283. - case MlImage.IMAGE_FORMAT_NV12:
  15284. - return new AutoValue_MlImageAdapter_ImageFormatProxy(ColorSpaceType.NV12, format);
  15285. - case MlImage.IMAGE_FORMAT_NV21:
  15286. - return new AutoValue_MlImageAdapter_ImageFormatProxy(ColorSpaceType.NV21, format);
  15287. - case MlImage.IMAGE_FORMAT_YV12:
  15288. - return new AutoValue_MlImageAdapter_ImageFormatProxy(ColorSpaceType.YV12, format);
  15289. - case MlImage.IMAGE_FORMAT_YV21:
  15290. - return new AutoValue_MlImageAdapter_ImageFormatProxy(ColorSpaceType.YV21, format);
  15291. - case MlImage.IMAGE_FORMAT_YUV_420_888:
  15292. - return new AutoValue_MlImageAdapter_ImageFormatProxy(ColorSpaceType.YUV_420_888, format);
  15293. - case MlImage.IMAGE_FORMAT_ALPHA:
  15294. - return new AutoValue_MlImageAdapter_ImageFormatProxy(ColorSpaceType.GRAYSCALE, format);
  15295. - case MlImage.IMAGE_FORMAT_RGBA:
  15296. - case MlImage.IMAGE_FORMAT_JPEG:
  15297. - case MlImage.IMAGE_FORMAT_UNKNOWN:
  15298. - throw new IllegalArgumentException(
  15299. - "Cannot create ColorSpaceType from MlImage format: " + format);
  15300. - default:
  15301. - throw new AssertionError("Illegal @ImageFormat: " + format);
  15302. - }
  15303. + static ImageFormatProxy createFromImageFormat(@ImageFormat int format) {
  15304. + switch (format) {
  15305. + case MlImage.IMAGE_FORMAT_RGB:
  15306. + return new AutoValue_MlImageAdapter_ImageFormatProxy(
  15307. + ColorSpaceType.RGB, format);
  15308. + case MlImage.IMAGE_FORMAT_NV12:
  15309. + return new AutoValue_MlImageAdapter_ImageFormatProxy(
  15310. + ColorSpaceType.NV12, format);
  15311. + case MlImage.IMAGE_FORMAT_NV21:
  15312. + return new AutoValue_MlImageAdapter_ImageFormatProxy(
  15313. + ColorSpaceType.NV21, format);
  15314. + case MlImage.IMAGE_FORMAT_YV12:
  15315. + return new AutoValue_MlImageAdapter_ImageFormatProxy(
  15316. + ColorSpaceType.YV12, format);
  15317. + case MlImage.IMAGE_FORMAT_YV21:
  15318. + return new AutoValue_MlImageAdapter_ImageFormatProxy(
  15319. + ColorSpaceType.YV21, format);
  15320. + case MlImage.IMAGE_FORMAT_YUV_420_888:
  15321. + return new AutoValue_MlImageAdapter_ImageFormatProxy(
  15322. + ColorSpaceType.YUV_420_888, format);
  15323. + case MlImage.IMAGE_FORMAT_ALPHA:
  15324. + return new AutoValue_MlImageAdapter_ImageFormatProxy(
  15325. + ColorSpaceType.GRAYSCALE, format);
  15326. + case MlImage.IMAGE_FORMAT_RGBA:
  15327. + case MlImage.IMAGE_FORMAT_JPEG:
  15328. + case MlImage.IMAGE_FORMAT_UNKNOWN:
  15329. + throw new IllegalArgumentException(
  15330. + "Cannot create ColorSpaceType from MlImage format: " + format);
  15331. + default:
  15332. + throw new AssertionError("Illegal @ImageFormat: " + format);
  15333. + }
  15334. + }
  15335. }
  15336. - }
  15337. - /**
  15338. - * Creates a {@link TensorImage} from an {@link MlImage}.
  15339. - *
  15340. - * <p>IMPORTANT: The returned {@link TensorImage} shares storage with {@code mlImage}, so do not
  15341. - * modify the contained object in the {@link TensorImage}, as {@code MlImage} expects its
  15342. - * contained data are immutable. Also, callers should use {@code MlImage#getInternal()#acquire()}
  15343. - * and {@code MlImage#release()} to avoid the {@code mlImage} being released unexpectedly.
  15344. - *
  15345. - * @throws IllegalArgumentException if the {@code mlImage} is built from an unsupported container.
  15346. - */
  15347. - public static TensorImage createTensorImageFrom(MlImage mlImage) {
  15348. - // TODO(b/190670174): Choose the best storage from multiple containers.
  15349. - com.google.android.odml.image.ImageProperties mlImageProperties =
  15350. - mlImage.getContainedImageProperties().get(0);
  15351. - switch (mlImageProperties.getStorageType()) {
  15352. - case MlImage.STORAGE_TYPE_BITMAP:
  15353. - return TensorImage.fromBitmap(BitmapExtractor.extract(mlImage));
  15354. - case MlImage.STORAGE_TYPE_MEDIA_IMAGE:
  15355. - TensorImage mediaTensorImage = new TensorImage();
  15356. - mediaTensorImage.load(MediaImageExtractor.extract(mlImage));
  15357. - return mediaTensorImage;
  15358. - case MlImage.STORAGE_TYPE_BYTEBUFFER:
  15359. - ByteBuffer buffer = ByteBufferExtractor.extract(mlImage);
  15360. - ImageFormatProxy formatProxy =
  15361. - ImageFormatProxy.createFromImageFormat(mlImageProperties.getImageFormat());
  15362. - TensorImage byteBufferTensorImage = new TensorImage();
  15363. - ImageProperties properties =
  15364. - ImageProperties.builder()
  15365. - .setColorSpaceType(formatProxy.getColorSpaceType())
  15366. - .setHeight(mlImage.getHeight())
  15367. - .setWidth(mlImage.getWidth())
  15368. - .build();
  15369. - byteBufferTensorImage.load(buffer, properties);
  15370. - return byteBufferTensorImage;
  15371. - default:
  15372. - throw new IllegalArgumentException(
  15373. - "Illegal storage type: " + mlImageProperties.getStorageType());
  15374. + /**
  15375. + * Creates a {@link TensorImage} from an {@link MlImage}.
  15376. + *
  15377. + * <p>IMPORTANT: The returned {@link TensorImage} shares storage with {@code mlImage}, so do not
  15378. + * modify the contained object in the {@link TensorImage}, as {@code MlImage} expects its
  15379. + * contained data are immutable. Also, callers should use {@code
  15380. + * MlImage#getInternal()#acquire()} and {@code MlImage#release()} to avoid the {@code mlImage}
  15381. + * being released unexpectedly.
  15382. + *
  15383. + * @throws IllegalArgumentException if the {@code mlImage} is built from an unsupported
  15384. + * container.
  15385. + */
  15386. + public static TensorImage createTensorImageFrom(MlImage mlImage) {
  15387. + // TODO(b/190670174): Choose the best storage from multiple containers.
  15388. + com.google.android.odml.image.ImageProperties mlImageProperties =
  15389. + mlImage.getContainedImageProperties().get(0);
  15390. + switch (mlImageProperties.getStorageType()) {
  15391. + case MlImage.STORAGE_TYPE_BITMAP:
  15392. + return TensorImage.fromBitmap(BitmapExtractor.extract(mlImage));
  15393. + case MlImage.STORAGE_TYPE_MEDIA_IMAGE:
  15394. + TensorImage mediaTensorImage = new TensorImage();
  15395. + mediaTensorImage.load(MediaImageExtractor.extract(mlImage));
  15396. + return mediaTensorImage;
  15397. + case MlImage.STORAGE_TYPE_BYTEBUFFER:
  15398. + ByteBuffer buffer = ByteBufferExtractor.extract(mlImage);
  15399. + ImageFormatProxy formatProxy =
  15400. + ImageFormatProxy.createFromImageFormat(mlImageProperties.getImageFormat());
  15401. + TensorImage byteBufferTensorImage = new TensorImage();
  15402. + ImageProperties properties =
  15403. + ImageProperties.builder()
  15404. + .setColorSpaceType(formatProxy.getColorSpaceType())
  15405. + .setHeight(mlImage.getHeight())
  15406. + .setWidth(mlImage.getWidth())
  15407. + .build();
  15408. + byteBufferTensorImage.load(buffer, properties);
  15409. + return byteBufferTensorImage;
  15410. + default:
  15411. + throw new IllegalArgumentException(
  15412. + "Illegal storage type: " + mlImageProperties.getStorageType());
  15413. + }
  15414. }
  15415. - }
  15416. - /** Creatas a {@link ColorSpaceType} from {@code MlImage.ImageFormat}. */
  15417. - public static ColorSpaceType createColorSpaceTypeFrom(@ImageFormat int imageFormat) {
  15418. - return ImageFormatProxy.createFromImageFormat(imageFormat).getColorSpaceType();
  15419. - }
  15420. + /** Creatas a {@link ColorSpaceType} from {@code MlImage.ImageFormat}. */
  15421. + public static ColorSpaceType createColorSpaceTypeFrom(@ImageFormat int imageFormat) {
  15422. + return ImageFormatProxy.createFromImageFormat(imageFormat).getColorSpaceType();
  15423. + }
  15424. - private MlImageAdapter() {}
  15425. + private MlImageAdapter() {}
  15426. }
  15427. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/TensorBufferContainer.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/TensorBufferContainer.java
  15428. index 39e2ceb9db521..6dfef70ba67f7 100644
  15429. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/TensorBufferContainer.java
  15430. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/TensorBufferContainer.java
  15431. @@ -20,118 +20,108 @@ import static org.tensorflow.lite.support.common.internal.SupportPreconditions.c
  15432. import android.graphics.Bitmap;
  15433. import android.media.Image;
  15434. import android.util.Log;
  15435. +
  15436. import org.tensorflow.lite.DataType;
  15437. import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
  15438. /** Holds a {@link TensorBuffer} and converts it to other image formats as needed. */
  15439. final class TensorBufferContainer implements ImageContainer {
  15440. + private final TensorBuffer buffer;
  15441. + private final ColorSpaceType colorSpaceType;
  15442. + private final int height;
  15443. + private final int width;
  15444. + private static final String TAG = TensorBufferContainer.class.getSimpleName();
  15445. +
  15446. + /**
  15447. + * Creates a {@link TensorBufferContainer} object with the specified {@link
  15448. + * TensorImage#ColorSpaceType}.
  15449. + *
  15450. + * <p>Only supports {@link ColorSapceType#RGB} and {@link ColorSpaceType#GRAYSCALE}. Use {@link
  15451. + * #create(TensorBuffer, ImageProperties)} for other color space types.
  15452. + *
  15453. + * @throws IllegalArgumentException if the shape of the {@link TensorBuffer} does not match the
  15454. + * specified color space type, or if the color space type is not supported
  15455. + */
  15456. + static TensorBufferContainer create(TensorBuffer buffer, ColorSpaceType colorSpaceType) {
  15457. + checkArgument(
  15458. + colorSpaceType == ColorSpaceType.RGB || colorSpaceType == ColorSpaceType.GRAYSCALE,
  15459. + "Only ColorSpaceType.RGB and ColorSpaceType.GRAYSCALE are supported. Use"
  15460. + + " `create(TensorBuffer, ImageProperties)` for other color space types.");
  15461. +
  15462. + return new TensorBufferContainer(buffer, colorSpaceType,
  15463. + colorSpaceType.getHeight(buffer.getShape()),
  15464. + colorSpaceType.getWidth(buffer.getShape()));
  15465. + }
  15466. - private final TensorBuffer buffer;
  15467. - private final ColorSpaceType colorSpaceType;
  15468. - private final int height;
  15469. - private final int width;
  15470. - private static final String TAG = TensorBufferContainer.class.getSimpleName();
  15471. -
  15472. - /**
  15473. - * Creates a {@link TensorBufferContainer} object with the specified {@link
  15474. - * TensorImage#ColorSpaceType}.
  15475. - *
  15476. - * <p>Only supports {@link ColorSapceType#RGB} and {@link ColorSpaceType#GRAYSCALE}. Use {@link
  15477. - * #create(TensorBuffer, ImageProperties)} for other color space types.
  15478. - *
  15479. - * @throws IllegalArgumentException if the shape of the {@link TensorBuffer} does not match the
  15480. - * specified color space type, or if the color space type is not supported
  15481. - */
  15482. - static TensorBufferContainer create(TensorBuffer buffer, ColorSpaceType colorSpaceType) {
  15483. - checkArgument(
  15484. - colorSpaceType == ColorSpaceType.RGB || colorSpaceType == ColorSpaceType.GRAYSCALE,
  15485. - "Only ColorSpaceType.RGB and ColorSpaceType.GRAYSCALE are supported. Use"
  15486. - + " `create(TensorBuffer, ImageProperties)` for other color space types.");
  15487. -
  15488. - return new TensorBufferContainer(
  15489. - buffer,
  15490. - colorSpaceType,
  15491. - colorSpaceType.getHeight(buffer.getShape()),
  15492. - colorSpaceType.getWidth(buffer.getShape()));
  15493. - }
  15494. -
  15495. - static TensorBufferContainer create(TensorBuffer buffer, ImageProperties imageProperties) {
  15496. - return new TensorBufferContainer(
  15497. - buffer,
  15498. - imageProperties.getColorSpaceType(),
  15499. - imageProperties.getHeight(),
  15500. - imageProperties.getWidth());
  15501. - }
  15502. -
  15503. - private TensorBufferContainer(
  15504. - TensorBuffer buffer, ColorSpaceType colorSpaceType, int height, int width) {
  15505. - checkArgument(
  15506. - colorSpaceType != ColorSpaceType.YUV_420_888,
  15507. - "The actual encoding format of YUV420 is required. Choose a ColorSpaceType from: NV12,"
  15508. - + " NV21, YV12, YV21. Use YUV_420_888 only when loading an android.media.Image.");
  15509. -
  15510. - colorSpaceType.assertNumElements(buffer.getFlatSize(), height, width);
  15511. - this.buffer = buffer;
  15512. - this.colorSpaceType = colorSpaceType;
  15513. - this.height = height;
  15514. - this.width = width;
  15515. - }
  15516. -
  15517. - @Override
  15518. - public TensorBufferContainer clone() {
  15519. - return new TensorBufferContainer(
  15520. - TensorBuffer.createFrom(buffer, buffer.getDataType()),
  15521. - colorSpaceType,
  15522. - getHeight(),
  15523. - getWidth());
  15524. - }
  15525. -
  15526. - @Override
  15527. - public Bitmap getBitmap() {
  15528. - if (buffer.getDataType() != DataType.UINT8) {
  15529. - // Print warning instead of throwing an exception. When using float models, users may want to
  15530. - // convert the resulting float image into Bitmap. That's fine to do so, as long as they are
  15531. - // aware of the potential accuracy lost when casting to uint8.
  15532. - Log.w(
  15533. - TAG,
  15534. - "<Warning> TensorBufferContainer is holding a non-uint8 image. The conversion to Bitmap"
  15535. - + " will cause numeric casting and clamping on the data value.");
  15536. + static TensorBufferContainer create(TensorBuffer buffer, ImageProperties imageProperties) {
  15537. + return new TensorBufferContainer(buffer, imageProperties.getColorSpaceType(),
  15538. + imageProperties.getHeight(), imageProperties.getWidth());
  15539. }
  15540. - return colorSpaceType.convertTensorBufferToBitmap(buffer);
  15541. - }
  15542. -
  15543. - @Override
  15544. - public TensorBuffer getTensorBuffer(DataType dataType) {
  15545. - // If the data type of buffer is desired, return it directly. Not making a defensive copy for
  15546. - // performance considerations. During image processing, users may need to set and get the
  15547. - // TensorBuffer many times.
  15548. - // Otherwise, create another one with the expected data type.
  15549. - return buffer.getDataType() == dataType ? buffer : TensorBuffer.createFrom(buffer, dataType);
  15550. - }
  15551. -
  15552. - @Override
  15553. - public Image getMediaImage() {
  15554. - throw new UnsupportedOperationException(
  15555. - "Converting from TensorBuffer to android.media.Image is unsupported.");
  15556. - }
  15557. -
  15558. - @Override
  15559. - public int getWidth() {
  15560. - // In case the underlying buffer in Tensorbuffer gets updated after TensorImage is created.
  15561. - colorSpaceType.assertNumElements(buffer.getFlatSize(), height, width);
  15562. - return width;
  15563. - }
  15564. -
  15565. - @Override
  15566. - public int getHeight() {
  15567. - // In case the underlying buffer in Tensorbuffer gets updated after TensorImage is created.
  15568. - colorSpaceType.assertNumElements(buffer.getFlatSize(), height, width);
  15569. - return height;
  15570. - }
  15571. -
  15572. - @Override
  15573. - public ColorSpaceType getColorSpaceType() {
  15574. - return colorSpaceType;
  15575. - }
  15576. + private TensorBufferContainer(
  15577. + TensorBuffer buffer, ColorSpaceType colorSpaceType, int height, int width) {
  15578. + checkArgument(colorSpaceType != ColorSpaceType.YUV_420_888,
  15579. + "The actual encoding format of YUV420 is required. Choose a ColorSpaceType from: NV12,"
  15580. + + " NV21, YV12, YV21. Use YUV_420_888 only when loading an android.media.Image.");
  15581. +
  15582. + colorSpaceType.assertNumElements(buffer.getFlatSize(), height, width);
  15583. + this.buffer = buffer;
  15584. + this.colorSpaceType = colorSpaceType;
  15585. + this.height = height;
  15586. + this.width = width;
  15587. + }
  15588. +
  15589. + @Override
  15590. + public TensorBufferContainer clone() {
  15591. + return new TensorBufferContainer(TensorBuffer.createFrom(buffer, buffer.getDataType()),
  15592. + colorSpaceType, getHeight(), getWidth());
  15593. + }
  15594. +
  15595. + @Override
  15596. + public Bitmap getBitmap() {
  15597. + if (buffer.getDataType() != DataType.UINT8) {
  15598. + // Print warning instead of throwing an exception. When using float models, users may
  15599. + // want to convert the resulting float image into Bitmap. That's fine to do so, as long
  15600. + // as they are aware of the potential accuracy lost when casting to uint8.
  15601. + Log.w(TAG,
  15602. + "<Warning> TensorBufferContainer is holding a non-uint8 image. The conversion to Bitmap"
  15603. + + " will cause numeric casting and clamping on the data value.");
  15604. + }
  15605. +
  15606. + return colorSpaceType.convertTensorBufferToBitmap(buffer);
  15607. + }
  15608. +
  15609. + @Override
  15610. + public TensorBuffer getTensorBuffer(DataType dataType) {
  15611. + // If the data type of buffer is desired, return it directly. Not making a defensive copy
  15612. + // for performance considerations. During image processing, users may need to set and get
  15613. + // the TensorBuffer many times. Otherwise, create another one with the expected data type.
  15614. + return buffer.getDataType() == dataType ? buffer
  15615. + : TensorBuffer.createFrom(buffer, dataType);
  15616. + }
  15617. +
  15618. + @Override
  15619. + public Image getMediaImage() {
  15620. + throw new UnsupportedOperationException(
  15621. + "Converting from TensorBuffer to android.media.Image is unsupported.");
  15622. + }
  15623. +
  15624. + @Override
  15625. + public int getWidth() {
  15626. + // In case the underlying buffer in Tensorbuffer gets updated after TensorImage is created.
  15627. + colorSpaceType.assertNumElements(buffer.getFlatSize(), height, width);
  15628. + return width;
  15629. + }
  15630. +
  15631. + @Override
  15632. + public int getHeight() {
  15633. + // In case the underlying buffer in Tensorbuffer gets updated after TensorImage is created.
  15634. + colorSpaceType.assertNumElements(buffer.getFlatSize(), height, width);
  15635. + return height;
  15636. + }
  15637. +
  15638. + @Override
  15639. + public ColorSpaceType getColorSpaceType() {
  15640. + return colorSpaceType;
  15641. + }
  15642. }
  15643. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/TensorImage.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/TensorImage.java
  15644. index 1624971817aba..83cf4c0f648b2 100644
  15645. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/TensorImage.java
  15646. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/TensorImage.java
  15647. @@ -19,10 +19,12 @@ import static org.tensorflow.lite.support.common.internal.SupportPreconditions.c
  15648. import android.graphics.Bitmap;
  15649. import android.media.Image;
  15650. -import java.nio.ByteBuffer;
  15651. +
  15652. import org.tensorflow.lite.DataType;
  15653. import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
  15654. +import java.nio.ByteBuffer;
  15655. +
  15656. /**
  15657. * TensorImage is the wrapper class for Image object. When using image processing utils in
  15658. * TFLite.support library, it's common to convert image objects in variant types to TensorImage at
  15659. @@ -49,350 +51,357 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
  15660. // TODO(b/138907116): Support loading images from TensorBuffer with properties.
  15661. // TODO(b/138905544): Support directly loading RGBBytes, YUVBytes and other types if necessary.
  15662. public class TensorImage {
  15663. + private final DataType dataType;
  15664. + private ImageContainer container = null;
  15665. +
  15666. + /**
  15667. + * Initializes a {@link TensorImage} object.
  15668. + *
  15669. + * <p>Note: the data type of this {@link TensorImage} is {@link DataType#UINT8}. Use {@link
  15670. + * #TensorImage(DataType)} if other data types are preferred.
  15671. + */
  15672. + public TensorImage() {
  15673. + this(DataType.UINT8);
  15674. + }
  15675. +
  15676. + /**
  15677. + * Initializes a {@link TensorImage} object with the specified data type.
  15678. + *
  15679. + * <p>When getting a {@link TensorBuffer} or a {@link ByteBuffer} from this {@link TensorImage},
  15680. + * such as using {@link #getTensorBuffer} and {@link #getBuffer}, the data values will be
  15681. + * converted to the specified data type.
  15682. + *
  15683. + * <p>Note: the shape of a {@link TensorImage} is not fixed. It can be adjusted to the shape of
  15684. + * the image being loaded to this {@link TensorImage}.
  15685. + *
  15686. + * @param dataType the expected data type of the resulting {@link TensorBuffer}. The type is
  15687. + * always fixed during the lifetime of the {@link TensorImage}. To convert the data type,
  15688. + * use
  15689. + * {@link #createFrom(TensorImage, DataType)} to create a copy and convert data type at the
  15690. + * same time.
  15691. + * @throws IllegalArgumentException if {@code dataType} is neither {@link DataType#UINT8} nor
  15692. + * {@link DataType#FLOAT32}
  15693. + */
  15694. + public TensorImage(DataType dataType) {
  15695. + checkArgument(dataType == DataType.UINT8 || dataType == DataType.FLOAT32,
  15696. + "Illegal data type for TensorImage: Only FLOAT32 and UINT8 are accepted");
  15697. + this.dataType = dataType;
  15698. + }
  15699. +
  15700. + /**
  15701. + * Initializes a {@link TensorImage} object of {@link DataType#UINT8} with a {@link
  15702. + * android.graphics.Bitmap} .
  15703. + *
  15704. + * @see #load(Bitmap) for reusing the object when it's expensive to create objects frequently,
  15705. + * because every call of {@code fromBitmap} creates a new {@link TensorImage}.
  15706. + */
  15707. + public static TensorImage fromBitmap(Bitmap bitmap) {
  15708. + TensorImage image = new TensorImage();
  15709. + image.load(bitmap);
  15710. + return image;
  15711. + }
  15712. +
  15713. + /**
  15714. + * Creates a deep-copy of a given {@link TensorImage} with the desired data type.
  15715. + *
  15716. + * @param src the {@link TensorImage} to copy from
  15717. + * @param dataType the expected data type of newly created {@link TensorImage}
  15718. + * @return a {@link TensorImage} whose data is copied from {@code src} and data type is {@code
  15719. + * dataType}
  15720. + */
  15721. + public static TensorImage createFrom(TensorImage src, DataType dataType) {
  15722. + TensorImage dst = new TensorImage(dataType);
  15723. + dst.container = src.container.clone();
  15724. + return dst;
  15725. + }
  15726. +
  15727. + /**
  15728. + * Loads a {@link android.graphics.Bitmap} image object into this {@link TensorImage}.
  15729. + *
  15730. + * <p>Note: if the {@link TensorImage} has data type other than {@link DataType#UINT8}, numeric
  15731. + * casting and clamping will be applied when calling {@link #getTensorBuffer} and {@link
  15732. + * #getBuffer}, where the {@link android.graphics.Bitmap} will be converted into a {@link
  15733. + * TensorBuffer}.
  15734. + *
  15735. + * <p>Important: when loading a bitmap, DO NOT MODIFY the bitmap from the caller side anymore.
  15736. + * The
  15737. + * {@link TensorImage} object will rely on the bitmap. It will probably modify the bitmap as
  15738. + * well. In this method, we perform a zero-copy approach for that bitmap, by simply holding its
  15739. + * reference. Use {@code bitmap.copy(bitmap.getConfig(), true)} to create a copy if necessary.
  15740. + *
  15741. + * <p>Note: to get the best performance, please load images in the same shape to avoid memory
  15742. + * re-allocation.
  15743. + *
  15744. + * @throws IllegalArgumentException if {@code bitmap} is not in ARGB_8888
  15745. + */
  15746. + public void load(Bitmap bitmap) {
  15747. + container = BitmapContainer.create(bitmap);
  15748. + }
  15749. +
  15750. + /**
  15751. + * Loads a float array as RGB pixels into this {@link TensorImage}, representing the pixels
  15752. + * inside.
  15753. + *
  15754. + * <p>Note: if the {@link TensorImage} has a data type other than {@link DataType#FLOAT32},
  15755. + * numeric casting and clamping will be applied when calling {@link #getTensorBuffer} and {@link
  15756. + * #getBuffer}.
  15757. + *
  15758. + * @param pixels the RGB pixels representing the image
  15759. + * @param shape the shape of the image, should either in form (h, w, 3), or in form (1, h, w, 3)
  15760. + * @throws IllegalArgumentException if the shape is neither (h, w, 3) nor (1, h, w, 3)
  15761. + */
  15762. + public void load(float[] pixels, int[] shape) {
  15763. + TensorBuffer buffer = TensorBuffer.createDynamic(getDataType());
  15764. + buffer.loadArray(pixels, shape);
  15765. + load(buffer);
  15766. + }
  15767. - private final DataType dataType;
  15768. - private ImageContainer container = null;
  15769. -
  15770. - /**
  15771. - * Initializes a {@link TensorImage} object.
  15772. - *
  15773. - * <p>Note: the data type of this {@link TensorImage} is {@link DataType#UINT8}. Use {@link
  15774. - * #TensorImage(DataType)} if other data types are preferred.
  15775. - */
  15776. - public TensorImage() {
  15777. - this(DataType.UINT8);
  15778. - }
  15779. -
  15780. - /**
  15781. - * Initializes a {@link TensorImage} object with the specified data type.
  15782. - *
  15783. - * <p>When getting a {@link TensorBuffer} or a {@link ByteBuffer} from this {@link TensorImage},
  15784. - * such as using {@link #getTensorBuffer} and {@link #getBuffer}, the data values will be
  15785. - * converted to the specified data type.
  15786. - *
  15787. - * <p>Note: the shape of a {@link TensorImage} is not fixed. It can be adjusted to the shape of
  15788. - * the image being loaded to this {@link TensorImage}.
  15789. - *
  15790. - * @param dataType the expected data type of the resulting {@link TensorBuffer}. The type is
  15791. - * always fixed during the lifetime of the {@link TensorImage}. To convert the data type, use
  15792. - * {@link #createFrom(TensorImage, DataType)} to create a copy and convert data type at the
  15793. - * same time.
  15794. - * @throws IllegalArgumentException if {@code dataType} is neither {@link DataType#UINT8} nor
  15795. - * {@link DataType#FLOAT32}
  15796. - */
  15797. - public TensorImage(DataType dataType) {
  15798. - checkArgument(
  15799. - dataType == DataType.UINT8 || dataType == DataType.FLOAT32,
  15800. - "Illegal data type for TensorImage: Only FLOAT32 and UINT8 are accepted");
  15801. - this.dataType = dataType;
  15802. - }
  15803. -
  15804. - /**
  15805. - * Initializes a {@link TensorImage} object of {@link DataType#UINT8} with a {@link
  15806. - * android.graphics.Bitmap} .
  15807. - *
  15808. - * @see #load(Bitmap) for reusing the object when it's expensive to create objects frequently,
  15809. - * because every call of {@code fromBitmap} creates a new {@link TensorImage}.
  15810. - */
  15811. - public static TensorImage fromBitmap(Bitmap bitmap) {
  15812. - TensorImage image = new TensorImage();
  15813. - image.load(bitmap);
  15814. - return image;
  15815. - }
  15816. -
  15817. - /**
  15818. - * Creates a deep-copy of a given {@link TensorImage} with the desired data type.
  15819. - *
  15820. - * @param src the {@link TensorImage} to copy from
  15821. - * @param dataType the expected data type of newly created {@link TensorImage}
  15822. - * @return a {@link TensorImage} whose data is copied from {@code src} and data type is {@code
  15823. - * dataType}
  15824. - */
  15825. - public static TensorImage createFrom(TensorImage src, DataType dataType) {
  15826. - TensorImage dst = new TensorImage(dataType);
  15827. - dst.container = src.container.clone();
  15828. - return dst;
  15829. - }
  15830. -
  15831. - /**
  15832. - * Loads a {@link android.graphics.Bitmap} image object into this {@link TensorImage}.
  15833. - *
  15834. - * <p>Note: if the {@link TensorImage} has data type other than {@link DataType#UINT8}, numeric
  15835. - * casting and clamping will be applied when calling {@link #getTensorBuffer} and {@link
  15836. - * #getBuffer}, where the {@link android.graphics.Bitmap} will be converted into a {@link
  15837. - * TensorBuffer}.
  15838. - *
  15839. - * <p>Important: when loading a bitmap, DO NOT MODIFY the bitmap from the caller side anymore. The
  15840. - * {@link TensorImage} object will rely on the bitmap. It will probably modify the bitmap as well.
  15841. - * In this method, we perform a zero-copy approach for that bitmap, by simply holding its
  15842. - * reference. Use {@code bitmap.copy(bitmap.getConfig(), true)} to create a copy if necessary.
  15843. - *
  15844. - * <p>Note: to get the best performance, please load images in the same shape to avoid memory
  15845. - * re-allocation.
  15846. - *
  15847. - * @throws IllegalArgumentException if {@code bitmap} is not in ARGB_8888
  15848. - */
  15849. - public void load(Bitmap bitmap) {
  15850. - container = BitmapContainer.create(bitmap);
  15851. - }
  15852. -
  15853. - /**
  15854. - * Loads a float array as RGB pixels into this {@link TensorImage}, representing the pixels
  15855. - * inside.
  15856. - *
  15857. - * <p>Note: if the {@link TensorImage} has a data type other than {@link DataType#FLOAT32},
  15858. - * numeric casting and clamping will be applied when calling {@link #getTensorBuffer} and {@link
  15859. - * #getBuffer}.
  15860. - *
  15861. - * @param pixels the RGB pixels representing the image
  15862. - * @param shape the shape of the image, should either in form (h, w, 3), or in form (1, h, w, 3)
  15863. - * @throws IllegalArgumentException if the shape is neither (h, w, 3) nor (1, h, w, 3)
  15864. - */
  15865. - public void load(float[] pixels, int[] shape) {
  15866. - TensorBuffer buffer = TensorBuffer.createDynamic(getDataType());
  15867. - buffer.loadArray(pixels, shape);
  15868. - load(buffer);
  15869. - }
  15870. -
  15871. - /**
  15872. - * Loads an int array as RGB pixels into this {@link TensorImage}, representing the pixels inside.
  15873. - *
  15874. - * <p>Note: numeric casting and clamping will be applied to convert the values into the data type
  15875. - * of this {@link TensorImage} when calling {@link #getTensorBuffer} and {@link #getBuffer}.
  15876. - *
  15877. - * @param pixels the RGB pixels representing the image
  15878. - * @param shape the shape of the image, should either in form (h, w, 3), or in form (1, h, w, 3)
  15879. - * @throws IllegalArgumentException if the shape is neither (h, w, 3) nor (1, h, w, 3)
  15880. - */
  15881. - public void load(int[] pixels, int[] shape) {
  15882. - TensorBuffer buffer = TensorBuffer.createDynamic(getDataType());
  15883. - buffer.loadArray(pixels, shape);
  15884. - load(buffer);
  15885. - }
  15886. -
  15887. - /**
  15888. - * Loads a {@link TensorBuffer} containing pixel values. The color layout should be RGB.
  15889. - *
  15890. - * <p>Note: if the data type of {@code buffer} does not match that of this {@link TensorImage},
  15891. - * numeric casting and clamping will be applied when calling {@link #getTensorBuffer} and {@link
  15892. - * #getBuffer}.
  15893. - *
  15894. - * @param buffer the {@link TensorBuffer} to be loaded. Its shape should be either (h, w, 3) or
  15895. - * (1, h, w, 3)
  15896. - * @throws IllegalArgumentException if the shape is neither (h, w, 3) nor (1, h, w, 3)
  15897. - */
  15898. - public void load(TensorBuffer buffer) {
  15899. - load(buffer, ColorSpaceType.RGB);
  15900. - }
  15901. -
  15902. - /**
  15903. - * Loads a {@link TensorBuffer} containing pixel values with the specific {@link ColorSpaceType}.
  15904. - *
  15905. - * <p>Only supports {@link ColorSpaceType#RGB} and {@link ColorSpaceType#GRAYSCALE}. Use {@link
  15906. - * #load(TensorBuffer, ImageProperties)} for other color space types.
  15907. - *
  15908. - * <p>Note: if the data type of {@code buffer} does not match that of this {@link TensorImage},
  15909. - * numeric casting and clamping will be applied when calling {@link #getTensorBuffer} and {@link
  15910. - * #getBuffer}.
  15911. - *
  15912. - * @param buffer the {@link TensorBuffer} to be loaded. Its shape should be either (h, w, 3) or
  15913. - * (1, h, w, 3) for RGB images, and either (h, w) or (1, h, w) for GRAYSCALE images
  15914. - * @throws IllegalArgumentException if the shape of buffer does not match the color space type, or
  15915. - * if the color space type is not supported
  15916. - */
  15917. - public void load(TensorBuffer buffer, ColorSpaceType colorSpaceType) {
  15918. - checkArgument(
  15919. - colorSpaceType == ColorSpaceType.RGB || colorSpaceType == ColorSpaceType.GRAYSCALE,
  15920. - "Only ColorSpaceType.RGB and ColorSpaceType.GRAYSCALE are supported. Use"
  15921. - + " `load(TensorBuffer, ImageProperties)` for other color space types.");
  15922. -
  15923. - container = TensorBufferContainer.create(buffer, colorSpaceType);
  15924. - }
  15925. -
  15926. - /**
  15927. - * Loads a {@link TensorBuffer} containing pixel values with the specific {@link ImageProperties}.
  15928. - *
  15929. - * <p>The shape of the {@link TensorBuffer} will not be used to determine image height and width.
  15930. - * Set image properties through {@link ImageProperties}.
  15931. - *
  15932. - * <p>Note: if the data type of {@code buffer} does not match that of this {@link TensorImage},
  15933. - * numeric casting and clamping will be applied when calling {@link #getTensorBuffer} and {@link
  15934. - * #getBuffer}.
  15935. - *
  15936. - * @throws IllegalArgumentException if buffer size is less than the image size indicated by image
  15937. - * height, width, and color space type in {@link ImageProperties}
  15938. - */
  15939. - public void load(TensorBuffer buffer, ImageProperties imageProperties) {
  15940. - container = TensorBufferContainer.create(buffer, imageProperties);
  15941. - }
  15942. -
  15943. - /**
  15944. - * Loads a {@link ByteBuffer} containing pixel values with the specific {@link ImageProperties}.
  15945. - *
  15946. - * <p>Note: if the data type of {@code buffer} does not match that of this {@link TensorImage},
  15947. - * numeric casting and clamping will be applied when calling {@link #getTensorBuffer} and {@link
  15948. - * #getBuffer}.
  15949. - *
  15950. - * @throws IllegalArgumentException if buffer size is less than the image size indicated by image
  15951. - * height, width, and color space type in {@link ImageProperties}
  15952. - */
  15953. - public void load(ByteBuffer buffer, ImageProperties imageProperties) {
  15954. - TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.UINT8);
  15955. - tensorBuffer.loadBuffer(buffer, new int[] {buffer.limit()});
  15956. - container = TensorBufferContainer.create(tensorBuffer, imageProperties);
  15957. - }
  15958. -
  15959. - /**
  15960. - * Loads an {@link android.media.Image} object into this {@link TensorImage}.
  15961. - *
  15962. - * <p>The main usage of this method is to load an {@link android.media.Image} object as model
  15963. - * input to the <a href="TFLite Task
  15964. - * Library">https://www.tensorflow.org/lite/inference_with_metadata/task_library/overview</a>.
  15965. - * {@link TensorImage} backed by {@link android.media.Image} is not supported by {@link
  15966. - * ImageProcessor}.
  15967. - *
  15968. - * <p>* @throws IllegalArgumentException if the {@link android.graphics.ImageFormat} of {@code
  15969. - * image} is not YUV_420_888
  15970. - */
  15971. - public void load(Image image) {
  15972. - container = MediaImageContainer.create(image);
  15973. - }
  15974. -
  15975. - /**
  15976. - * Returns a {@link android.graphics.Bitmap} representation of this {@link TensorImage}.
  15977. - *
  15978. - * <p>Numeric casting and clamping will be applied if the stored data is not uint8.
  15979. - *
  15980. - * <p>Note that, the reliable way to get pixels from an {@code ALPHA_8} Bitmap is to use {@code
  15981. - * copyPixelsToBuffer}. Bitmap methods such as, `setPixels()` and `getPixels` do not work.
  15982. - *
  15983. - * <p>Important: it's only a reference. DO NOT MODIFY. We don't create a copy here for performance
  15984. - * concern, but if modification is necessary, please make a copy.
  15985. - *
  15986. - * @return a reference to a {@link android.graphics.Bitmap} in {@code ARGB_8888} config ("A"
  15987. - * channel is always opaque) or in {@code ALPHA_8}, depending on the {@link ColorSpaceType} of
  15988. - * this {@link TensorBuffer}.
  15989. - * @throws IllegalStateException if the {@link TensorImage} never loads data
  15990. - */
  15991. - public Bitmap getBitmap() {
  15992. - if (container == null) {
  15993. - throw new IllegalStateException("No image has been loaded yet.");
  15994. + /**
  15995. + * Loads an int array as RGB pixels into this {@link TensorImage}, representing the pixels
  15996. + * inside.
  15997. + *
  15998. + * <p>Note: numeric casting and clamping will be applied to convert the values into the data
  15999. + * type of this {@link TensorImage} when calling {@link #getTensorBuffer} and {@link
  16000. + * #getBuffer}.
  16001. + *
  16002. + * @param pixels the RGB pixels representing the image
  16003. + * @param shape the shape of the image, should either in form (h, w, 3), or in form (1, h, w, 3)
  16004. + * @throws IllegalArgumentException if the shape is neither (h, w, 3) nor (1, h, w, 3)
  16005. + */
  16006. + public void load(int[] pixels, int[] shape) {
  16007. + TensorBuffer buffer = TensorBuffer.createDynamic(getDataType());
  16008. + buffer.loadArray(pixels, shape);
  16009. + load(buffer);
  16010. }
  16011. - return container.getBitmap();
  16012. - }
  16013. -
  16014. - /**
  16015. - * Returns a {@link ByteBuffer} representation of this {@link TensorImage} with the expected data
  16016. - * type.
  16017. - *
  16018. - * <p>Numeric casting and clamping will be applied if the stored data is different from the data
  16019. - * type of the {@link TensorImage}.
  16020. - *
  16021. - * <p>Important: it's only a reference. DO NOT MODIFY. We don't create a copy here for performance
  16022. - * concern, but if modification is necessary, please make a copy.
  16023. - *
  16024. - * <p>It's essentially a short cut for {@code getTensorBuffer().getBuffer()}.
  16025. - *
  16026. - * @return a reference to a {@link ByteBuffer} which holds the image data
  16027. - * @throws IllegalStateException if the {@link TensorImage} never loads data
  16028. - */
  16029. - public ByteBuffer getBuffer() {
  16030. - return getTensorBuffer().getBuffer();
  16031. - }
  16032. -
  16033. - /**
  16034. - * Returns a {@link TensorBuffer} representation of this {@link TensorImage} with the expected
  16035. - * data type.
  16036. - *
  16037. - * <p>Numeric casting and clamping will be applied if the stored data is different from the data
  16038. - * type of the {@link TensorImage}.
  16039. - *
  16040. - * <p>Important: it's only a reference. DO NOT MODIFY. We don't create a copy here for performance
  16041. - * concern, but if modification is necessary, please make a copy.
  16042. - *
  16043. - * @return a reference to a {@link TensorBuffer} which holds the image data
  16044. - * @throws IllegalStateException if the {@link TensorImage} never loads data
  16045. - */
  16046. - public TensorBuffer getTensorBuffer() {
  16047. - if (container == null) {
  16048. - throw new IllegalStateException("No image has been loaded yet.");
  16049. + /**
  16050. + * Loads a {@link TensorBuffer} containing pixel values. The color layout should be RGB.
  16051. + *
  16052. + * <p>Note: if the data type of {@code buffer} does not match that of this {@link TensorImage},
  16053. + * numeric casting and clamping will be applied when calling {@link #getTensorBuffer} and {@link
  16054. + * #getBuffer}.
  16055. + *
  16056. + * @param buffer the {@link TensorBuffer} to be loaded. Its shape should be either (h, w, 3) or
  16057. + * (1, h, w, 3)
  16058. + * @throws IllegalArgumentException if the shape is neither (h, w, 3) nor (1, h, w, 3)
  16059. + */
  16060. + public void load(TensorBuffer buffer) {
  16061. + load(buffer, ColorSpaceType.RGB);
  16062. }
  16063. - return container.getTensorBuffer(dataType);
  16064. - }
  16065. -
  16066. - /**
  16067. - * Returns an {@link android.media.Image} representation of this {@link TensorImage}.
  16068. - *
  16069. - * <p>This method only works when the {@link TensorImage} is backed by an {@link
  16070. - * android.media.Image}, meaning you need to first load an {@link android.media.Image} through
  16071. - * {@link #load(Image)}.
  16072. - *
  16073. - * <p>Important: it's only a reference. DO NOT MODIFY. We don't create a copy here for performance
  16074. - * concern, but if modification is necessary, please make a copy.
  16075. - *
  16076. - * @return a reference to a {@link android.graphics.Bitmap} in {@code ARGB_8888} config ("A"
  16077. - * channel is always opaque) or in {@code ALPHA_8}, depending on the {@link ColorSpaceType} of
  16078. - * this {@link TensorBuffer}.
  16079. - * @throws IllegalStateException if the {@link TensorImage} never loads data
  16080. - */
  16081. - public Image getMediaImage() {
  16082. - if (container == null) {
  16083. - throw new IllegalStateException("No image has been loaded yet.");
  16084. + /**
  16085. + * Loads a {@link TensorBuffer} containing pixel values with the specific {@link
  16086. + * ColorSpaceType}.
  16087. + *
  16088. + * <p>Only supports {@link ColorSpaceType#RGB} and {@link ColorSpaceType#GRAYSCALE}. Use {@link
  16089. + * #load(TensorBuffer, ImageProperties)} for other color space types.
  16090. + *
  16091. + * <p>Note: if the data type of {@code buffer} does not match that of this {@link TensorImage},
  16092. + * numeric casting and clamping will be applied when calling {@link #getTensorBuffer} and {@link
  16093. + * #getBuffer}.
  16094. + *
  16095. + * @param buffer the {@link TensorBuffer} to be loaded. Its shape should be either (h, w, 3) or
  16096. + * (1, h, w, 3) for RGB images, and either (h, w) or (1, h, w) for GRAYSCALE images
  16097. + * @throws IllegalArgumentException if the shape of buffer does not match the color space type,
  16098. + * or
  16099. + * if the color space type is not supported
  16100. + */
  16101. + public void load(TensorBuffer buffer, ColorSpaceType colorSpaceType) {
  16102. + checkArgument(
  16103. + colorSpaceType == ColorSpaceType.RGB || colorSpaceType == ColorSpaceType.GRAYSCALE,
  16104. + "Only ColorSpaceType.RGB and ColorSpaceType.GRAYSCALE are supported. Use"
  16105. + + " `load(TensorBuffer, ImageProperties)` for other color space types.");
  16106. +
  16107. + container = TensorBufferContainer.create(buffer, colorSpaceType);
  16108. }
  16109. - return container.getMediaImage();
  16110. - }
  16111. -
  16112. - /**
  16113. - * Gets the data type of this {@link TensorImage}.
  16114. - *
  16115. - * @return a data type. Currently only {@link DataType#UINT8} and {@link DataType#FLOAT32} are
  16116. - * supported.
  16117. - */
  16118. - public DataType getDataType() {
  16119. - return dataType;
  16120. - }
  16121. -
  16122. - /**
  16123. - * Gets the color space type of this {@link TensorImage}.
  16124. - *
  16125. - * @throws IllegalStateException if the {@link TensorImage} never loads data
  16126. - */
  16127. - public ColorSpaceType getColorSpaceType() {
  16128. - if (container == null) {
  16129. - throw new IllegalStateException("No image has been loaded yet.");
  16130. + /**
  16131. + * Loads a {@link TensorBuffer} containing pixel values with the specific {@link
  16132. + * ImageProperties}.
  16133. + *
  16134. + * <p>The shape of the {@link TensorBuffer} will not be used to determine image height and
  16135. + * width. Set image properties through {@link ImageProperties}.
  16136. + *
  16137. + * <p>Note: if the data type of {@code buffer} does not match that of this {@link TensorImage},
  16138. + * numeric casting and clamping will be applied when calling {@link #getTensorBuffer} and {@link
  16139. + * #getBuffer}.
  16140. + *
  16141. + * @throws IllegalArgumentException if buffer size is less than the image size indicated by
  16142. + * image
  16143. + * height, width, and color space type in {@link ImageProperties}
  16144. + */
  16145. + public void load(TensorBuffer buffer, ImageProperties imageProperties) {
  16146. + container = TensorBufferContainer.create(buffer, imageProperties);
  16147. }
  16148. - return container.getColorSpaceType();
  16149. - }
  16150. -
  16151. - /**
  16152. - * Gets the image width.
  16153. - *
  16154. - * @throws IllegalStateException if the {@link TensorImage} never loads data
  16155. - * @throws IllegalArgumentException if the underlying data is corrupted
  16156. - */
  16157. - public int getWidth() {
  16158. - if (container == null) {
  16159. - throw new IllegalStateException("No image has been loaded yet.");
  16160. + /**
  16161. + * Loads a {@link ByteBuffer} containing pixel values with the specific {@link ImageProperties}.
  16162. + *
  16163. + * <p>Note: if the data type of {@code buffer} does not match that of this {@link TensorImage},
  16164. + * numeric casting and clamping will be applied when calling {@link #getTensorBuffer} and {@link
  16165. + * #getBuffer}.
  16166. + *
  16167. + * @throws IllegalArgumentException if buffer size is less than the image size indicated by
  16168. + * image
  16169. + * height, width, and color space type in {@link ImageProperties}
  16170. + */
  16171. + public void load(ByteBuffer buffer, ImageProperties imageProperties) {
  16172. + TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.UINT8);
  16173. + tensorBuffer.loadBuffer(buffer, new int[] {buffer.limit()});
  16174. + container = TensorBufferContainer.create(tensorBuffer, imageProperties);
  16175. }
  16176. - return container.getWidth();
  16177. - }
  16178. -
  16179. - /**
  16180. - * Gets the image height.
  16181. - *
  16182. - * @throws IllegalStateException if the {@link TensorImage} never loads data
  16183. - * @throws IllegalArgumentException if the underlying data is corrupted
  16184. - */
  16185. - public int getHeight() {
  16186. - if (container == null) {
  16187. - throw new IllegalStateException("No image has been loaded yet.");
  16188. + /**
  16189. + * Loads an {@link android.media.Image} object into this {@link TensorImage}.
  16190. + *
  16191. + * <p>The main usage of this method is to load an {@link android.media.Image} object as model
  16192. + * input to the <a href="TFLite Task
  16193. + * Library">https://www.tensorflow.org/lite/inference_with_metadata/task_library/overview</a>.
  16194. + * {@link TensorImage} backed by {@link android.media.Image} is not supported by {@link
  16195. + * ImageProcessor}.
  16196. + *
  16197. + * <p>* @throws IllegalArgumentException if the {@link android.graphics.ImageFormat} of {@code
  16198. + * image} is not YUV_420_888
  16199. + */
  16200. + public void load(Image image) {
  16201. + container = MediaImageContainer.create(image);
  16202. }
  16203. - return container.getHeight();
  16204. - }
  16205. + /**
  16206. + * Returns a {@link android.graphics.Bitmap} representation of this {@link TensorImage}.
  16207. + *
  16208. + * <p>Numeric casting and clamping will be applied if the stored data is not uint8.
  16209. + *
  16210. + * <p>Note that, the reliable way to get pixels from an {@code ALPHA_8} Bitmap is to use {@code
  16211. + * copyPixelsToBuffer}. Bitmap methods such as, `setPixels()` and `getPixels` do not work.
  16212. + *
  16213. + * <p>Important: it's only a reference. DO NOT MODIFY. We don't create a copy here for
  16214. + * performance concern, but if modification is necessary, please make a copy.
  16215. + *
  16216. + * @return a reference to a {@link android.graphics.Bitmap} in {@code ARGB_8888} config ("A"
  16217. + * channel is always opaque) or in {@code ALPHA_8}, depending on the {@link ColorSpaceType}
  16218. + * of this {@link TensorBuffer}.
  16219. + * @throws IllegalStateException if the {@link TensorImage} never loads data
  16220. + */
  16221. + public Bitmap getBitmap() {
  16222. + if (container == null) {
  16223. + throw new IllegalStateException("No image has been loaded yet.");
  16224. + }
  16225. +
  16226. + return container.getBitmap();
  16227. + }
  16228. +
  16229. + /**
  16230. + * Returns a {@link ByteBuffer} representation of this {@link TensorImage} with the expected
  16231. + * data type.
  16232. + *
  16233. + * <p>Numeric casting and clamping will be applied if the stored data is different from the data
  16234. + * type of the {@link TensorImage}.
  16235. + *
  16236. + * <p>Important: it's only a reference. DO NOT MODIFY. We don't create a copy here for
  16237. + * performance concern, but if modification is necessary, please make a copy.
  16238. + *
  16239. + * <p>It's essentially a short cut for {@code getTensorBuffer().getBuffer()}.
  16240. + *
  16241. + * @return a reference to a {@link ByteBuffer} which holds the image data
  16242. + * @throws IllegalStateException if the {@link TensorImage} never loads data
  16243. + */
  16244. + public ByteBuffer getBuffer() {
  16245. + return getTensorBuffer().getBuffer();
  16246. + }
  16247. +
  16248. + /**
  16249. + * Returns a {@link TensorBuffer} representation of this {@link TensorImage} with the expected
  16250. + * data type.
  16251. + *
  16252. + * <p>Numeric casting and clamping will be applied if the stored data is different from the data
  16253. + * type of the {@link TensorImage}.
  16254. + *
  16255. + * <p>Important: it's only a reference. DO NOT MODIFY. We don't create a copy here for
  16256. + * performance concern, but if modification is necessary, please make a copy.
  16257. + *
  16258. + * @return a reference to a {@link TensorBuffer} which holds the image data
  16259. + * @throws IllegalStateException if the {@link TensorImage} never loads data
  16260. + */
  16261. + public TensorBuffer getTensorBuffer() {
  16262. + if (container == null) {
  16263. + throw new IllegalStateException("No image has been loaded yet.");
  16264. + }
  16265. +
  16266. + return container.getTensorBuffer(dataType);
  16267. + }
  16268. +
  16269. + /**
  16270. + * Returns an {@link android.media.Image} representation of this {@link TensorImage}.
  16271. + *
  16272. + * <p>This method only works when the {@link TensorImage} is backed by an {@link
  16273. + * android.media.Image}, meaning you need to first load an {@link android.media.Image} through
  16274. + * {@link #load(Image)}.
  16275. + *
  16276. + * <p>Important: it's only a reference. DO NOT MODIFY. We don't create a copy here for
  16277. + * performance concern, but if modification is necessary, please make a copy.
  16278. + *
  16279. + * @return a reference to a {@link android.graphics.Bitmap} in {@code ARGB_8888} config ("A"
  16280. + * channel is always opaque) or in {@code ALPHA_8}, depending on the {@link ColorSpaceType}
  16281. + * of this {@link TensorBuffer}.
  16282. + * @throws IllegalStateException if the {@link TensorImage} never loads data
  16283. + */
  16284. + public Image getMediaImage() {
  16285. + if (container == null) {
  16286. + throw new IllegalStateException("No image has been loaded yet.");
  16287. + }
  16288. +
  16289. + return container.getMediaImage();
  16290. + }
  16291. +
  16292. + /**
  16293. + * Gets the data type of this {@link TensorImage}.
  16294. + *
  16295. + * @return a data type. Currently only {@link DataType#UINT8} and {@link DataType#FLOAT32} are
  16296. + * supported.
  16297. + */
  16298. + public DataType getDataType() {
  16299. + return dataType;
  16300. + }
  16301. +
  16302. + /**
  16303. + * Gets the color space type of this {@link TensorImage}.
  16304. + *
  16305. + * @throws IllegalStateException if the {@link TensorImage} never loads data
  16306. + */
  16307. + public ColorSpaceType getColorSpaceType() {
  16308. + if (container == null) {
  16309. + throw new IllegalStateException("No image has been loaded yet.");
  16310. + }
  16311. +
  16312. + return container.getColorSpaceType();
  16313. + }
  16314. +
  16315. + /**
  16316. + * Gets the image width.
  16317. + *
  16318. + * @throws IllegalStateException if the {@link TensorImage} never loads data
  16319. + * @throws IllegalArgumentException if the underlying data is corrupted
  16320. + */
  16321. + public int getWidth() {
  16322. + if (container == null) {
  16323. + throw new IllegalStateException("No image has been loaded yet.");
  16324. + }
  16325. +
  16326. + return container.getWidth();
  16327. + }
  16328. +
  16329. + /**
  16330. + * Gets the image height.
  16331. + *
  16332. + * @throws IllegalStateException if the {@link TensorImage} never loads data
  16333. + * @throws IllegalArgumentException if the underlying data is corrupted
  16334. + */
  16335. + public int getHeight() {
  16336. + if (container == null) {
  16337. + throw new IllegalStateException("No image has been loaded yet.");
  16338. + }
  16339. +
  16340. + return container.getHeight();
  16341. + }
  16342. }
  16343. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/ResizeOp.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/ResizeOp.java
  16344. index 06391de9cc3e0..adccf23dc97f0 100644
  16345. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/ResizeOp.java
  16346. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/ResizeOp.java
  16347. @@ -19,6 +19,7 @@ import static org.tensorflow.lite.support.common.internal.SupportPreconditions.c
  16348. import android.graphics.Bitmap;
  16349. import android.graphics.PointF;
  16350. +
  16351. import org.checkerframework.checker.nullness.qual.NonNull;
  16352. import org.tensorflow.lite.support.image.ColorSpaceType;
  16353. import org.tensorflow.lite.support.image.ImageOperator;
  16354. @@ -32,64 +33,60 @@ import org.tensorflow.lite.support.image.TensorImage;
  16355. * @see ResizeWithCropOrPadOp for resizing without content distortion.
  16356. */
  16357. public class ResizeOp implements ImageOperator {
  16358. + /** Algorithms for resizing. */
  16359. + public enum ResizeMethod { BILINEAR, NEAREST_NEIGHBOR }
  16360. - /** Algorithms for resizing. */
  16361. - public enum ResizeMethod {
  16362. - BILINEAR,
  16363. - NEAREST_NEIGHBOR
  16364. - }
  16365. -
  16366. - private final int targetHeight;
  16367. - private final int targetWidth;
  16368. - private final boolean useBilinear;
  16369. + private final int targetHeight;
  16370. + private final int targetWidth;
  16371. + private final boolean useBilinear;
  16372. - /**
  16373. - * Creates a ResizeOp which can resize images to specified size in specified method.
  16374. - *
  16375. - * @param targetHeight The expected height of resized image.
  16376. - * @param targetWidth The expected width of resized image.
  16377. - * @param resizeMethod The algorithm to use for resizing. Options: {@link ResizeMethod}
  16378. - */
  16379. - public ResizeOp(int targetHeight, int targetWidth, ResizeMethod resizeMethod) {
  16380. - this.targetHeight = targetHeight;
  16381. - this.targetWidth = targetWidth;
  16382. - useBilinear = (resizeMethod == ResizeMethod.BILINEAR);
  16383. - }
  16384. + /**
  16385. + * Creates a ResizeOp which can resize images to specified size in specified method.
  16386. + *
  16387. + * @param targetHeight The expected height of resized image.
  16388. + * @param targetWidth The expected width of resized image.
  16389. + * @param resizeMethod The algorithm to use for resizing. Options: {@link ResizeMethod}
  16390. + */
  16391. + public ResizeOp(int targetHeight, int targetWidth, ResizeMethod resizeMethod) {
  16392. + this.targetHeight = targetHeight;
  16393. + this.targetWidth = targetWidth;
  16394. + useBilinear = (resizeMethod == ResizeMethod.BILINEAR);
  16395. + }
  16396. - /**
  16397. - * Applies the defined resizing on given image and returns the result.
  16398. - *
  16399. - * <p>Note: the content of input {@code image} will change, and {@code image} is the same instance
  16400. - * with the output.
  16401. - *
  16402. - * @param image input image.
  16403. - * @return output image.
  16404. - */
  16405. - @Override
  16406. - @NonNull
  16407. - public TensorImage apply(@NonNull TensorImage image) {
  16408. - checkArgument(
  16409. - image.getColorSpaceType() == ColorSpaceType.RGB,
  16410. - "Only RGB images are supported in ResizeOp, but not " + image.getColorSpaceType().name());
  16411. - Bitmap scaled =
  16412. - Bitmap.createScaledBitmap(image.getBitmap(), targetWidth, targetHeight, useBilinear);
  16413. - image.load(scaled);
  16414. - return image;
  16415. - }
  16416. + /**
  16417. + * Applies the defined resizing on given image and returns the result.
  16418. + *
  16419. + * <p>Note: the content of input {@code image} will change, and {@code image} is the same
  16420. + * instance with the output.
  16421. + *
  16422. + * @param image input image.
  16423. + * @return output image.
  16424. + */
  16425. + @Override
  16426. + @NonNull
  16427. + public TensorImage apply(@NonNull TensorImage image) {
  16428. + checkArgument(image.getColorSpaceType() == ColorSpaceType.RGB,
  16429. + "Only RGB images are supported in ResizeOp, but not "
  16430. + + image.getColorSpaceType().name());
  16431. + Bitmap scaled = Bitmap.createScaledBitmap(
  16432. + image.getBitmap(), targetWidth, targetHeight, useBilinear);
  16433. + image.load(scaled);
  16434. + return image;
  16435. + }
  16436. - @Override
  16437. - public int getOutputImageHeight(int inputImageHeight, int inputImageWidth) {
  16438. - return targetHeight;
  16439. - }
  16440. + @Override
  16441. + public int getOutputImageHeight(int inputImageHeight, int inputImageWidth) {
  16442. + return targetHeight;
  16443. + }
  16444. - @Override
  16445. - public int getOutputImageWidth(int inputImageHeight, int inputImageWidth) {
  16446. - return targetWidth;
  16447. - }
  16448. + @Override
  16449. + public int getOutputImageWidth(int inputImageHeight, int inputImageWidth) {
  16450. + return targetWidth;
  16451. + }
  16452. - @Override
  16453. - public PointF inverseTransform(PointF point, int inputImageHeight, int inputImageWidth) {
  16454. - return new PointF(
  16455. - point.x * inputImageWidth / targetWidth, point.y * inputImageHeight / targetHeight);
  16456. - }
  16457. + @Override
  16458. + public PointF inverseTransform(PointF point, int inputImageHeight, int inputImageWidth) {
  16459. + return new PointF(
  16460. + point.x * inputImageWidth / targetWidth, point.y * inputImageHeight / targetHeight);
  16461. + }
  16462. }
  16463. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/ResizeWithCropOrPadOp.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/ResizeWithCropOrPadOp.java
  16464. index 66491090ac9c0..e5de5bbcf50d9 100644
  16465. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/ResizeWithCropOrPadOp.java
  16466. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/ResizeWithCropOrPadOp.java
  16467. @@ -22,6 +22,7 @@ import android.graphics.Bitmap.Config;
  16468. import android.graphics.Canvas;
  16469. import android.graphics.PointF;
  16470. import android.graphics.Rect;
  16471. +
  16472. import org.checkerframework.checker.nullness.qual.NonNull;
  16473. import org.tensorflow.lite.support.image.ColorSpaceType;
  16474. import org.tensorflow.lite.support.image.ImageOperator;
  16475. @@ -37,96 +38,95 @@ import org.tensorflow.lite.support.image.TensorImage;
  16476. * @see ResizeOp for reszing images while stretching / compressing the content.
  16477. */
  16478. public class ResizeWithCropOrPadOp implements ImageOperator {
  16479. - private final int targetHeight;
  16480. - private final int targetWidth;
  16481. - private final Bitmap output;
  16482. -
  16483. - /**
  16484. - * Creates a ResizeWithCropOrPadOp which could crop/pad images to specified size. It adopts
  16485. - * center-crop and zero-padding.
  16486. - *
  16487. - * @param targetHeight The expected height of cropped/padded image.
  16488. - * @param targetWidth The expected width of cropped/padded image.
  16489. - */
  16490. - public ResizeWithCropOrPadOp(int targetHeight, int targetWidth) {
  16491. - this.targetHeight = targetHeight;
  16492. - this.targetWidth = targetWidth;
  16493. - output = Bitmap.createBitmap(this.targetWidth, this.targetHeight, Config.ARGB_8888);
  16494. - }
  16495. + private final int targetHeight;
  16496. + private final int targetWidth;
  16497. + private final Bitmap output;
  16498. - /**
  16499. - * Applies the defined resizing with cropping or/and padding on given image and returns the
  16500. - * result.
  16501. - *
  16502. - * <p>Note: the content of input {@code image} will change, and {@code image} is the same instance
  16503. - * with the output.
  16504. - *
  16505. - * @param image input image.
  16506. - * @return output image.
  16507. - */
  16508. - @Override
  16509. - @NonNull
  16510. - public TensorImage apply(@NonNull TensorImage image) {
  16511. - checkArgument(
  16512. - image.getColorSpaceType() == ColorSpaceType.RGB,
  16513. - "Only RGB images are supported in ResizeWithCropOrPadOp, but not "
  16514. - + image.getColorSpaceType().name());
  16515. - Bitmap input = image.getBitmap();
  16516. - int srcL;
  16517. - int srcR;
  16518. - int srcT;
  16519. - int srcB;
  16520. - int dstL;
  16521. - int dstR;
  16522. - int dstT;
  16523. - int dstB;
  16524. - int w = input.getWidth();
  16525. - int h = input.getHeight();
  16526. - if (targetWidth > w) { // padding
  16527. - srcL = 0;
  16528. - srcR = w;
  16529. - dstL = (targetWidth - w) / 2;
  16530. - dstR = dstL + w;
  16531. - } else { // cropping
  16532. - dstL = 0;
  16533. - dstR = targetWidth;
  16534. - srcL = (w - targetWidth) / 2;
  16535. - srcR = srcL + targetWidth;
  16536. + /**
  16537. + * Creates a ResizeWithCropOrPadOp which could crop/pad images to specified size. It adopts
  16538. + * center-crop and zero-padding.
  16539. + *
  16540. + * @param targetHeight The expected height of cropped/padded image.
  16541. + * @param targetWidth The expected width of cropped/padded image.
  16542. + */
  16543. + public ResizeWithCropOrPadOp(int targetHeight, int targetWidth) {
  16544. + this.targetHeight = targetHeight;
  16545. + this.targetWidth = targetWidth;
  16546. + output = Bitmap.createBitmap(this.targetWidth, this.targetHeight, Config.ARGB_8888);
  16547. }
  16548. - if (targetHeight > h) { // padding
  16549. - srcT = 0;
  16550. - srcB = h;
  16551. - dstT = (targetHeight - h) / 2;
  16552. - dstB = dstT + h;
  16553. - } else { // cropping
  16554. - dstT = 0;
  16555. - dstB = targetHeight;
  16556. - srcT = (h - targetHeight) / 2;
  16557. - srcB = srcT + targetHeight;
  16558. +
  16559. + /**
  16560. + * Applies the defined resizing with cropping or/and padding on given image and returns the
  16561. + * result.
  16562. + *
  16563. + * <p>Note: the content of input {@code image} will change, and {@code image} is the same
  16564. + * instance with the output.
  16565. + *
  16566. + * @param image input image.
  16567. + * @return output image.
  16568. + */
  16569. + @Override
  16570. + @NonNull
  16571. + public TensorImage apply(@NonNull TensorImage image) {
  16572. + checkArgument(image.getColorSpaceType() == ColorSpaceType.RGB,
  16573. + "Only RGB images are supported in ResizeWithCropOrPadOp, but not "
  16574. + + image.getColorSpaceType().name());
  16575. + Bitmap input = image.getBitmap();
  16576. + int srcL;
  16577. + int srcR;
  16578. + int srcT;
  16579. + int srcB;
  16580. + int dstL;
  16581. + int dstR;
  16582. + int dstT;
  16583. + int dstB;
  16584. + int w = input.getWidth();
  16585. + int h = input.getHeight();
  16586. + if (targetWidth > w) { // padding
  16587. + srcL = 0;
  16588. + srcR = w;
  16589. + dstL = (targetWidth - w) / 2;
  16590. + dstR = dstL + w;
  16591. + } else { // cropping
  16592. + dstL = 0;
  16593. + dstR = targetWidth;
  16594. + srcL = (w - targetWidth) / 2;
  16595. + srcR = srcL + targetWidth;
  16596. + }
  16597. + if (targetHeight > h) { // padding
  16598. + srcT = 0;
  16599. + srcB = h;
  16600. + dstT = (targetHeight - h) / 2;
  16601. + dstB = dstT + h;
  16602. + } else { // cropping
  16603. + dstT = 0;
  16604. + dstB = targetHeight;
  16605. + srcT = (h - targetHeight) / 2;
  16606. + srcB = srcT + targetHeight;
  16607. + }
  16608. + Rect src = new Rect(srcL, srcT, srcR, srcB);
  16609. + Rect dst = new Rect(dstL, dstT, dstR, dstB);
  16610. + new Canvas(output).drawBitmap(input, src, dst, null);
  16611. + image.load(output);
  16612. + return image;
  16613. }
  16614. - Rect src = new Rect(srcL, srcT, srcR, srcB);
  16615. - Rect dst = new Rect(dstL, dstT, dstR, dstB);
  16616. - new Canvas(output).drawBitmap(input, src, dst, null);
  16617. - image.load(output);
  16618. - return image;
  16619. - }
  16620. - @Override
  16621. - public int getOutputImageHeight(int inputImageHeight, int inputImageWidth) {
  16622. - return targetHeight;
  16623. - }
  16624. + @Override
  16625. + public int getOutputImageHeight(int inputImageHeight, int inputImageWidth) {
  16626. + return targetHeight;
  16627. + }
  16628. - @Override
  16629. - public int getOutputImageWidth(int inputImageHeight, int inputImageWidth) {
  16630. - return targetWidth;
  16631. - }
  16632. + @Override
  16633. + public int getOutputImageWidth(int inputImageHeight, int inputImageWidth) {
  16634. + return targetWidth;
  16635. + }
  16636. - @Override
  16637. - public PointF inverseTransform(PointF point, int inputImageHeight, int inputImageWidth) {
  16638. - return transformImpl(point, targetHeight, targetWidth, inputImageHeight, inputImageWidth);
  16639. - }
  16640. + @Override
  16641. + public PointF inverseTransform(PointF point, int inputImageHeight, int inputImageWidth) {
  16642. + return transformImpl(point, targetHeight, targetWidth, inputImageHeight, inputImageWidth);
  16643. + }
  16644. - private static PointF transformImpl(PointF point, int srcH, int srcW, int dstH, int dstW) {
  16645. - return new PointF(point.x + (dstW - srcW) / 2, point.y + (dstH - srcH) / 2);
  16646. - }
  16647. + private static PointF transformImpl(PointF point, int srcH, int srcW, int dstH, int dstW) {
  16648. + return new PointF(point.x + (dstW - srcW) / 2, point.y + (dstH - srcH) / 2);
  16649. + }
  16650. }
  16651. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/Rot90Op.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/Rot90Op.java
  16652. index 849b4bc9ef3db..86413c90c69ca 100644
  16653. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/Rot90Op.java
  16654. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/Rot90Op.java
  16655. @@ -20,6 +20,7 @@ import static org.tensorflow.lite.support.common.internal.SupportPreconditions.c
  16656. import android.graphics.Bitmap;
  16657. import android.graphics.Matrix;
  16658. import android.graphics.PointF;
  16659. +
  16660. import org.checkerframework.checker.nullness.qual.NonNull;
  16661. import org.tensorflow.lite.support.image.ColorSpaceType;
  16662. import org.tensorflow.lite.support.image.ImageOperator;
  16663. @@ -27,83 +28,83 @@ import org.tensorflow.lite.support.image.TensorImage;
  16664. /** Rotates image counter-clockwise. */
  16665. public class Rot90Op implements ImageOperator {
  16666. + private final int numRotation;
  16667. - private final int numRotation;
  16668. -
  16669. - /** Creates a Rot90 Op which will rotate image by 90 degree counter-clockwise. */
  16670. - public Rot90Op() {
  16671. - this(1);
  16672. - }
  16673. + /** Creates a Rot90 Op which will rotate image by 90 degree counter-clockwise. */
  16674. + public Rot90Op() {
  16675. + this(1);
  16676. + }
  16677. - /**
  16678. - * Creates a Rot90 Op which will rotate image by 90 degree for {@code k} times counter-clockwise.
  16679. - *
  16680. - * @param k The number of times the image is rotated by 90 degrees. If it's positive, the image
  16681. - * will be rotated counter-clockwise. If it's negative, the op will rotate image clockwise.
  16682. - */
  16683. - public Rot90Op(int k) {
  16684. - numRotation = k % 4;
  16685. - }
  16686. + /**
  16687. + * Creates a Rot90 Op which will rotate image by 90 degree for {@code k} times
  16688. + * counter-clockwise.
  16689. + *
  16690. + * @param k The number of times the image is rotated by 90 degrees. If it's positive, the image
  16691. + * will be rotated counter-clockwise. If it's negative, the op will rotate image clockwise.
  16692. + */
  16693. + public Rot90Op(int k) {
  16694. + numRotation = k % 4;
  16695. + }
  16696. - /**
  16697. - * Applies the defined rotation on given image and returns the result.
  16698. - *
  16699. - * <p>Note: the content of input {@code image} will change, and {@code image} is the same instance
  16700. - * with the output.
  16701. - *
  16702. - * @param image input image.
  16703. - * @return output image.
  16704. - */
  16705. - @NonNull
  16706. - @Override
  16707. - public TensorImage apply(@NonNull TensorImage image) {
  16708. - checkArgument(
  16709. - image.getColorSpaceType() == ColorSpaceType.RGB,
  16710. - "Only RGB images are supported in Rot90Op, but not " + image.getColorSpaceType().name());
  16711. - Bitmap input = image.getBitmap();
  16712. - if (numRotation == 0) {
  16713. - return image;
  16714. + /**
  16715. + * Applies the defined rotation on given image and returns the result.
  16716. + *
  16717. + * <p>Note: the content of input {@code image} will change, and {@code image} is the same
  16718. + * instance with the output.
  16719. + *
  16720. + * @param image input image.
  16721. + * @return output image.
  16722. + */
  16723. + @NonNull
  16724. + @Override
  16725. + public TensorImage apply(@NonNull TensorImage image) {
  16726. + checkArgument(image.getColorSpaceType() == ColorSpaceType.RGB,
  16727. + "Only RGB images are supported in Rot90Op, but not "
  16728. + + image.getColorSpaceType().name());
  16729. + Bitmap input = image.getBitmap();
  16730. + if (numRotation == 0) {
  16731. + return image;
  16732. + }
  16733. + int w = input.getWidth();
  16734. + int h = input.getHeight();
  16735. + Matrix matrix = new Matrix();
  16736. + matrix.postTranslate(w * 0.5f, h * 0.5f);
  16737. + matrix.postRotate(-90 * numRotation);
  16738. + int newW = (numRotation % 2 == 0) ? w : h;
  16739. + int newH = (numRotation % 2 == 0) ? h : w;
  16740. + matrix.postTranslate(newW * 0.5f, newH * 0.5f);
  16741. + Bitmap output = Bitmap.createBitmap(input, 0, 0, w, h, matrix, false);
  16742. + image.load(output);
  16743. + return image;
  16744. }
  16745. - int w = input.getWidth();
  16746. - int h = input.getHeight();
  16747. - Matrix matrix = new Matrix();
  16748. - matrix.postTranslate(w * 0.5f, h * 0.5f);
  16749. - matrix.postRotate(-90 * numRotation);
  16750. - int newW = (numRotation % 2 == 0) ? w : h;
  16751. - int newH = (numRotation % 2 == 0) ? h : w;
  16752. - matrix.postTranslate(newW * 0.5f, newH * 0.5f);
  16753. - Bitmap output = Bitmap.createBitmap(input, 0, 0, w, h, matrix, false);
  16754. - image.load(output);
  16755. - return image;
  16756. - }
  16757. - @Override
  16758. - public int getOutputImageHeight(int inputImageHeight, int inputImageWidth) {
  16759. - return (numRotation % 2 == 0) ? inputImageHeight : inputImageWidth;
  16760. - }
  16761. + @Override
  16762. + public int getOutputImageHeight(int inputImageHeight, int inputImageWidth) {
  16763. + return (numRotation % 2 == 0) ? inputImageHeight : inputImageWidth;
  16764. + }
  16765. - @Override
  16766. - public int getOutputImageWidth(int inputImageHeight, int inputImageWidth) {
  16767. - return (numRotation % 2 == 0) ? inputImageWidth : inputImageHeight;
  16768. - }
  16769. + @Override
  16770. + public int getOutputImageWidth(int inputImageHeight, int inputImageWidth) {
  16771. + return (numRotation % 2 == 0) ? inputImageWidth : inputImageHeight;
  16772. + }
  16773. - @Override
  16774. - public PointF inverseTransform(PointF point, int inputImageHeight, int inputImageWidth) {
  16775. - int inverseNumRotation = (4 - numRotation) % 4;
  16776. - int height = getOutputImageHeight(inputImageHeight, inputImageWidth);
  16777. - int width = getOutputImageWidth(inputImageHeight, inputImageWidth);
  16778. - return transformImpl(point, height, width, inverseNumRotation);
  16779. - }
  16780. + @Override
  16781. + public PointF inverseTransform(PointF point, int inputImageHeight, int inputImageWidth) {
  16782. + int inverseNumRotation = (4 - numRotation) % 4;
  16783. + int height = getOutputImageHeight(inputImageHeight, inputImageWidth);
  16784. + int width = getOutputImageWidth(inputImageHeight, inputImageWidth);
  16785. + return transformImpl(point, height, width, inverseNumRotation);
  16786. + }
  16787. - private static PointF transformImpl(PointF point, int height, int width, int numRotation) {
  16788. - if (numRotation == 0) {
  16789. - return point;
  16790. - } else if (numRotation == 1) {
  16791. - return new PointF(point.y, width - point.x);
  16792. - } else if (numRotation == 2) {
  16793. - return new PointF(width - point.x, height - point.y);
  16794. - } else { // numRotation == 3
  16795. - return new PointF(height - point.y, point.x);
  16796. + private static PointF transformImpl(PointF point, int height, int width, int numRotation) {
  16797. + if (numRotation == 0) {
  16798. + return point;
  16799. + } else if (numRotation == 1) {
  16800. + return new PointF(point.y, width - point.x);
  16801. + } else if (numRotation == 2) {
  16802. + return new PointF(width - point.x, height - point.y);
  16803. + } else { // numRotation == 3
  16804. + return new PointF(height - point.y, point.x);
  16805. + }
  16806. }
  16807. - }
  16808. }
  16809. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/TensorOperatorWrapper.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/TensorOperatorWrapper.java
  16810. index 5d10ac890e57b..feb2b3b7b0762 100644
  16811. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/TensorOperatorWrapper.java
  16812. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/TensorOperatorWrapper.java
  16813. @@ -16,6 +16,7 @@ limitations under the License.
  16814. package org.tensorflow.lite.support.image.ops;
  16815. import android.graphics.PointF;
  16816. +
  16817. import org.checkerframework.checker.nullness.qual.NonNull;
  16818. import org.tensorflow.lite.support.common.TensorOperator;
  16819. import org.tensorflow.lite.support.common.internal.SupportPreconditions;
  16820. @@ -31,48 +32,47 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
  16821. * @see org.tensorflow.lite.support.image.TensorImage
  16822. */
  16823. public class TensorOperatorWrapper implements ImageOperator {
  16824. + private final TensorOperator tensorOp;
  16825. - private final TensorOperator tensorOp;
  16826. -
  16827. - /**
  16828. - * Wraps a {@link TensorOperator} object as an {@link ImageOperator}, so that the {@link
  16829. - * TensorOperator} could handle {@link TensorImage} objects by handling its underlying {@link
  16830. - * org.tensorflow.lite.support.tensorbuffer.TensorBuffer}.
  16831. - *
  16832. - * <p>Requirement: The {@code op} should not change coordinate system when applied on an image.
  16833. - *
  16834. - * @param op The created operator.
  16835. - */
  16836. - public TensorOperatorWrapper(TensorOperator op) {
  16837. - tensorOp = op;
  16838. - }
  16839. + /**
  16840. + * Wraps a {@link TensorOperator} object as an {@link ImageOperator}, so that the {@link
  16841. + * TensorOperator} could handle {@link TensorImage} objects by handling its underlying {@link
  16842. + * org.tensorflow.lite.support.tensorbuffer.TensorBuffer}.
  16843. + *
  16844. + * <p>Requirement: The {@code op} should not change coordinate system when applied on an image.
  16845. + *
  16846. + * @param op The created operator.
  16847. + */
  16848. + public TensorOperatorWrapper(TensorOperator op) {
  16849. + tensorOp = op;
  16850. + }
  16851. - @Override
  16852. - @NonNull
  16853. - public TensorImage apply(@NonNull TensorImage image) {
  16854. - SupportPreconditions.checkNotNull(image, "Op cannot apply on null image.");
  16855. - TensorBuffer resBuffer = tensorOp.apply(image.getTensorBuffer());
  16856. - // Some ops may change the data type of the underlying TensorBuffer, such as CastOp. Therefore,
  16857. - // need to create a new TensorImage with the correct data type.
  16858. - // However the underlying ops should not touch the color type.
  16859. - ColorSpaceType colorSpaceType = image.getColorSpaceType();
  16860. - TensorImage resImage = new TensorImage(resBuffer.getDataType());
  16861. - resImage.load(resBuffer, colorSpaceType);
  16862. - return resImage;
  16863. - }
  16864. + @Override
  16865. + @NonNull
  16866. + public TensorImage apply(@NonNull TensorImage image) {
  16867. + SupportPreconditions.checkNotNull(image, "Op cannot apply on null image.");
  16868. + TensorBuffer resBuffer = tensorOp.apply(image.getTensorBuffer());
  16869. + // Some ops may change the data type of the underlying TensorBuffer, such as CastOp.
  16870. + // Therefore, need to create a new TensorImage with the correct data type. However the
  16871. + // underlying ops should not touch the color type.
  16872. + ColorSpaceType colorSpaceType = image.getColorSpaceType();
  16873. + TensorImage resImage = new TensorImage(resBuffer.getDataType());
  16874. + resImage.load(resBuffer, colorSpaceType);
  16875. + return resImage;
  16876. + }
  16877. - @Override
  16878. - public int getOutputImageHeight(int inputImageHeight, int inputImageWidth) {
  16879. - return inputImageHeight;
  16880. - }
  16881. + @Override
  16882. + public int getOutputImageHeight(int inputImageHeight, int inputImageWidth) {
  16883. + return inputImageHeight;
  16884. + }
  16885. - @Override
  16886. - public int getOutputImageWidth(int inputImageHeight, int inputImageWidth) {
  16887. - return inputImageWidth;
  16888. - }
  16889. + @Override
  16890. + public int getOutputImageWidth(int inputImageHeight, int inputImageWidth) {
  16891. + return inputImageWidth;
  16892. + }
  16893. - @Override
  16894. - public PointF inverseTransform(PointF point, int inputImageHeight, int inputImageWidth) {
  16895. - return point;
  16896. - }
  16897. + @Override
  16898. + public PointF inverseTransform(PointF point, int inputImageHeight, int inputImageWidth) {
  16899. + return point;
  16900. + }
  16901. }
  16902. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/TransformToGrayscaleOp.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/TransformToGrayscaleOp.java
  16903. index bd3c10b254ac5..1a6f905b1bffd 100644
  16904. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/TransformToGrayscaleOp.java
  16905. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/TransformToGrayscaleOp.java
  16906. @@ -23,6 +23,7 @@ import android.graphics.ColorFilter;
  16907. import android.graphics.ColorMatrixColorFilter;
  16908. import android.graphics.Paint;
  16909. import android.graphics.PointF;
  16910. +
  16911. import org.tensorflow.lite.support.image.ColorSpaceType;
  16912. import org.tensorflow.lite.support.image.ImageOperator;
  16913. import org.tensorflow.lite.support.image.TensorImage;
  16914. @@ -41,77 +42,73 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
  16915. * https://docs.opencv.org/master/de/d25/imgproc_color_conversions.html#color_convert_rgb_gray
  16916. */
  16917. public class TransformToGrayscaleOp implements ImageOperator {
  16918. + // A matrix is created that will be applied later to canvas to generate grayscale image
  16919. + // The luminance of each pixel is calculated as the weighted sum of the 3 RGB values
  16920. + // Y = 0.299R + 0.587G + 0.114B
  16921. + private static final float[] BITMAP_RGBA_GRAYSCALE_TRANSFORMATION =
  16922. + new float[] {0.299F, 0.587F, 0.114F, 0.0F, 0.0F, 0.0F, 0.0F, 0.0F, 0.0F, 0.0F, 0.0F,
  16923. + 0.0F, 0.0F, 0.0F, 0.0F, 0.0F, 0.0F, 0.0F, 1.0F, 0.0F};
  16924. - // A matrix is created that will be applied later to canvas to generate grayscale image
  16925. - // The luminance of each pixel is calculated as the weighted sum of the 3 RGB values
  16926. - // Y = 0.299R + 0.587G + 0.114B
  16927. - private static final float[] BITMAP_RGBA_GRAYSCALE_TRANSFORMATION =
  16928. - new float[] {
  16929. - 0.299F, 0.587F, 0.114F, 0.0F, 0.0F, 0.0F, 0.0F, 0.0F, 0.0F, 0.0F, 0.0F, 0.0F, 0.0F, 0.0F,
  16930. - 0.0F, 0.0F, 0.0F, 0.0F, 1.0F, 0.0F
  16931. - };
  16932. -
  16933. - /** Creates a TransformToGrayscaleOp. */
  16934. - public TransformToGrayscaleOp() {}
  16935. + /** Creates a TransformToGrayscaleOp. */
  16936. + public TransformToGrayscaleOp() {}
  16937. - /**
  16938. - * Applies the transformation to grayscale and returns a {@link TensorImage}.
  16939. - *
  16940. - * <p>If the input image is already {@link
  16941. - * org.tensorflow.lite.support.image.ColorSpaceType#GRAYSCALE}, this op will be a no-op.
  16942. - *
  16943. - * @throws IllegalArgumentException if the {@code image} is not {@link
  16944. - * org.tensorflow.lite.support.image.ColorSpaceType#RGB} or {@link
  16945. - * org.tensorflow.lite.support.image.ColorSpaceType#GRAYSCALE}.
  16946. - */
  16947. - @Override
  16948. - public TensorImage apply(TensorImage image) {
  16949. - if (image.getColorSpaceType() == ColorSpaceType.GRAYSCALE) {
  16950. - return image;
  16951. - } else {
  16952. - checkArgument(
  16953. - image.getColorSpaceType() == ColorSpaceType.RGB,
  16954. - "Only RGB images are supported in TransformToGrayscaleOp, but not "
  16955. - + image.getColorSpaceType().name());
  16956. - }
  16957. - int h = image.getHeight();
  16958. - int w = image.getWidth();
  16959. - Bitmap bmpGrayscale = Bitmap.createBitmap(w, h, Bitmap.Config.ARGB_8888);
  16960. - Canvas canvas = new Canvas(bmpGrayscale);
  16961. - Paint paint = new Paint();
  16962. - ColorMatrixColorFilter colorMatrixFilter =
  16963. - new ColorMatrixColorFilter(BITMAP_RGBA_GRAYSCALE_TRANSFORMATION);
  16964. - paint.setColorFilter((ColorFilter) colorMatrixFilter);
  16965. - canvas.drawBitmap(image.getBitmap(), 0.0F, 0.0F, paint);
  16966. + /**
  16967. + * Applies the transformation to grayscale and returns a {@link TensorImage}.
  16968. + *
  16969. + * <p>If the input image is already {@link
  16970. + * org.tensorflow.lite.support.image.ColorSpaceType#GRAYSCALE}, this op will be a no-op.
  16971. + *
  16972. + * @throws IllegalArgumentException if the {@code image} is not {@link
  16973. + * org.tensorflow.lite.support.image.ColorSpaceType#RGB} or {@link
  16974. + * org.tensorflow.lite.support.image.ColorSpaceType#GRAYSCALE}.
  16975. + */
  16976. + @Override
  16977. + public TensorImage apply(TensorImage image) {
  16978. + if (image.getColorSpaceType() == ColorSpaceType.GRAYSCALE) {
  16979. + return image;
  16980. + } else {
  16981. + checkArgument(image.getColorSpaceType() == ColorSpaceType.RGB,
  16982. + "Only RGB images are supported in TransformToGrayscaleOp, but not "
  16983. + + image.getColorSpaceType().name());
  16984. + }
  16985. + int h = image.getHeight();
  16986. + int w = image.getWidth();
  16987. + Bitmap bmpGrayscale = Bitmap.createBitmap(w, h, Bitmap.Config.ARGB_8888);
  16988. + Canvas canvas = new Canvas(bmpGrayscale);
  16989. + Paint paint = new Paint();
  16990. + ColorMatrixColorFilter colorMatrixFilter =
  16991. + new ColorMatrixColorFilter(BITMAP_RGBA_GRAYSCALE_TRANSFORMATION);
  16992. + paint.setColorFilter((ColorFilter) colorMatrixFilter);
  16993. + canvas.drawBitmap(image.getBitmap(), 0.0F, 0.0F, paint);
  16994. - // Get the pixels from the generated grayscale image
  16995. - int[] intValues = new int[w * h];
  16996. - bmpGrayscale.getPixels(intValues, 0, w, 0, 0, w, h);
  16997. - // Shape with one channel
  16998. - int[] shape = new int[] {1, h, w, 1};
  16999. + // Get the pixels from the generated grayscale image
  17000. + int[] intValues = new int[w * h];
  17001. + bmpGrayscale.getPixels(intValues, 0, w, 0, 0, w, h);
  17002. + // Shape with one channel
  17003. + int[] shape = new int[] {1, h, w, 1};
  17004. - // Get R channel from ARGB color
  17005. - for (int i = 0; i < intValues.length; i++) {
  17006. - intValues[i] = ((intValues[i] >> 16) & 0xff);
  17007. + // Get R channel from ARGB color
  17008. + for (int i = 0; i < intValues.length; i++) {
  17009. + intValues[i] = ((intValues[i] >> 16) & 0xff);
  17010. + }
  17011. + TensorBuffer buffer = TensorBuffer.createFixedSize(shape, image.getDataType());
  17012. + buffer.loadArray(intValues, shape);
  17013. + image.load(buffer, ColorSpaceType.GRAYSCALE);
  17014. + return image;
  17015. }
  17016. - TensorBuffer buffer = TensorBuffer.createFixedSize(shape, image.getDataType());
  17017. - buffer.loadArray(intValues, shape);
  17018. - image.load(buffer, ColorSpaceType.GRAYSCALE);
  17019. - return image;
  17020. - }
  17021. - @Override
  17022. - public int getOutputImageHeight(int inputImageHeight, int inputImageWidth) {
  17023. - return inputImageHeight;
  17024. - }
  17025. + @Override
  17026. + public int getOutputImageHeight(int inputImageHeight, int inputImageWidth) {
  17027. + return inputImageHeight;
  17028. + }
  17029. - @Override
  17030. - public int getOutputImageWidth(int inputImageHeight, int inputImageWidth) {
  17031. - return inputImageWidth;
  17032. - }
  17033. + @Override
  17034. + public int getOutputImageWidth(int inputImageHeight, int inputImageWidth) {
  17035. + return inputImageWidth;
  17036. + }
  17037. - @Override
  17038. - public PointF inverseTransform(PointF point, int inputImageHeight, int inputImageWidth) {
  17039. - return point;
  17040. - }
  17041. + @Override
  17042. + public PointF inverseTransform(PointF point, int inputImageHeight, int inputImageWidth) {
  17043. + return point;
  17044. + }
  17045. }
  17046. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/Category.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/Category.java
  17047. index 8135ddcc28619..af56b70a77cf3 100644
  17048. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/Category.java
  17049. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/Category.java
  17050. @@ -15,9 +15,10 @@ limitations under the License.
  17051. package org.tensorflow.lite.support.label;
  17052. -import java.util.Objects;
  17053. import org.tensorflow.lite.annotations.UsedByReflection;
  17054. +import java.util.Objects;
  17055. +
  17056. /**
  17057. * Category is a util class, contains a label, its display name, a float value as score, and the
  17058. * index of the label in the corresponding label file. Typically it's used as result of
  17059. @@ -25,102 +26,97 @@ import org.tensorflow.lite.annotations.UsedByReflection;
  17060. */
  17061. @UsedByReflection("TFLiteSupport/Task")
  17062. public final class Category {
  17063. - private static final int DEFAULT_INDEX = -1;
  17064. - private static final float TOLERANCE = 1e-6f;
  17065. - private final int index;
  17066. - private final String label;
  17067. - private final String displayName;
  17068. - private final float score;
  17069. -
  17070. - /**
  17071. - * Constructs a {@link Category} object.
  17072. - *
  17073. - * @param label the label of this category object
  17074. - * @param displayName the display name of the label, which may be translated for different
  17075. - * locales. For exmaple, a label, "apple", may be translated into Spanish for display purpose,
  17076. - * so that the displayName is "manzana".
  17077. - * @param score the probability score of this label category
  17078. - * @param index the index of the label in the corresponding label file
  17079. - */
  17080. - @UsedByReflection("TFLiteSupport/Task")
  17081. - public static Category create(String label, String displayName, float score, int index) {
  17082. - return new Category(label, displayName, score, index);
  17083. - }
  17084. -
  17085. - /** Constructs a {@link Category} object with the default index (-1). */
  17086. - @UsedByReflection("TFLiteSupport/Task")
  17087. - public static Category create(String label, String displayName, float score) {
  17088. - return new Category(label, displayName, score, DEFAULT_INDEX);
  17089. - }
  17090. -
  17091. - /** Constructs a {@link Category} object with an empty displayName and the default index (-1). */
  17092. - @UsedByReflection("TFLiteSupport/Task")
  17093. - public Category(String label, float score) {
  17094. - this(label, /*displayName=*/ "", score, DEFAULT_INDEX);
  17095. - }
  17096. -
  17097. - private Category(String label, String displayName, float score, int index) {
  17098. - this.label = label;
  17099. - this.displayName = displayName;
  17100. - this.score = score;
  17101. - this.index = index;
  17102. - }
  17103. -
  17104. - /** Gets the reference of category's label. */
  17105. - public String getLabel() {
  17106. - return label;
  17107. - }
  17108. -
  17109. - /**
  17110. - * Gets the reference of category's displayName, a name in locale of the label.
  17111. - *
  17112. - * <p>The display name can be an empty string if this {@link Category} object is constructed
  17113. - * without displayName, such as when using {@link #Category(String label, float score)}.
  17114. - */
  17115. - public String getDisplayName() {
  17116. - return displayName;
  17117. - }
  17118. -
  17119. - /** Gets the score of the category. */
  17120. - public float getScore() {
  17121. - return score;
  17122. - }
  17123. -
  17124. - /**
  17125. - * Gets the index of the category. The index value might be -1, which means it has not been set up
  17126. - * properly and is invalid.
  17127. - */
  17128. - public int getIndex() {
  17129. - return index;
  17130. - }
  17131. -
  17132. - @Override
  17133. - public boolean equals(Object o) {
  17134. - if (o instanceof Category) {
  17135. - Category other = (Category) o;
  17136. - return (other.getLabel().equals(this.label)
  17137. - && other.getDisplayName().equals(this.displayName)
  17138. - && Math.abs(other.getScore() - this.score) < TOLERANCE
  17139. - && other.getIndex() == this.index);
  17140. + private static final int DEFAULT_INDEX = -1;
  17141. + private static final float TOLERANCE = 1e-6f;
  17142. + private final int index;
  17143. + private final String label;
  17144. + private final String displayName;
  17145. + private final float score;
  17146. +
  17147. + /**
  17148. + * Constructs a {@link Category} object.
  17149. + *
  17150. + * @param label the label of this category object
  17151. + * @param displayName the display name of the label, which may be translated for different
  17152. + * locales. For exmaple, a label, "apple", may be translated into Spanish for display
  17153. + * purpose, so that the displayName is "manzana".
  17154. + * @param score the probability score of this label category
  17155. + * @param index the index of the label in the corresponding label file
  17156. + */
  17157. + @UsedByReflection("TFLiteSupport/Task")
  17158. + public static Category create(String label, String displayName, float score, int index) {
  17159. + return new Category(label, displayName, score, index);
  17160. + }
  17161. +
  17162. + /** Constructs a {@link Category} object with the default index (-1). */
  17163. + @UsedByReflection("TFLiteSupport/Task")
  17164. + public static Category create(String label, String displayName, float score) {
  17165. + return new Category(label, displayName, score, DEFAULT_INDEX);
  17166. + }
  17167. +
  17168. + /**
  17169. + * Constructs a {@link Category} object with an empty displayName and the default index (-1).
  17170. + */
  17171. + @UsedByReflection("TFLiteSupport/Task")
  17172. + public Category(String label, float score) {
  17173. + this(label, /*displayName=*/"", score, DEFAULT_INDEX);
  17174. + }
  17175. +
  17176. + private Category(String label, String displayName, float score, int index) {
  17177. + this.label = label;
  17178. + this.displayName = displayName;
  17179. + this.score = score;
  17180. + this.index = index;
  17181. + }
  17182. +
  17183. + /** Gets the reference of category's label. */
  17184. + public String getLabel() {
  17185. + return label;
  17186. + }
  17187. +
  17188. + /**
  17189. + * Gets the reference of category's displayName, a name in locale of the label.
  17190. + *
  17191. + * <p>The display name can be an empty string if this {@link Category} object is constructed
  17192. + * without displayName, such as when using {@link #Category(String label, float score)}.
  17193. + */
  17194. + public String getDisplayName() {
  17195. + return displayName;
  17196. + }
  17197. +
  17198. + /** Gets the score of the category. */
  17199. + public float getScore() {
  17200. + return score;
  17201. + }
  17202. +
  17203. + /**
  17204. + * Gets the index of the category. The index value might be -1, which means it has not been set
  17205. + * up properly and is invalid.
  17206. + */
  17207. + public int getIndex() {
  17208. + return index;
  17209. + }
  17210. +
  17211. + @Override
  17212. + public boolean equals(Object o) {
  17213. + if (o instanceof Category) {
  17214. + Category other = (Category) o;
  17215. + return (other.getLabel().equals(this.label)
  17216. + && other.getDisplayName().equals(this.displayName)
  17217. + && Math.abs(other.getScore() - this.score) < TOLERANCE
  17218. + && other.getIndex() == this.index);
  17219. + }
  17220. + return false;
  17221. + }
  17222. +
  17223. + @Override
  17224. + public int hashCode() {
  17225. + return Objects.hash(label, displayName, score, index);
  17226. + }
  17227. +
  17228. + @Override
  17229. + public String toString() {
  17230. + return "<Category \"" + label + "\" (displayName=" + displayName + " score=" + score
  17231. + + " index=" + index + ")>";
  17232. }
  17233. - return false;
  17234. - }
  17235. -
  17236. - @Override
  17237. - public int hashCode() {
  17238. - return Objects.hash(label, displayName, score, index);
  17239. - }
  17240. -
  17241. - @Override
  17242. - public String toString() {
  17243. - return "<Category \""
  17244. - + label
  17245. - + "\" (displayName="
  17246. - + displayName
  17247. - + " score="
  17248. - + score
  17249. - + " index="
  17250. - + index
  17251. - + ")>";
  17252. - }
  17253. }
  17254. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/LabelUtil.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/LabelUtil.java
  17255. index af21d74e25f5d..56ee89f091e03 100644
  17256. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/LabelUtil.java
  17257. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/LabelUtil.java
  17258. @@ -16,49 +16,52 @@ limitations under the License.
  17259. package org.tensorflow.lite.support.label;
  17260. import android.util.Log;
  17261. -import java.util.ArrayList;
  17262. -import java.util.Arrays;
  17263. -import java.util.List;
  17264. +
  17265. import org.checkerframework.checker.nullness.qual.NonNull;
  17266. import org.tensorflow.lite.support.common.internal.SupportPreconditions;
  17267. import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
  17268. +import java.util.ArrayList;
  17269. +import java.util.Arrays;
  17270. +import java.util.List;
  17271. +
  17272. /** Label operation utils. */
  17273. public class LabelUtil {
  17274. - /**
  17275. - * Maps an int value tensor to a list of string labels. It takes an array of strings as the
  17276. - * dictionary. Example: if the given tensor is [3, 1, 0], and given labels is ["background",
  17277. - * "apple", "banana", "cherry", "date"], the result will be ["date", "banana", "apple"].
  17278. - *
  17279. - * @param tensorBuffer A tensor with index values. The values should be non-negative integers, and
  17280. - * each value {@code x} will be converted to {@code labels[x + offset]}. If the tensor is
  17281. - * given as a float {@link TensorBuffer}, values will be cast to integers. All values that are
  17282. - * out of bound will map to empty string.
  17283. - * @param labels A list of strings, used as a dictionary to look up. The index of the array
  17284. - * element will be used as the key. To get better performance, use an object that implements
  17285. - * RandomAccess, such as {@link ArrayList}.
  17286. - * @param offset The offset value when look up int values in the {@code labels}.
  17287. - * @return the mapped strings. The length of the list is {@link TensorBuffer#getFlatSize}.
  17288. - * @throws IllegalArgumentException if {@code tensorBuffer} or {@code labels} is null.
  17289. - */
  17290. - public static List<String> mapValueToLabels(
  17291. - @NonNull TensorBuffer tensorBuffer, @NonNull List<String> labels, int offset) {
  17292. - SupportPreconditions.checkNotNull(tensorBuffer, "Given tensor should not be null");
  17293. - SupportPreconditions.checkNotNull(labels, "Given labels should not be null");
  17294. - int[] values = tensorBuffer.getIntArray();
  17295. - Log.d("values", Arrays.toString(values));
  17296. - List<String> result = new ArrayList<>();
  17297. - for (int v : values) {
  17298. - int index = v + offset;
  17299. - if (index < 0 || index >= labels.size()) {
  17300. - result.add("");
  17301. - } else {
  17302. - result.add(labels.get(index));
  17303. - }
  17304. + /**
  17305. + * Maps an int value tensor to a list of string labels. It takes an array of strings as the
  17306. + * dictionary. Example: if the given tensor is [3, 1, 0], and given labels is ["background",
  17307. + * "apple", "banana", "cherry", "date"], the result will be ["date", "banana", "apple"].
  17308. + *
  17309. + * @param tensorBuffer A tensor with index values. The values should be non-negative integers,
  17310. + * and
  17311. + * each value {@code x} will be converted to {@code labels[x + offset]}. If the tensor is
  17312. + * given as a float {@link TensorBuffer}, values will be cast to integers. All values that
  17313. + * are out of bound will map to empty string.
  17314. + * @param labels A list of strings, used as a dictionary to look up. The index of the array
  17315. + * element will be used as the key. To get better performance, use an object that implements
  17316. + * RandomAccess, such as {@link ArrayList}.
  17317. + * @param offset The offset value when look up int values in the {@code labels}.
  17318. + * @return the mapped strings. The length of the list is {@link TensorBuffer#getFlatSize}.
  17319. + * @throws IllegalArgumentException if {@code tensorBuffer} or {@code labels} is null.
  17320. + */
  17321. + public static List<String> mapValueToLabels(
  17322. + @NonNull TensorBuffer tensorBuffer, @NonNull List<String> labels, int offset) {
  17323. + SupportPreconditions.checkNotNull(tensorBuffer, "Given tensor should not be null");
  17324. + SupportPreconditions.checkNotNull(labels, "Given labels should not be null");
  17325. + int[] values = tensorBuffer.getIntArray();
  17326. + Log.d("values", Arrays.toString(values));
  17327. + List<String> result = new ArrayList<>();
  17328. + for (int v : values) {
  17329. + int index = v + offset;
  17330. + if (index < 0 || index >= labels.size()) {
  17331. + result.add("");
  17332. + } else {
  17333. + result.add(labels.get(index));
  17334. + }
  17335. + }
  17336. + return result;
  17337. }
  17338. - return result;
  17339. - }
  17340. - // Private constructor to prevent initialization.
  17341. - private LabelUtil() {}
  17342. + // Private constructor to prevent initialization.
  17343. + private LabelUtil() {}
  17344. }
  17345. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/TensorLabel.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/TensorLabel.java
  17346. index bdab7cf464c1b..edd683cd08126 100644
  17347. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/TensorLabel.java
  17348. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/TensorLabel.java
  17349. @@ -16,16 +16,18 @@ limitations under the License.
  17350. package org.tensorflow.lite.support.label;
  17351. import android.content.Context;
  17352. +
  17353. +import org.checkerframework.checker.nullness.qual.NonNull;
  17354. +import org.tensorflow.lite.DataType;
  17355. +import org.tensorflow.lite.support.common.internal.SupportPreconditions;
  17356. +import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
  17357. +
  17358. import java.nio.ByteBuffer;
  17359. import java.util.ArrayList;
  17360. import java.util.Arrays;
  17361. import java.util.LinkedHashMap;
  17362. import java.util.List;
  17363. import java.util.Map;
  17364. -import org.checkerframework.checker.nullness.qual.NonNull;
  17365. -import org.tensorflow.lite.DataType;
  17366. -import org.tensorflow.lite.support.common.internal.SupportPreconditions;
  17367. -import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
  17368. /**
  17369. * TensorLabel is an util wrapper for TensorBuffers with meaningful labels on an axis.
  17370. @@ -56,169 +58,170 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
  17371. * a label file (plain text file whose each line is a label) in assets simply.
  17372. */
  17373. public class TensorLabel {
  17374. - private final Map<Integer, List<String>> axisLabels;
  17375. - private final TensorBuffer tensorBuffer;
  17376. - private final int[] shape;
  17377. -
  17378. - /**
  17379. - * Creates a TensorLabel object which is able to label on the axes of multi-dimensional tensors.
  17380. - *
  17381. - * @param axisLabels A map, whose key is axis id (starting from 0) and value is corresponding
  17382. - * labels. Note: The size of labels should be same with the size of the tensor on that axis.
  17383. - * @param tensorBuffer The TensorBuffer to be labeled.
  17384. - * @throws NullPointerException if {@code axisLabels} or {@code tensorBuffer} is null, or any
  17385. - * value in {@code axisLabels} is null.
  17386. - * @throws IllegalArgumentException if any key in {@code axisLabels} is out of range (compared to
  17387. - * the shape of {@code tensorBuffer}, or any value (labels) has different size with the {@code
  17388. - * tensorBuffer} on the given dimension.
  17389. - */
  17390. - public TensorLabel(
  17391. - @NonNull Map<Integer, List<String>> axisLabels, @NonNull TensorBuffer tensorBuffer) {
  17392. - SupportPreconditions.checkNotNull(axisLabels, "Axis labels cannot be null.");
  17393. - SupportPreconditions.checkNotNull(tensorBuffer, "Tensor Buffer cannot be null.");
  17394. - this.axisLabels = axisLabels;
  17395. - this.tensorBuffer = tensorBuffer;
  17396. - this.shape = tensorBuffer.getShape();
  17397. - for (Map.Entry<Integer, List<String>> entry : axisLabels.entrySet()) {
  17398. - int axis = entry.getKey();
  17399. - SupportPreconditions.checkArgument(
  17400. - axis >= 0 && axis < shape.length, "Invalid axis id: " + axis);
  17401. - SupportPreconditions.checkNotNull(entry.getValue(), "Label list is null on axis " + axis);
  17402. - SupportPreconditions.checkArgument(
  17403. - shape[axis] == entry.getValue().size(),
  17404. - "Label number " + entry.getValue().size() + " mismatch the shape on axis " + axis);
  17405. + private final Map<Integer, List<String>> axisLabels;
  17406. + private final TensorBuffer tensorBuffer;
  17407. + private final int[] shape;
  17408. +
  17409. + /**
  17410. + * Creates a TensorLabel object which is able to label on the axes of multi-dimensional tensors.
  17411. + *
  17412. + * @param axisLabels A map, whose key is axis id (starting from 0) and value is corresponding
  17413. + * labels. Note: The size of labels should be same with the size of the tensor on that axis.
  17414. + * @param tensorBuffer The TensorBuffer to be labeled.
  17415. + * @throws NullPointerException if {@code axisLabels} or {@code tensorBuffer} is null, or any
  17416. + * value in {@code axisLabels} is null.
  17417. + * @throws IllegalArgumentException if any key in {@code axisLabels} is out of range (compared
  17418. + * to
  17419. + * the shape of {@code tensorBuffer}, or any value (labels) has different size with the
  17420. + * {@code tensorBuffer} on the given dimension.
  17421. + */
  17422. + public TensorLabel(
  17423. + @NonNull Map<Integer, List<String>> axisLabels, @NonNull TensorBuffer tensorBuffer) {
  17424. + SupportPreconditions.checkNotNull(axisLabels, "Axis labels cannot be null.");
  17425. + SupportPreconditions.checkNotNull(tensorBuffer, "Tensor Buffer cannot be null.");
  17426. + this.axisLabels = axisLabels;
  17427. + this.tensorBuffer = tensorBuffer;
  17428. + this.shape = tensorBuffer.getShape();
  17429. + for (Map.Entry<Integer, List<String>> entry : axisLabels.entrySet()) {
  17430. + int axis = entry.getKey();
  17431. + SupportPreconditions.checkArgument(
  17432. + axis >= 0 && axis < shape.length, "Invalid axis id: " + axis);
  17433. + SupportPreconditions.checkNotNull(
  17434. + entry.getValue(), "Label list is null on axis " + axis);
  17435. + SupportPreconditions.checkArgument(shape[axis] == entry.getValue().size(),
  17436. + "Label number " + entry.getValue().size() + " mismatch the shape on axis "
  17437. + + axis);
  17438. + }
  17439. }
  17440. - }
  17441. -
  17442. - /**
  17443. - * Creates a TensorLabel object which is able to label on one axis of multi-dimensional tensors.
  17444. - *
  17445. - * <p>Note: The labels are applied on the first axis whose size is larger than 1. For example, if
  17446. - * the shape of the tensor is [1, 10, 3], the labels will be applied on axis 1 (id starting from
  17447. - * 0), and size of {@code axisLabels} should be 10 as well.
  17448. - *
  17449. - * @param axisLabels A list of labels, whose size should be same with the size of the tensor on
  17450. - * the to-be-labeled axis.
  17451. - * @param tensorBuffer The TensorBuffer to be labeled.
  17452. - */
  17453. - public TensorLabel(@NonNull List<String> axisLabels, @NonNull TensorBuffer tensorBuffer) {
  17454. - this(makeMap(getFirstAxisWithSizeGreaterThanOne(tensorBuffer), axisLabels), tensorBuffer);
  17455. - }
  17456. -
  17457. - /**
  17458. - * Gets the map with a pair of the label and the corresponding TensorBuffer. Only allow the
  17459. - * mapping on the first axis with size greater than 1 currently.
  17460. - */
  17461. - @NonNull
  17462. - public Map<String, TensorBuffer> getMapWithTensorBuffer() {
  17463. - int labeledAxis = getFirstAxisWithSizeGreaterThanOne(tensorBuffer);
  17464. -
  17465. - Map<String, TensorBuffer> labelToTensorMap = new LinkedHashMap<>();
  17466. - SupportPreconditions.checkArgument(
  17467. - axisLabels.containsKey(labeledAxis),
  17468. - "get a <String, TensorBuffer> map requires the labels are set on the first non-1 axis.");
  17469. - List<String> labels = axisLabels.get(labeledAxis);
  17470. -
  17471. - DataType dataType = tensorBuffer.getDataType();
  17472. - int typeSize = tensorBuffer.getTypeSize();
  17473. - int flatSize = tensorBuffer.getFlatSize();
  17474. -
  17475. - // Gets the underlying bytes that could be used to generate the sub-array later.
  17476. - ByteBuffer byteBuffer = tensorBuffer.getBuffer();
  17477. - byteBuffer.rewind();
  17478. -
  17479. - // Note: computation below is only correct when labeledAxis is the first axis with size greater
  17480. - // than 1.
  17481. - int subArrayLength = flatSize / shape[labeledAxis] * typeSize;
  17482. - int i = 0;
  17483. - SupportPreconditions.checkNotNull(labels, "Label list should never be null");
  17484. - for (String label : labels) {
  17485. - // Gets the corresponding TensorBuffer.
  17486. - byteBuffer.position(i * subArrayLength);
  17487. - ByteBuffer subBuffer = byteBuffer.slice();
  17488. - // ByteBuffer.slice doesn't keep order. Modify it to align with the original one.
  17489. - subBuffer.order(byteBuffer.order()).limit(subArrayLength);
  17490. - TensorBuffer labelBuffer = TensorBuffer.createDynamic(dataType);
  17491. - labelBuffer.loadBuffer(subBuffer, Arrays.copyOfRange(shape, labeledAxis + 1, shape.length));
  17492. - labelToTensorMap.put(label, labelBuffer);
  17493. - i += 1;
  17494. +
  17495. + /**
  17496. + * Creates a TensorLabel object which is able to label on one axis of multi-dimensional tensors.
  17497. + *
  17498. + * <p>Note: The labels are applied on the first axis whose size is larger than 1. For example,
  17499. + * if the shape of the tensor is [1, 10, 3], the labels will be applied on axis 1 (id starting
  17500. + * from 0), and size of {@code axisLabels} should be 10 as well.
  17501. + *
  17502. + * @param axisLabels A list of labels, whose size should be same with the size of the tensor on
  17503. + * the to-be-labeled axis.
  17504. + * @param tensorBuffer The TensorBuffer to be labeled.
  17505. + */
  17506. + public TensorLabel(@NonNull List<String> axisLabels, @NonNull TensorBuffer tensorBuffer) {
  17507. + this(makeMap(getFirstAxisWithSizeGreaterThanOne(tensorBuffer), axisLabels), tensorBuffer);
  17508. }
  17509. - return labelToTensorMap;
  17510. - }
  17511. -
  17512. - /**
  17513. - * Gets a map that maps label to float. Only allow the mapping on the first axis with size greater
  17514. - * than 1, and the axis should be effectively the last axis (which means every sub tensor
  17515. - * specified by this axis should have a flat size of 1).
  17516. - *
  17517. - * <p>{@link TensorLabel#getCategoryList()} is an alternative API to get the result.
  17518. - *
  17519. - * @throws IllegalStateException if size of a sub tensor on each label is not 1.
  17520. - */
  17521. - @NonNull
  17522. - public Map<String, Float> getMapWithFloatValue() {
  17523. - int labeledAxis = getFirstAxisWithSizeGreaterThanOne(tensorBuffer);
  17524. - SupportPreconditions.checkState(
  17525. - labeledAxis == shape.length - 1,
  17526. - "get a <String, Scalar> map is only valid when the only labeled axis is the last one.");
  17527. - List<String> labels = axisLabels.get(labeledAxis);
  17528. - float[] data = tensorBuffer.getFloatArray();
  17529. - SupportPreconditions.checkState(labels.size() == data.length);
  17530. - Map<String, Float> result = new LinkedHashMap<>();
  17531. - int i = 0;
  17532. - for (String label : labels) {
  17533. - result.put(label, data[i]);
  17534. - i += 1;
  17535. +
  17536. + /**
  17537. + * Gets the map with a pair of the label and the corresponding TensorBuffer. Only allow the
  17538. + * mapping on the first axis with size greater than 1 currently.
  17539. + */
  17540. + @NonNull
  17541. + public Map<String, TensorBuffer> getMapWithTensorBuffer() {
  17542. + int labeledAxis = getFirstAxisWithSizeGreaterThanOne(tensorBuffer);
  17543. +
  17544. + Map<String, TensorBuffer> labelToTensorMap = new LinkedHashMap<>();
  17545. + SupportPreconditions.checkArgument(axisLabels.containsKey(labeledAxis),
  17546. + "get a <String, TensorBuffer> map requires the labels are set on the first non-1 axis.");
  17547. + List<String> labels = axisLabels.get(labeledAxis);
  17548. +
  17549. + DataType dataType = tensorBuffer.getDataType();
  17550. + int typeSize = tensorBuffer.getTypeSize();
  17551. + int flatSize = tensorBuffer.getFlatSize();
  17552. +
  17553. + // Gets the underlying bytes that could be used to generate the sub-array later.
  17554. + ByteBuffer byteBuffer = tensorBuffer.getBuffer();
  17555. + byteBuffer.rewind();
  17556. +
  17557. + // Note: computation below is only correct when labeledAxis is the first axis with size
  17558. + // greater than 1.
  17559. + int subArrayLength = flatSize / shape[labeledAxis] * typeSize;
  17560. + int i = 0;
  17561. + SupportPreconditions.checkNotNull(labels, "Label list should never be null");
  17562. + for (String label : labels) {
  17563. + // Gets the corresponding TensorBuffer.
  17564. + byteBuffer.position(i * subArrayLength);
  17565. + ByteBuffer subBuffer = byteBuffer.slice();
  17566. + // ByteBuffer.slice doesn't keep order. Modify it to align with the original one.
  17567. + subBuffer.order(byteBuffer.order()).limit(subArrayLength);
  17568. + TensorBuffer labelBuffer = TensorBuffer.createDynamic(dataType);
  17569. + labelBuffer.loadBuffer(
  17570. + subBuffer, Arrays.copyOfRange(shape, labeledAxis + 1, shape.length));
  17571. + labelToTensorMap.put(label, labelBuffer);
  17572. + i += 1;
  17573. + }
  17574. + return labelToTensorMap;
  17575. }
  17576. - return result;
  17577. - }
  17578. -
  17579. - /**
  17580. - * Gets a list of {@link Category} from the {@link TensorLabel} object.
  17581. - *
  17582. - * <p>The axis of label should be effectively the last axis (which means every sub tensor
  17583. - * specified by this axis should have a flat size of 1), so that each labelled sub tensor could be
  17584. - * converted into a float value score. Example: A {@link TensorLabel} with shape {@code {2, 5, 3}}
  17585. - * and axis 2 is valid. If axis is 1 or 0, it cannot be converted into a {@link Category}.
  17586. - *
  17587. - * <p>{@link TensorLabel#getMapWithFloatValue()} is an alternative but returns a {@link Map} as
  17588. - * the result.
  17589. - *
  17590. - * @throws IllegalStateException if size of a sub tensor on each label is not 1.
  17591. - */
  17592. - @NonNull
  17593. - public List<Category> getCategoryList() {
  17594. - int labeledAxis = getFirstAxisWithSizeGreaterThanOne(tensorBuffer);
  17595. - SupportPreconditions.checkState(
  17596. - labeledAxis == shape.length - 1,
  17597. - "get a Category list is only valid when the only labeled axis is the last one.");
  17598. - List<String> labels = axisLabels.get(labeledAxis);
  17599. - float[] data = tensorBuffer.getFloatArray();
  17600. - SupportPreconditions.checkState(labels.size() == data.length);
  17601. - List<Category> result = new ArrayList<>();
  17602. - int i = 0;
  17603. - for (String label : labels) {
  17604. - result.add(new Category(label, data[i]));
  17605. - i += 1;
  17606. +
  17607. + /**
  17608. + * Gets a map that maps label to float. Only allow the mapping on the first axis with size
  17609. + * greater than 1, and the axis should be effectively the last axis (which means every sub
  17610. + * tensor specified by this axis should have a flat size of 1).
  17611. + *
  17612. + * <p>{@link TensorLabel#getCategoryList()} is an alternative API to get the result.
  17613. + *
  17614. + * @throws IllegalStateException if size of a sub tensor on each label is not 1.
  17615. + */
  17616. + @NonNull
  17617. + public Map<String, Float> getMapWithFloatValue() {
  17618. + int labeledAxis = getFirstAxisWithSizeGreaterThanOne(tensorBuffer);
  17619. + SupportPreconditions.checkState(labeledAxis == shape.length - 1,
  17620. + "get a <String, Scalar> map is only valid when the only labeled axis is the last one.");
  17621. + List<String> labels = axisLabels.get(labeledAxis);
  17622. + float[] data = tensorBuffer.getFloatArray();
  17623. + SupportPreconditions.checkState(labels.size() == data.length);
  17624. + Map<String, Float> result = new LinkedHashMap<>();
  17625. + int i = 0;
  17626. + for (String label : labels) {
  17627. + result.put(label, data[i]);
  17628. + i += 1;
  17629. + }
  17630. + return result;
  17631. }
  17632. - return result;
  17633. - }
  17634. -
  17635. - private static int getFirstAxisWithSizeGreaterThanOne(@NonNull TensorBuffer tensorBuffer) {
  17636. - int[] shape = tensorBuffer.getShape();
  17637. - for (int i = 0; i < shape.length; i++) {
  17638. - if (shape[i] > 1) {
  17639. - return i;
  17640. - }
  17641. +
  17642. + /**
  17643. + * Gets a list of {@link Category} from the {@link TensorLabel} object.
  17644. + *
  17645. + * <p>The axis of label should be effectively the last axis (which means every sub tensor
  17646. + * specified by this axis should have a flat size of 1), so that each labelled sub tensor could
  17647. + * be converted into a float value score. Example: A {@link TensorLabel} with shape {@code {2,
  17648. + * 5, 3}} and axis 2 is valid. If axis is 1 or 0, it cannot be converted into a {@link
  17649. + * Category}.
  17650. + *
  17651. + * <p>{@link TensorLabel#getMapWithFloatValue()} is an alternative but returns a {@link Map} as
  17652. + * the result.
  17653. + *
  17654. + * @throws IllegalStateException if size of a sub tensor on each label is not 1.
  17655. + */
  17656. + @NonNull
  17657. + public List<Category> getCategoryList() {
  17658. + int labeledAxis = getFirstAxisWithSizeGreaterThanOne(tensorBuffer);
  17659. + SupportPreconditions.checkState(labeledAxis == shape.length - 1,
  17660. + "get a Category list is only valid when the only labeled axis is the last one.");
  17661. + List<String> labels = axisLabels.get(labeledAxis);
  17662. + float[] data = tensorBuffer.getFloatArray();
  17663. + SupportPreconditions.checkState(labels.size() == data.length);
  17664. + List<Category> result = new ArrayList<>();
  17665. + int i = 0;
  17666. + for (String label : labels) {
  17667. + result.add(new Category(label, data[i]));
  17668. + i += 1;
  17669. + }
  17670. + return result;
  17671. + }
  17672. +
  17673. + private static int getFirstAxisWithSizeGreaterThanOne(@NonNull TensorBuffer tensorBuffer) {
  17674. + int[] shape = tensorBuffer.getShape();
  17675. + for (int i = 0; i < shape.length; i++) {
  17676. + if (shape[i] > 1) {
  17677. + return i;
  17678. + }
  17679. + }
  17680. + throw new IllegalArgumentException(
  17681. + "Cannot find an axis to label. A valid axis to label should have size larger than 1.");
  17682. + }
  17683. +
  17684. + // Helper function to wrap the List<String> to a one-entry map.
  17685. + private static Map<Integer, List<String>> makeMap(int axis, List<String> labels) {
  17686. + Map<Integer, List<String>> map = new LinkedHashMap<>();
  17687. + map.put(axis, labels);
  17688. + return map;
  17689. }
  17690. - throw new IllegalArgumentException(
  17691. - "Cannot find an axis to label. A valid axis to label should have size larger than 1.");
  17692. - }
  17693. -
  17694. - // Helper function to wrap the List<String> to a one-entry map.
  17695. - private static Map<Integer, List<String>> makeMap(int axis, List<String> labels) {
  17696. - Map<Integer, List<String>> map = new LinkedHashMap<>();
  17697. - map.put(axis, labels);
  17698. - return map;
  17699. - }
  17700. }
  17701. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/ops/LabelAxisOp.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/ops/LabelAxisOp.java
  17702. index ed47f65a726a6..e44edc64f4969 100644
  17703. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/ops/LabelAxisOp.java
  17704. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/ops/LabelAxisOp.java
  17705. @@ -16,16 +16,18 @@ limitations under the License.
  17706. package org.tensorflow.lite.support.label.ops;
  17707. import android.content.Context;
  17708. -import java.io.IOException;
  17709. -import java.util.HashMap;
  17710. -import java.util.List;
  17711. -import java.util.Map;
  17712. +
  17713. import org.checkerframework.checker.nullness.qual.NonNull;
  17714. import org.tensorflow.lite.support.common.FileUtil;
  17715. import org.tensorflow.lite.support.common.internal.SupportPreconditions;
  17716. import org.tensorflow.lite.support.label.TensorLabel;
  17717. import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
  17718. +import java.io.IOException;
  17719. +import java.util.HashMap;
  17720. +import java.util.List;
  17721. +import java.util.Map;
  17722. +
  17723. /**
  17724. * Labels TensorBuffer with axisLabels for outputs.
  17725. *
  17726. @@ -33,42 +35,42 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
  17727. * a pair of the label name and the corresponding TensorBuffer value.
  17728. */
  17729. public class LabelAxisOp {
  17730. - // Axis and its corresponding label names.
  17731. - private final Map<Integer, List<String>> axisLabels;
  17732. -
  17733. - protected LabelAxisOp(Builder builder) {
  17734. - axisLabels = builder.axisLabels;
  17735. - }
  17736. -
  17737. - public TensorLabel apply(@NonNull TensorBuffer buffer) {
  17738. - SupportPreconditions.checkNotNull(buffer, "Tensor buffer cannot be null.");
  17739. - return new TensorLabel(axisLabels, buffer);
  17740. - }
  17741. -
  17742. - /** The inner builder class to build a LabelTensor Operator. */
  17743. - public static class Builder {
  17744. + // Axis and its corresponding label names.
  17745. private final Map<Integer, List<String>> axisLabels;
  17746. - protected Builder() {
  17747. - axisLabels = new HashMap<>();
  17748. + protected LabelAxisOp(Builder builder) {
  17749. + axisLabels = builder.axisLabels;
  17750. }
  17751. - public Builder addAxisLabel(@NonNull Context context, int axis, @NonNull String filePath)
  17752. - throws IOException {
  17753. - SupportPreconditions.checkNotNull(context, "Context cannot be null.");
  17754. - SupportPreconditions.checkNotNull(filePath, "File path cannot be null.");
  17755. - List<String> labels = FileUtil.loadLabels(context, filePath);
  17756. - axisLabels.put(axis, labels);
  17757. - return this;
  17758. + public TensorLabel apply(@NonNull TensorBuffer buffer) {
  17759. + SupportPreconditions.checkNotNull(buffer, "Tensor buffer cannot be null.");
  17760. + return new TensorLabel(axisLabels, buffer);
  17761. }
  17762. - public Builder addAxisLabel(int axis, @NonNull List<String> labels) {
  17763. - axisLabels.put(axis, labels);
  17764. - return this;
  17765. - }
  17766. + /** The inner builder class to build a LabelTensor Operator. */
  17767. + public static class Builder {
  17768. + private final Map<Integer, List<String>> axisLabels;
  17769. +
  17770. + protected Builder() {
  17771. + axisLabels = new HashMap<>();
  17772. + }
  17773. +
  17774. + public Builder addAxisLabel(@NonNull Context context, int axis, @NonNull String filePath)
  17775. + throws IOException {
  17776. + SupportPreconditions.checkNotNull(context, "Context cannot be null.");
  17777. + SupportPreconditions.checkNotNull(filePath, "File path cannot be null.");
  17778. + List<String> labels = FileUtil.loadLabels(context, filePath);
  17779. + axisLabels.put(axis, labels);
  17780. + return this;
  17781. + }
  17782. +
  17783. + public Builder addAxisLabel(int axis, @NonNull List<String> labels) {
  17784. + axisLabels.put(axis, labels);
  17785. + return this;
  17786. + }
  17787. - public LabelAxisOp build() {
  17788. - return new LabelAxisOp(this);
  17789. + public LabelAxisOp build() {
  17790. + return new LabelAxisOp(this);
  17791. + }
  17792. }
  17793. - }
  17794. }
  17795. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/model/GpuDelegateProxy.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/model/GpuDelegateProxy.java
  17796. index 9cfcf923dedee..ada9b33fb0eea 100644
  17797. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/model/GpuDelegateProxy.java
  17798. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/model/GpuDelegateProxy.java
  17799. @@ -16,54 +16,55 @@ limitations under the License.
  17800. package org.tensorflow.lite.support.model;
  17801. import android.util.Log;
  17802. -import java.io.Closeable;
  17803. -import java.io.IOException;
  17804. +
  17805. import org.checkerframework.checker.nullness.qual.Nullable;
  17806. import org.tensorflow.lite.Delegate;
  17807. +import java.io.Closeable;
  17808. +import java.io.IOException;
  17809. +
  17810. /**
  17811. * Helper class to create and call necessary methods of {@code GpuDelegate} which is not a strict
  17812. * dependency.
  17813. */
  17814. class GpuDelegateProxy implements Delegate, Closeable {
  17815. + private static final String TAG = "GpuDelegateProxy";
  17816. - private static final String TAG = "GpuDelegateProxy";
  17817. -
  17818. - private final Delegate proxiedDelegate;
  17819. - private final Closeable proxiedCloseable;
  17820. + private final Delegate proxiedDelegate;
  17821. + private final Closeable proxiedCloseable;
  17822. - @Nullable
  17823. - public static GpuDelegateProxy maybeNewInstance() {
  17824. - try {
  17825. - Class<?> clazz = Class.forName("org.tensorflow.lite.gpu.GpuDelegate");
  17826. - Object instance = clazz.getDeclaredConstructor().newInstance();
  17827. - return new GpuDelegateProxy(instance);
  17828. - } catch (ReflectiveOperationException e) {
  17829. - Log.e(TAG, "Failed to create the GpuDelegate dynamically.", e);
  17830. - return null;
  17831. + @Nullable
  17832. + public static GpuDelegateProxy maybeNewInstance() {
  17833. + try {
  17834. + Class<?> clazz = Class.forName("org.tensorflow.lite.gpu.GpuDelegate");
  17835. + Object instance = clazz.getDeclaredConstructor().newInstance();
  17836. + return new GpuDelegateProxy(instance);
  17837. + } catch (ReflectiveOperationException e) {
  17838. + Log.e(TAG, "Failed to create the GpuDelegate dynamically.", e);
  17839. + return null;
  17840. + }
  17841. }
  17842. - }
  17843. - /** Calls {@code close()} method of the delegate. */
  17844. - @Override
  17845. - public void close() {
  17846. - try {
  17847. - proxiedCloseable.close();
  17848. - } catch (IOException e) {
  17849. - // Should not trigger, because GpuDelegate#close never throws. The catch is required because
  17850. - // of Closeable#close.
  17851. - Log.e(TAG, "Failed to close the GpuDelegate.", e);
  17852. + /** Calls {@code close()} method of the delegate. */
  17853. + @Override
  17854. + public void close() {
  17855. + try {
  17856. + proxiedCloseable.close();
  17857. + } catch (IOException e) {
  17858. + // Should not trigger, because GpuDelegate#close never throws. The catch is required
  17859. + // because of Closeable#close.
  17860. + Log.e(TAG, "Failed to close the GpuDelegate.", e);
  17861. + }
  17862. }
  17863. - }
  17864. - /** Calls {@code getNativeHandle()} method of the delegate. */
  17865. - @Override
  17866. - public long getNativeHandle() {
  17867. - return proxiedDelegate.getNativeHandle();
  17868. - }
  17869. + /** Calls {@code getNativeHandle()} method of the delegate. */
  17870. + @Override
  17871. + public long getNativeHandle() {
  17872. + return proxiedDelegate.getNativeHandle();
  17873. + }
  17874. - private GpuDelegateProxy(Object instance) {
  17875. - this.proxiedCloseable = (Closeable) instance;
  17876. - this.proxiedDelegate = (Delegate) instance;
  17877. - }
  17878. + private GpuDelegateProxy(Object instance) {
  17879. + this.proxiedCloseable = (Closeable) instance;
  17880. + this.proxiedDelegate = (Delegate) instance;
  17881. + }
  17882. }
  17883. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/model/Model.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/model/Model.java
  17884. index 09b63f1b12beb..282f2b9aa599c 100644
  17885. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/model/Model.java
  17886. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/model/Model.java
  17887. @@ -16,9 +16,7 @@ limitations under the License.
  17888. package org.tensorflow.lite.support.model;
  17889. import android.content.Context;
  17890. -import java.io.IOException;
  17891. -import java.nio.MappedByteBuffer;
  17892. -import java.util.Map;
  17893. +
  17894. import org.checkerframework.checker.nullness.qual.NonNull;
  17895. import org.checkerframework.checker.nullness.qual.Nullable;
  17896. import org.tensorflow.lite.InterpreterApi;
  17897. @@ -27,6 +25,10 @@ import org.tensorflow.lite.Tensor;
  17898. import org.tensorflow.lite.support.common.FileUtil;
  17899. import org.tensorflow.lite.support.common.internal.SupportPreconditions;
  17900. +import java.io.IOException;
  17901. +import java.nio.MappedByteBuffer;
  17902. +import java.util.Map;
  17903. +
  17904. /**
  17905. * The wrapper class for a TFLite model and a TFLite interpreter.
  17906. *
  17907. @@ -34,263 +36,254 @@ import org.tensorflow.lite.support.common.internal.SupportPreconditions;
  17908. * interpreter instance to run it.
  17909. */
  17910. public class Model {
  17911. + /** The runtime device type used for executing classification. */
  17912. + public enum Device { CPU, NNAPI, GPU }
  17913. - /** The runtime device type used for executing classification. */
  17914. - public enum Device {
  17915. - CPU,
  17916. - NNAPI,
  17917. - GPU
  17918. - }
  17919. -
  17920. - /**
  17921. - * Options for running the model. Configurable parameters includes:
  17922. - *
  17923. - * <ul>
  17924. - * <li>{@code device} {@link Builder#setDevice(Device)} specifies the hardware to run the model.
  17925. - * The default value is {@link Device#CPU}.
  17926. - * <li>{@code numThreads} {@link Builder#setNumThreads(int)} specifies the number of threads
  17927. - * used by TFLite inference. It's only effective when device is set to {@link Device#CPU}
  17928. - * and default value is 1.
  17929. - * </ul>
  17930. - */
  17931. - public static class Options {
  17932. - private final Device device;
  17933. - private final int numThreads;
  17934. - private final TfLiteRuntime tfLiteRuntime;
  17935. -
  17936. - /** Builder of {@link Options}. See its doc for details. */
  17937. - public static class Builder {
  17938. - private Device device = Device.CPU;
  17939. - private int numThreads = 1;
  17940. - private TfLiteRuntime tfLiteRuntime;
  17941. -
  17942. - public Builder setDevice(Device device) {
  17943. - this.device = device;
  17944. - return this;
  17945. - }
  17946. -
  17947. - public Builder setNumThreads(int numThreads) {
  17948. - this.numThreads = numThreads;
  17949. - return this;
  17950. - }
  17951. -
  17952. - public Builder setTfLiteRuntime(TfLiteRuntime tfLiteRuntime) {
  17953. - this.tfLiteRuntime = tfLiteRuntime;
  17954. - return this;
  17955. - }
  17956. -
  17957. - public Options build() {
  17958. - return new Options(this);
  17959. - }
  17960. + /**
  17961. + * Options for running the model. Configurable parameters includes:
  17962. + *
  17963. + * <ul>
  17964. + * <li>{@code device} {@link Builder#setDevice(Device)} specifies the hardware to run the
  17965. + * model. The default value is {@link Device#CPU}. <li>{@code numThreads} {@link
  17966. + * Builder#setNumThreads(int)} specifies the number of threads used by TFLite inference. It's
  17967. + * only effective when device is set to {@link Device#CPU} and default value is 1.
  17968. + * </ul>
  17969. + */
  17970. + public static class Options {
  17971. + private final Device device;
  17972. + private final int numThreads;
  17973. + private final TfLiteRuntime tfLiteRuntime;
  17974. +
  17975. + /** Builder of {@link Options}. See its doc for details. */
  17976. + public static class Builder {
  17977. + private Device device = Device.CPU;
  17978. + private int numThreads = 1;
  17979. + private TfLiteRuntime tfLiteRuntime;
  17980. +
  17981. + public Builder setDevice(Device device) {
  17982. + this.device = device;
  17983. + return this;
  17984. + }
  17985. +
  17986. + public Builder setNumThreads(int numThreads) {
  17987. + this.numThreads = numThreads;
  17988. + return this;
  17989. + }
  17990. +
  17991. + public Builder setTfLiteRuntime(TfLiteRuntime tfLiteRuntime) {
  17992. + this.tfLiteRuntime = tfLiteRuntime;
  17993. + return this;
  17994. + }
  17995. +
  17996. + public Options build() {
  17997. + return new Options(this);
  17998. + }
  17999. + }
  18000. +
  18001. + private Options(Builder builder) {
  18002. + device = builder.device;
  18003. + numThreads = builder.numThreads;
  18004. + tfLiteRuntime = builder.tfLiteRuntime;
  18005. + }
  18006. }
  18007. - private Options(Builder builder) {
  18008. - device = builder.device;
  18009. - numThreads = builder.numThreads;
  18010. - tfLiteRuntime = builder.tfLiteRuntime;
  18011. - }
  18012. - }
  18013. + /** An instance of the driver class to run model inference with Tensorflow Lite. */
  18014. + private final InterpreterApi interpreter;
  18015. - /** An instance of the driver class to run model inference with Tensorflow Lite. */
  18016. - private final InterpreterApi interpreter;
  18017. + /** Path to tflite model file in asset folder. */
  18018. + private final String modelPath;
  18019. - /** Path to tflite model file in asset folder. */
  18020. - private final String modelPath;
  18021. + /** The memory-mapped model data. */
  18022. + private final MappedByteBuffer byteModel;
  18023. - /** The memory-mapped model data. */
  18024. - private final MappedByteBuffer byteModel;
  18025. + private final GpuDelegateProxy gpuDelegateProxy;
  18026. - private final GpuDelegateProxy gpuDelegateProxy;
  18027. + /**
  18028. + * Builder for {@link Model}.
  18029. + *
  18030. + * @deprecated Please use {@link Model#createModel(Context, String, Options)}.
  18031. + */
  18032. + @Deprecated
  18033. + public static class Builder {
  18034. + private Device device = Device.CPU;
  18035. + private int numThreads = 1;
  18036. + private final String modelPath;
  18037. + private final MappedByteBuffer byteModel;
  18038. +
  18039. + /**
  18040. + * Creates a builder which loads tflite model from asset folder using memory-mapped files.
  18041. + *
  18042. + * @param context Application context to access assets.
  18043. + * @param modelPath Asset path of the model (.tflite file).
  18044. + * @throws IOException if an I/O error occurs when loading the tflite model.
  18045. + */
  18046. + public Builder(@NonNull Context context, @NonNull String modelPath) throws IOException {
  18047. + this.modelPath = modelPath;
  18048. + byteModel = FileUtil.loadMappedFile(context, modelPath);
  18049. + }
  18050. +
  18051. + /** Sets running device. By default, TFLite will run on CPU. */
  18052. + @NonNull
  18053. + public Builder setDevice(Device device) {
  18054. + this.device = device;
  18055. + return this;
  18056. + }
  18057. +
  18058. + /** Sets number of threads. By default it's 1. */
  18059. + @NonNull
  18060. + public Builder setNumThreads(int numThreads) {
  18061. + this.numThreads = numThreads;
  18062. + return this;
  18063. + }
  18064. +
  18065. + // Note: The implementation is copied from `Model#createModel`. As the builder is going to
  18066. + // be deprecated, this function is also to be removed.
  18067. + @NonNull
  18068. + public Model build() {
  18069. + Options options =
  18070. + new Options.Builder().setNumThreads(numThreads).setDevice(device).build();
  18071. + return createModel(byteModel, modelPath, options);
  18072. + }
  18073. + }
  18074. - /**
  18075. - * Builder for {@link Model}.
  18076. - *
  18077. - * @deprecated Please use {@link Model#createModel(Context, String, Options)}.
  18078. - */
  18079. - @Deprecated
  18080. - public static class Builder {
  18081. - private Device device = Device.CPU;
  18082. - private int numThreads = 1;
  18083. - private final String modelPath;
  18084. - private final MappedByteBuffer byteModel;
  18085. + /**
  18086. + * Loads a model from assets and initialize TFLite interpreter.
  18087. + *
  18088. + * <p>The default options are: (1) CPU device; (2) one thread.
  18089. + *
  18090. + * @param context The App Context.
  18091. + * @param modelPath The path of the model file.
  18092. + * @throws IOException if any exception occurs when open the model file.
  18093. + */
  18094. + public static Model createModel(@NonNull Context context, @NonNull String modelPath)
  18095. + throws IOException {
  18096. + return createModel(context, modelPath, new Options.Builder().build());
  18097. + }
  18098. /**
  18099. - * Creates a builder which loads tflite model from asset folder using memory-mapped files.
  18100. + * Loads a model from assets and initialize TFLite interpreter with given options.
  18101. *
  18102. - * @param context Application context to access assets.
  18103. - * @param modelPath Asset path of the model (.tflite file).
  18104. - * @throws IOException if an I/O error occurs when loading the tflite model.
  18105. + * @see Options for details.
  18106. + * @param context The App Context.
  18107. + * @param modelPath The path of the model file.
  18108. + * @param options The options for running the model.
  18109. + * @throws IOException if any exception occurs when open the model file.
  18110. */
  18111. - public Builder(@NonNull Context context, @NonNull String modelPath) throws IOException {
  18112. - this.modelPath = modelPath;
  18113. - byteModel = FileUtil.loadMappedFile(context, modelPath);
  18114. + public static Model createModel(@NonNull Context context, @NonNull String modelPath,
  18115. + @NonNull Options options) throws IOException {
  18116. + SupportPreconditions.checkNotEmpty(
  18117. + modelPath, "Model path in the asset folder cannot be empty.");
  18118. + MappedByteBuffer byteModel = FileUtil.loadMappedFile(context, modelPath);
  18119. + return createModel(byteModel, modelPath, options);
  18120. }
  18121. - /** Sets running device. By default, TFLite will run on CPU. */
  18122. - @NonNull
  18123. - public Builder setDevice(Device device) {
  18124. - this.device = device;
  18125. - return this;
  18126. + /**
  18127. + * Creates a model with loaded {@link MappedByteBuffer}.
  18128. + *
  18129. + * @see Options for details.
  18130. + * @param byteModel The loaded TFLite model.
  18131. + * @param modelPath The original path of the model. It can be fetched later by {@link
  18132. + * Model#getPath()}.
  18133. + * @param options The options for running the model.
  18134. + * @throws IllegalArgumentException if {@code options.device} is {@link Device#GPU} but
  18135. + * "tensorflow-lite-gpu" is not linked to the project.
  18136. + */
  18137. + public static Model createModel(@NonNull MappedByteBuffer byteModel, @NonNull String modelPath,
  18138. + @NonNull Options options) {
  18139. + InterpreterApi.Options interpreterOptions = new InterpreterApi.Options();
  18140. + GpuDelegateProxy gpuDelegateProxy = null;
  18141. + switch (options.device) {
  18142. + case NNAPI:
  18143. + interpreterOptions.setUseNNAPI(true);
  18144. + break;
  18145. + case GPU:
  18146. + gpuDelegateProxy = GpuDelegateProxy.maybeNewInstance();
  18147. + SupportPreconditions.checkArgument(gpuDelegateProxy != null,
  18148. + "Cannot inference with GPU. Did you add \"tensorflow-lite-gpu\" as dependency?");
  18149. + interpreterOptions.addDelegate(gpuDelegateProxy);
  18150. + break;
  18151. + case CPU:
  18152. + break;
  18153. + }
  18154. + interpreterOptions.setNumThreads(options.numThreads);
  18155. + if (options.tfLiteRuntime != null) {
  18156. + interpreterOptions.setRuntime(options.tfLiteRuntime);
  18157. + }
  18158. + InterpreterApi interpreter = InterpreterApi.create(byteModel, interpreterOptions);
  18159. + return new Model(modelPath, byteModel, interpreter, gpuDelegateProxy);
  18160. }
  18161. - /** Sets number of threads. By default it's 1. */
  18162. + /** Returns the memory-mapped model data. */
  18163. @NonNull
  18164. - public Builder setNumThreads(int numThreads) {
  18165. - this.numThreads = numThreads;
  18166. - return this;
  18167. + public MappedByteBuffer getData() {
  18168. + return byteModel;
  18169. }
  18170. - // Note: The implementation is copied from `Model#createModel`. As the builder is going to be
  18171. - // deprecated, this function is also to be removed.
  18172. + /** Returns the path of the model file stored in Assets. */
  18173. @NonNull
  18174. - public Model build() {
  18175. - Options options = new Options.Builder().setNumThreads(numThreads).setDevice(device).build();
  18176. - return createModel(byteModel, modelPath, options);
  18177. + public String getPath() {
  18178. + return modelPath;
  18179. }
  18180. - }
  18181. -
  18182. - /**
  18183. - * Loads a model from assets and initialize TFLite interpreter.
  18184. - *
  18185. - * <p>The default options are: (1) CPU device; (2) one thread.
  18186. - *
  18187. - * @param context The App Context.
  18188. - * @param modelPath The path of the model file.
  18189. - * @throws IOException if any exception occurs when open the model file.
  18190. - */
  18191. - public static Model createModel(@NonNull Context context, @NonNull String modelPath)
  18192. - throws IOException {
  18193. - return createModel(context, modelPath, new Options.Builder().build());
  18194. - }
  18195. -
  18196. - /**
  18197. - * Loads a model from assets and initialize TFLite interpreter with given options.
  18198. - *
  18199. - * @see Options for details.
  18200. - * @param context The App Context.
  18201. - * @param modelPath The path of the model file.
  18202. - * @param options The options for running the model.
  18203. - * @throws IOException if any exception occurs when open the model file.
  18204. - */
  18205. - public static Model createModel(
  18206. - @NonNull Context context, @NonNull String modelPath, @NonNull Options options)
  18207. - throws IOException {
  18208. - SupportPreconditions.checkNotEmpty(
  18209. - modelPath, "Model path in the asset folder cannot be empty.");
  18210. - MappedByteBuffer byteModel = FileUtil.loadMappedFile(context, modelPath);
  18211. - return createModel(byteModel, modelPath, options);
  18212. - }
  18213. -
  18214. - /**
  18215. - * Creates a model with loaded {@link MappedByteBuffer}.
  18216. - *
  18217. - * @see Options for details.
  18218. - * @param byteModel The loaded TFLite model.
  18219. - * @param modelPath The original path of the model. It can be fetched later by {@link
  18220. - * Model#getPath()}.
  18221. - * @param options The options for running the model.
  18222. - * @throws IllegalArgumentException if {@code options.device} is {@link Device#GPU} but
  18223. - * "tensorflow-lite-gpu" is not linked to the project.
  18224. - */
  18225. - public static Model createModel(
  18226. - @NonNull MappedByteBuffer byteModel, @NonNull String modelPath, @NonNull Options options) {
  18227. - InterpreterApi.Options interpreterOptions = new InterpreterApi.Options();
  18228. - GpuDelegateProxy gpuDelegateProxy = null;
  18229. - switch (options.device) {
  18230. - case NNAPI:
  18231. - interpreterOptions.setUseNNAPI(true);
  18232. - break;
  18233. - case GPU:
  18234. - gpuDelegateProxy = GpuDelegateProxy.maybeNewInstance();
  18235. - SupportPreconditions.checkArgument(
  18236. - gpuDelegateProxy != null,
  18237. - "Cannot inference with GPU. Did you add \"tensorflow-lite-gpu\" as dependency?");
  18238. - interpreterOptions.addDelegate(gpuDelegateProxy);
  18239. - break;
  18240. - case CPU:
  18241. - break;
  18242. +
  18243. + /**
  18244. + * Gets the Tensor associated with the provided input index.
  18245. + *
  18246. + * @throws IllegalStateException if the interpreter is closed.
  18247. + */
  18248. + public Tensor getInputTensor(int inputIndex) {
  18249. + return interpreter.getInputTensor(inputIndex);
  18250. }
  18251. - interpreterOptions.setNumThreads(options.numThreads);
  18252. - if (options.tfLiteRuntime != null) {
  18253. - interpreterOptions.setRuntime(options.tfLiteRuntime);
  18254. +
  18255. + /**
  18256. + * Gets the Tensor associated with the provided output index.
  18257. + *
  18258. + * @throws IllegalStateException if the interpreter is closed.
  18259. + */
  18260. + public Tensor getOutputTensor(int outputIndex) {
  18261. + return interpreter.getOutputTensor(outputIndex);
  18262. }
  18263. - InterpreterApi interpreter = InterpreterApi.create(byteModel, interpreterOptions);
  18264. - return new Model(modelPath, byteModel, interpreter, gpuDelegateProxy);
  18265. - }
  18266. -
  18267. - /** Returns the memory-mapped model data. */
  18268. - @NonNull
  18269. - public MappedByteBuffer getData() {
  18270. - return byteModel;
  18271. - }
  18272. -
  18273. - /** Returns the path of the model file stored in Assets. */
  18274. - @NonNull
  18275. - public String getPath() {
  18276. - return modelPath;
  18277. - }
  18278. -
  18279. - /**
  18280. - * Gets the Tensor associated with the provided input index.
  18281. - *
  18282. - * @throws IllegalStateException if the interpreter is closed.
  18283. - */
  18284. - public Tensor getInputTensor(int inputIndex) {
  18285. - return interpreter.getInputTensor(inputIndex);
  18286. - }
  18287. -
  18288. - /**
  18289. - * Gets the Tensor associated with the provided output index.
  18290. - *
  18291. - * @throws IllegalStateException if the interpreter is closed.
  18292. - */
  18293. - public Tensor getOutputTensor(int outputIndex) {
  18294. - return interpreter.getOutputTensor(outputIndex);
  18295. - }
  18296. -
  18297. - /**
  18298. - * Returns the output shape. Useful if output shape is only determined when graph is created.
  18299. - *
  18300. - * @throws IllegalStateException if the interpreter is closed.
  18301. - */
  18302. - public int[] getOutputTensorShape(int outputIndex) {
  18303. - return interpreter.getOutputTensor(outputIndex).shape();
  18304. - }
  18305. -
  18306. - /**
  18307. - * Runs model inference on multiple inputs, and returns multiple outputs.
  18308. - *
  18309. - * @param inputs an array of input data. The inputs should be in the same order as inputs of the
  18310. - * model. Each input can be an array or multidimensional array, or a {@link
  18311. - * java.nio.ByteBuffer} of primitive types including int, float, long, and byte. {@link
  18312. - * java.nio.ByteBuffer} is the preferred way to pass large input data, whereas string types
  18313. - * require using the (multi-dimensional) array input path. When {@link java.nio.ByteBuffer} is
  18314. - * used, its content should remain unchanged until model inference is done.
  18315. - * @param outputs a map mapping output indices to multidimensional arrays of output data or {@link
  18316. - * java.nio.ByteBuffer}s of primitive types including int, float, long, and byte. It only
  18317. - * needs to keep entries for the outputs to be used.
  18318. - */
  18319. - public void run(@NonNull Object[] inputs, @NonNull Map<Integer, Object> outputs) {
  18320. - interpreter.runForMultipleInputsOutputs(inputs, outputs);
  18321. - }
  18322. -
  18323. - public void close() {
  18324. - if (interpreter != null) {
  18325. - interpreter.close();
  18326. +
  18327. + /**
  18328. + * Returns the output shape. Useful if output shape is only determined when graph is created.
  18329. + *
  18330. + * @throws IllegalStateException if the interpreter is closed.
  18331. + */
  18332. + public int[] getOutputTensorShape(int outputIndex) {
  18333. + return interpreter.getOutputTensor(outputIndex).shape();
  18334. + }
  18335. +
  18336. + /**
  18337. + * Runs model inference on multiple inputs, and returns multiple outputs.
  18338. + *
  18339. + * @param inputs an array of input data. The inputs should be in the same order as inputs of the
  18340. + * model. Each input can be an array or multidimensional array, or a {@link
  18341. + * java.nio.ByteBuffer} of primitive types including int, float, long, and byte. {@link
  18342. + * java.nio.ByteBuffer} is the preferred way to pass large input data, whereas string types
  18343. + * require using the (multi-dimensional) array input path. When {@link java.nio.ByteBuffer}
  18344. + * is used, its content should remain unchanged until model inference is done.
  18345. + * @param outputs a map mapping output indices to multidimensional arrays of output data or
  18346. + * {@link
  18347. + * java.nio.ByteBuffer}s of primitive types including int, float, long, and byte. It only
  18348. + * needs to keep entries for the outputs to be used.
  18349. + */
  18350. + public void run(@NonNull Object[] inputs, @NonNull Map<Integer, Object> outputs) {
  18351. + interpreter.runForMultipleInputsOutputs(inputs, outputs);
  18352. }
  18353. - if (gpuDelegateProxy != null) {
  18354. - gpuDelegateProxy.close();
  18355. +
  18356. + public void close() {
  18357. + if (interpreter != null) {
  18358. + interpreter.close();
  18359. + }
  18360. + if (gpuDelegateProxy != null) {
  18361. + gpuDelegateProxy.close();
  18362. + }
  18363. + }
  18364. +
  18365. + private Model(@NonNull String modelPath, @NonNull MappedByteBuffer byteModel,
  18366. + @NonNull InterpreterApi interpreter, @Nullable GpuDelegateProxy gpuDelegateProxy) {
  18367. + this.modelPath = modelPath;
  18368. + this.byteModel = byteModel;
  18369. + this.interpreter = interpreter;
  18370. + this.gpuDelegateProxy = gpuDelegateProxy;
  18371. }
  18372. - }
  18373. -
  18374. - private Model(
  18375. - @NonNull String modelPath,
  18376. - @NonNull MappedByteBuffer byteModel,
  18377. - @NonNull InterpreterApi interpreter,
  18378. - @Nullable GpuDelegateProxy gpuDelegateProxy) {
  18379. - this.modelPath = modelPath;
  18380. - this.byteModel = byteModel;
  18381. - this.interpreter = interpreter;
  18382. - this.gpuDelegateProxy = gpuDelegateProxy;
  18383. - }
  18384. }
  18385. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBuffer.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBuffer.java
  18386. index 9e0204bdc2e71..ec6c800ef557a 100644
  18387. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBuffer.java
  18388. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBuffer.java
  18389. @@ -19,473 +19,476 @@ import static org.tensorflow.lite.support.common.internal.SupportPreconditions.c
  18390. import static org.tensorflow.lite.support.common.internal.SupportPreconditions.checkNotNull;
  18391. import static org.tensorflow.lite.support.common.internal.SupportPreconditions.checkState;
  18392. +import org.checkerframework.checker.nullness.qual.NonNull;
  18393. +import org.tensorflow.lite.DataType;
  18394. +
  18395. import java.nio.ByteBuffer;
  18396. import java.nio.ByteOrder;
  18397. import java.util.Arrays;
  18398. -import org.checkerframework.checker.nullness.qual.NonNull;
  18399. -import org.tensorflow.lite.DataType;
  18400. /** Represents the data buffer for either a model's input or its output. */
  18401. public abstract class TensorBuffer {
  18402. - /** Where the data is stored. */
  18403. - protected ByteBuffer buffer;
  18404. -
  18405. - /** Shape of the tensor stored in this buffer. */
  18406. - protected int[] shape;
  18407. -
  18408. - /** Number of elements in the buffer. It will be changed to a proper value in the constructor. */
  18409. - protected int flatSize = -1;
  18410. -
  18411. - /**
  18412. - * Indicator of whether this buffer is dynamic or fixed-size. Fixed-size buffers will have
  18413. - * pre-allocated memory and fixed size. While the size of dynamic buffers can be changed.
  18414. - */
  18415. - protected final boolean isDynamic;
  18416. -
  18417. - /**
  18418. - * Creates a {@link TensorBuffer} with specified {@code shape} and {@link DataType}. Here are some
  18419. - * examples:
  18420. - *
  18421. - * <pre>
  18422. - * // Creating a float TensorBuffer with shape {2, 3}:
  18423. - * int[] shape = new int[] {2, 3};
  18424. - * TensorBuffer tensorBuffer = TensorBuffer.createFixedSize(shape, DataType.FLOAT32);
  18425. - * </pre>
  18426. - *
  18427. - * <pre>
  18428. - * // Creating an uint8 TensorBuffer of a scalar:
  18429. - * int[] shape = new int[] {};
  18430. - * TensorBuffer tensorBuffer = TensorBuffer.createFixedSize(shape, DataType.UINT8);
  18431. - * </pre>
  18432. - *
  18433. - * <pre>
  18434. - * // Creating an empty uint8 TensorBuffer:
  18435. - * int[] shape = new int[] {0};
  18436. - * TensorBuffer tensorBuffer = TensorBuffer.createFixedSize(shape, DataType.UINT8);
  18437. - * </pre>
  18438. - *
  18439. - * <p>The size of a fixed-size TensorBuffer cannot be changed once it is created.
  18440. - *
  18441. - * @param shape The shape of the {@link TensorBuffer} to be created.
  18442. - * @param dataType The dataType of the {@link TensorBuffer} to be created.
  18443. - * @throws NullPointerException if {@code shape} is null.
  18444. - * @throws IllegalArgumentException if {@code shape} has non-positive elements.
  18445. - */
  18446. - @NonNull
  18447. - public static TensorBuffer createFixedSize(@NonNull int[] shape, DataType dataType) {
  18448. - switch (dataType) {
  18449. - case FLOAT32:
  18450. - return new TensorBufferFloat(shape);
  18451. - case UINT8:
  18452. - return new TensorBufferUint8(shape);
  18453. - default:
  18454. - throw new AssertionError("TensorBuffer does not support data type: " + dataType);
  18455. + /** Where the data is stored. */
  18456. + protected ByteBuffer buffer;
  18457. +
  18458. + /** Shape of the tensor stored in this buffer. */
  18459. + protected int[] shape;
  18460. +
  18461. + /**
  18462. + * Number of elements in the buffer. It will be changed to a proper value in the constructor.
  18463. + */
  18464. + protected int flatSize = -1;
  18465. +
  18466. + /**
  18467. + * Indicator of whether this buffer is dynamic or fixed-size. Fixed-size buffers will have
  18468. + * pre-allocated memory and fixed size. While the size of dynamic buffers can be changed.
  18469. + */
  18470. + protected final boolean isDynamic;
  18471. +
  18472. + /**
  18473. + * Creates a {@link TensorBuffer} with specified {@code shape} and {@link DataType}. Here are
  18474. + * some examples:
  18475. + *
  18476. + * <pre>
  18477. + * // Creating a float TensorBuffer with shape {2, 3}:
  18478. + * int[] shape = new int[] {2, 3};
  18479. + * TensorBuffer tensorBuffer = TensorBuffer.createFixedSize(shape, DataType.FLOAT32);
  18480. + * </pre>
  18481. + *
  18482. + * <pre>
  18483. + * // Creating an uint8 TensorBuffer of a scalar:
  18484. + * int[] shape = new int[] {};
  18485. + * TensorBuffer tensorBuffer = TensorBuffer.createFixedSize(shape, DataType.UINT8);
  18486. + * </pre>
  18487. + *
  18488. + * <pre>
  18489. + * // Creating an empty uint8 TensorBuffer:
  18490. + * int[] shape = new int[] {0};
  18491. + * TensorBuffer tensorBuffer = TensorBuffer.createFixedSize(shape, DataType.UINT8);
  18492. + * </pre>
  18493. + *
  18494. + * <p>The size of a fixed-size TensorBuffer cannot be changed once it is created.
  18495. + *
  18496. + * @param shape The shape of the {@link TensorBuffer} to be created.
  18497. + * @param dataType The dataType of the {@link TensorBuffer} to be created.
  18498. + * @throws NullPointerException if {@code shape} is null.
  18499. + * @throws IllegalArgumentException if {@code shape} has non-positive elements.
  18500. + */
  18501. + @NonNull
  18502. + public static TensorBuffer createFixedSize(@NonNull int[] shape, DataType dataType) {
  18503. + switch (dataType) {
  18504. + case FLOAT32:
  18505. + return new TensorBufferFloat(shape);
  18506. + case UINT8:
  18507. + return new TensorBufferUint8(shape);
  18508. + default:
  18509. + throw new AssertionError("TensorBuffer does not support data type: " + dataType);
  18510. + }
  18511. + }
  18512. +
  18513. + /**
  18514. + * Creates an empty dynamic {@link TensorBuffer} with specified {@link DataType}. The shape of
  18515. + * the created {@link TensorBuffer} is {0}.
  18516. + *
  18517. + * <p>Dynamic TensorBuffers will reallocate memory when loading arrays or data buffers of
  18518. + * different buffer sizes. Here are some examples:
  18519. + *
  18520. + * <pre>
  18521. + * // Creating a float dynamic TensorBuffer:
  18522. + * TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.FLOAT32);
  18523. + * // Loading a float array:
  18524. + * float[] arr1 = new float[] {1, 2, 3};
  18525. + * tensorBuffer.loadArray(arr, new int[] {arr1.length});
  18526. + * // loading another float array:
  18527. + * float[] arr2 = new float[] {1, 2, 3, 4, 5};
  18528. + * tensorBuffer.loadArray(arr, new int[] {arr2.length});
  18529. + * // loading a third float array with the same size as arr2, assuming shape doesn't change:
  18530. + * float[] arr3 = new float[] {5, 4, 3, 2, 1};
  18531. + * tensorBuffer.loadArray(arr);
  18532. + * // loading a forth float array with different size as arr3 and omitting the shape will result
  18533. + * // in error:
  18534. + * float[] arr4 = new float[] {3, 2, 1};
  18535. + * tensorBuffer.loadArray(arr); // Error: The size of byte buffer and the shape do not match.
  18536. + * </pre>
  18537. + *
  18538. + * @param dataType The dataType of the {@link TensorBuffer} to be created.
  18539. + */
  18540. + @NonNull
  18541. + public static TensorBuffer createDynamic(DataType dataType) {
  18542. + switch (dataType) {
  18543. + case FLOAT32:
  18544. + return new TensorBufferFloat();
  18545. + case UINT8:
  18546. + return new TensorBufferUint8();
  18547. + default:
  18548. + throw new AssertionError("TensorBuffer does not support data type: " + dataType);
  18549. + }
  18550. }
  18551. - }
  18552. -
  18553. - /**
  18554. - * Creates an empty dynamic {@link TensorBuffer} with specified {@link DataType}. The shape of the
  18555. - * created {@link TensorBuffer} is {0}.
  18556. - *
  18557. - * <p>Dynamic TensorBuffers will reallocate memory when loading arrays or data buffers of
  18558. - * different buffer sizes. Here are some examples:
  18559. - *
  18560. - * <pre>
  18561. - * // Creating a float dynamic TensorBuffer:
  18562. - * TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.FLOAT32);
  18563. - * // Loading a float array:
  18564. - * float[] arr1 = new float[] {1, 2, 3};
  18565. - * tensorBuffer.loadArray(arr, new int[] {arr1.length});
  18566. - * // loading another float array:
  18567. - * float[] arr2 = new float[] {1, 2, 3, 4, 5};
  18568. - * tensorBuffer.loadArray(arr, new int[] {arr2.length});
  18569. - * // loading a third float array with the same size as arr2, assuming shape doesn't change:
  18570. - * float[] arr3 = new float[] {5, 4, 3, 2, 1};
  18571. - * tensorBuffer.loadArray(arr);
  18572. - * // loading a forth float array with different size as arr3 and omitting the shape will result
  18573. - * // in error:
  18574. - * float[] arr4 = new float[] {3, 2, 1};
  18575. - * tensorBuffer.loadArray(arr); // Error: The size of byte buffer and the shape do not match.
  18576. - * </pre>
  18577. - *
  18578. - * @param dataType The dataType of the {@link TensorBuffer} to be created.
  18579. - */
  18580. - @NonNull
  18581. - public static TensorBuffer createDynamic(DataType dataType) {
  18582. - switch (dataType) {
  18583. - case FLOAT32:
  18584. - return new TensorBufferFloat();
  18585. - case UINT8:
  18586. - return new TensorBufferUint8();
  18587. - default:
  18588. - throw new AssertionError("TensorBuffer does not support data type: " + dataType);
  18589. +
  18590. + /**
  18591. + * Creates a {@link TensorBuffer} deep-copying data from another, with specified {@link
  18592. + * DataType}.
  18593. + *
  18594. + * @param buffer the source {@link TensorBuffer} to copy from.
  18595. + * @param dataType the expected {@link DataType} of newly created {@link TensorBuffer}.
  18596. + * @throws NullPointerException if {@code buffer} is null.
  18597. + */
  18598. + @NonNull
  18599. + public static TensorBuffer createFrom(@NonNull TensorBuffer buffer, DataType dataType) {
  18600. + checkNotNull(buffer, "Cannot create a buffer from null");
  18601. + TensorBuffer result;
  18602. + if (buffer.isDynamic()) {
  18603. + result = createDynamic(dataType);
  18604. + } else {
  18605. + result = createFixedSize(buffer.shape, dataType);
  18606. + }
  18607. + // The only scenario we need float array is FLOAT32->FLOAT32, or we can always use INT as
  18608. + // intermediate container.
  18609. + // The assumption is not true when we support other data types.
  18610. + if (buffer.getDataType() == DataType.FLOAT32 && dataType == DataType.FLOAT32) {
  18611. + float[] data = buffer.getFloatArray();
  18612. + result.loadArray(data, buffer.shape);
  18613. + } else {
  18614. + int[] data = buffer.getIntArray();
  18615. + result.loadArray(data, buffer.shape);
  18616. + }
  18617. + return result;
  18618. }
  18619. - }
  18620. -
  18621. - /**
  18622. - * Creates a {@link TensorBuffer} deep-copying data from another, with specified {@link DataType}.
  18623. - *
  18624. - * @param buffer the source {@link TensorBuffer} to copy from.
  18625. - * @param dataType the expected {@link DataType} of newly created {@link TensorBuffer}.
  18626. - * @throws NullPointerException if {@code buffer} is null.
  18627. - */
  18628. - @NonNull
  18629. - public static TensorBuffer createFrom(@NonNull TensorBuffer buffer, DataType dataType) {
  18630. - checkNotNull(buffer, "Cannot create a buffer from null");
  18631. - TensorBuffer result;
  18632. - if (buffer.isDynamic()) {
  18633. - result = createDynamic(dataType);
  18634. - } else {
  18635. - result = createFixedSize(buffer.shape, dataType);
  18636. +
  18637. + /** Returns the data buffer. */
  18638. + @NonNull
  18639. + public ByteBuffer getBuffer() {
  18640. + return buffer;
  18641. }
  18642. - // The only scenario we need float array is FLOAT32->FLOAT32, or we can always use INT as
  18643. - // intermediate container.
  18644. - // The assumption is not true when we support other data types.
  18645. - if (buffer.getDataType() == DataType.FLOAT32 && dataType == DataType.FLOAT32) {
  18646. - float[] data = buffer.getFloatArray();
  18647. - result.loadArray(data, buffer.shape);
  18648. - } else {
  18649. - int[] data = buffer.getIntArray();
  18650. - result.loadArray(data, buffer.shape);
  18651. +
  18652. + /**
  18653. + * Gets the flatSize of the buffer.
  18654. + *
  18655. + * @throws IllegalStateException if the underlying data is corrupted
  18656. + */
  18657. + public int getFlatSize() {
  18658. + assertShapeIsCorrect();
  18659. + return flatSize;
  18660. }
  18661. - return result;
  18662. - }
  18663. -
  18664. - /** Returns the data buffer. */
  18665. - @NonNull
  18666. - public ByteBuffer getBuffer() {
  18667. - return buffer;
  18668. - }
  18669. -
  18670. - /**
  18671. - * Gets the flatSize of the buffer.
  18672. - *
  18673. - * @throws IllegalStateException if the underlying data is corrupted
  18674. - */
  18675. - public int getFlatSize() {
  18676. - assertShapeIsCorrect();
  18677. - return flatSize;
  18678. - }
  18679. -
  18680. - /**
  18681. - * Gets the current shape. (returning a copy here to avoid unexpected modification.)
  18682. - *
  18683. - * @throws IllegalStateException if the underlying data is corrupted
  18684. - */
  18685. - @NonNull
  18686. - public int[] getShape() {
  18687. - assertShapeIsCorrect();
  18688. - return Arrays.copyOf(shape, shape.length);
  18689. - }
  18690. -
  18691. - /** Returns the data type of this buffer. */
  18692. - public abstract DataType getDataType();
  18693. -
  18694. - /**
  18695. - * Returns a float array of the values stored in this buffer. If the buffer is of different types
  18696. - * than float, the values will be converted into float. For example, values in {@link
  18697. - * TensorBufferUint8} will be converted from uint8 to float.
  18698. - */
  18699. - @NonNull
  18700. - public abstract float[] getFloatArray();
  18701. -
  18702. - /**
  18703. - * Returns a float value at a given index. If the buffer is of different types than float, the
  18704. - * value will be converted into float. For example, when reading a value from {@link
  18705. - * TensorBufferUint8}, the value will be first read out as uint8, and then will be converted from
  18706. - * uint8 to float.
  18707. - *
  18708. - * <pre>
  18709. - * For example, a TensorBuffer with shape {2, 3} that represents the following array,
  18710. - * [[0.0f, 1.0f, 2.0f], [3.0f, 4.0f, 5.0f]].
  18711. - *
  18712. - * The fourth element (whose value is 3.0f) in the TensorBuffer can be retrieved by:
  18713. - * float v = tensorBuffer.getFloatValue(3);
  18714. - * </pre>
  18715. - *
  18716. - * @param absIndex The absolute index of the value to be read.
  18717. - */
  18718. - public abstract float getFloatValue(int absIndex);
  18719. -
  18720. - /**
  18721. - * Returns an int array of the values stored in this buffer. If the buffer is of different type
  18722. - * than int, the values will be converted into int, and loss of precision may apply. For example,
  18723. - * getting an int array from a {@link TensorBufferFloat} with values {400.32f, 23.04f}, the output
  18724. - * is {400, 23}.
  18725. - */
  18726. - @NonNull
  18727. - public abstract int[] getIntArray();
  18728. -
  18729. - /**
  18730. - * Returns an int value at a given index. If the buffer is of different types than int, the value
  18731. - * will be converted into int. For example, when reading a value from {@link TensorBufferFloat},
  18732. - * the value will be first read out as float, and then will be converted from float to int. Loss
  18733. - * of precision may apply.
  18734. - *
  18735. - * <pre>
  18736. - * For example, a TensorBuffer with shape {2, 3} that represents the following array,
  18737. - * [[0.0f, 1.0f, 2.0f], [3.0f, 4.0f, 5.0f]].
  18738. - *
  18739. - * The fourth element (whose value is 3.0f) in the TensorBuffer can be retrieved by:
  18740. - * int v = tensorBuffer.getIntValue(3);
  18741. - * Note that v is converted from 3.0f to 3 as a result of type conversion.
  18742. - * </pre>
  18743. - *
  18744. - * @param absIndex The absolute index of the value to be read.
  18745. - */
  18746. - public abstract int getIntValue(int absIndex);
  18747. -
  18748. - /**
  18749. - * Returns the number of bytes of a single element in the array. For example, a float buffer will
  18750. - * return 4, and a byte buffer will return 1.
  18751. - */
  18752. - public abstract int getTypeSize();
  18753. -
  18754. - /** Returns if the {@link TensorBuffer} is dynamic sized (could resize arbitrarily). */
  18755. - public boolean isDynamic() {
  18756. - return isDynamic;
  18757. - }
  18758. -
  18759. - /**
  18760. - * Loads an int array into this buffer with specific shape. If the buffer is of different types
  18761. - * than int, the values will be converted into the buffer's type before being loaded into the
  18762. - * buffer, and loss of precision may apply. For example, loading an int array with values {400,
  18763. - * -23} into a {@link TensorBufferUint8} , the values will be clamped to [0, 255] and then be
  18764. - * casted to uint8 by {255, 0}.
  18765. - *
  18766. - * @param src The source array to be loaded.
  18767. - * @param shape Shape of the tensor that {@code src} represents.
  18768. - * @throws NullPointerException if {@code src} is null.
  18769. - * @throws NullPointerException if {@code shape} is null.
  18770. - * @throws IllegalArgumentException if the size of the array to be loaded does not match the
  18771. - * specified shape.
  18772. - */
  18773. - public abstract void loadArray(@NonNull int[] src, @NonNull int[] shape);
  18774. -
  18775. - /**
  18776. - * Loads an int array into this buffer. If the buffer is of different types than int, the values
  18777. - * will be converted into the buffer's type before being loaded into the buffer, and loss of
  18778. - * precision may apply. For example, loading an int array with values {400, -23} into a {@link
  18779. - * TensorBufferUint8} , the values will be clamped to [0, 255] and then be casted to uint8 by
  18780. - * {255, 0}.
  18781. - *
  18782. - * <p>Using this method assumes that the shape of {@code src} is the same as the shape of this
  18783. - * {@link TensorBuffer}. Thus the size of {@code buffer} ({@code src.length}) should always match
  18784. - * the flat size of this {@link TensorBuffer}, for both fixed-size and dynamic {@link
  18785. - * TensorBuffer}. Use {@link #loadArray(int[], int[])} if {@code src} has a different shape.
  18786. - *
  18787. - * @param src The source array to be loaded.
  18788. - */
  18789. - public void loadArray(@NonNull int[] src) {
  18790. - loadArray(src, shape);
  18791. - }
  18792. -
  18793. - /**
  18794. - * Loads a float array into this buffer with specific shape. If the buffer is of different types
  18795. - * than float, the values will be converted into the buffer's type before being loaded into the
  18796. - * buffer, and loss of precision may apply. For example, loading a float array into a {@link
  18797. - * TensorBufferUint8} with values {400.32f, -23.04f}, the values will be clamped to [0, 255] and
  18798. - * then be casted to uint8 by {255, 0}.
  18799. - *
  18800. - * @param src The source array to be loaded.
  18801. - * @param shape Shape of the tensor that {@code src} represents.
  18802. - * @throws NullPointerException if {@code src} is null.
  18803. - * @throws NullPointerException if {@code shape} is null.
  18804. - * @throws IllegalArgumentException if the size of the array to be loaded does not match the
  18805. - * specified shape.
  18806. - */
  18807. - public abstract void loadArray(@NonNull float[] src, @NonNull int[] shape);
  18808. -
  18809. - /**
  18810. - * Loads a float array into this buffer. If the buffer is of different types than float, the
  18811. - * values will be converted into the buffer's type before being loaded into the buffer, and loss
  18812. - * of precision may apply. For example, loading a float array into a {@link TensorBufferUint8}
  18813. - * with values {400.32f, -23.04f}, the values will be clamped to [0, 255] and then be casted to
  18814. - * uint8 by {255, 0}.
  18815. - *
  18816. - * <p>Using this method assumes that the shape of {@code src} is the same as the shape of this
  18817. - * {@link TensorBuffer}. Thus the size of {@code buffer} ({@code src.length}) should always match
  18818. - * the flat size of this {@link TensorBuffer}, for both fixed-size and dynamic {@link
  18819. - * TensorBuffer}. Use {@link #loadArray(float[], int[])} if {@code src} has a different shape.
  18820. - *
  18821. - * @param src The source array to be loaded.
  18822. - */
  18823. - public void loadArray(@NonNull float[] src) {
  18824. - loadArray(src, shape);
  18825. - }
  18826. -
  18827. - /**
  18828. - * Loads a byte buffer into this {@link TensorBuffer} with specific shape.
  18829. - *
  18830. - * <p>Important: The loaded buffer is a reference. DO NOT MODIFY. We don't create a copy here for
  18831. - * performance concern, but if modification is necessary, please make a copy.
  18832. - *
  18833. - * <p>For the best performance, always load a direct {@link ByteBuffer} or a {@link ByteBuffer}
  18834. - * backed by an array.
  18835. - *
  18836. - * @param buffer The byte buffer to load.
  18837. - * @throws NullPointerException if {@code buffer} is null.
  18838. - * @throws IllegalArgumentException if the size of {@code buffer} and {@code typeSize} do not
  18839. - * match or the size of {@code buffer} and {@code flatSize} do not match.
  18840. - */
  18841. - public void loadBuffer(@NonNull ByteBuffer buffer, @NonNull int[] shape) {
  18842. - checkNotNull(buffer, "Byte buffer cannot be null.");
  18843. - checkArgument(isShapeValid(shape), "Values in TensorBuffer shape should be non-negative.");
  18844. -
  18845. - int flatSize = computeFlatSize(shape);
  18846. - checkArgument(
  18847. - (buffer.limit() == getTypeSize() * flatSize),
  18848. - "The size of byte buffer and the shape do not match. Expected: "
  18849. - + getTypeSize() * flatSize
  18850. - + " Actual: "
  18851. - + buffer.limit());
  18852. -
  18853. - if (!isDynamic) {
  18854. - // Make sure the new shape fits the buffer size when TensorBuffer has fixed size.
  18855. - checkArgument(Arrays.equals(shape, this.shape));
  18856. +
  18857. + /**
  18858. + * Gets the current shape. (returning a copy here to avoid unexpected modification.)
  18859. + *
  18860. + * @throws IllegalStateException if the underlying data is corrupted
  18861. + */
  18862. + @NonNull
  18863. + public int[] getShape() {
  18864. + assertShapeIsCorrect();
  18865. + return Arrays.copyOf(shape, shape.length);
  18866. }
  18867. - // Update to the new shape, since shape dim values might change.
  18868. - this.shape = shape.clone();
  18869. - this.flatSize = flatSize;
  18870. -
  18871. - buffer.rewind();
  18872. - this.buffer = buffer;
  18873. - }
  18874. -
  18875. - /**
  18876. - * Loads a byte buffer into this {@link TensorBuffer}. Buffer size must match the flat size of
  18877. - * this {@link TensorBuffer}.
  18878. - *
  18879. - * <p>Using this method assumes that the shape of {@code buffer} is the same as the shape of this
  18880. - * {@link TensorBuffer}. Thus the size of {@code buffer} ({@code buffer.limit()}) should always
  18881. - * match the flat size of this {@link TensorBuffer}, for both fixed-size and dynamic {@link
  18882. - * TensorBuffer}. Use {@link #loadBuffer(ByteBuffer, int[])} if {@code buffer} has a different
  18883. - * shape.
  18884. - *
  18885. - * <p>Important: The loaded buffer is a reference. DO NOT MODIFY. We don't create a copy here for
  18886. - * performance concern, but if modification is necessary, please make a copy.
  18887. - *
  18888. - * <p>For the best performance, always load a direct {@link ByteBuffer} or a {@link ByteBuffer}
  18889. - * backed by an array.
  18890. - *
  18891. - * <p>If the {@code buffer} is read-only, we adopt a copy-on-write strategy for performance.
  18892. - *
  18893. - * @param buffer The byte buffer to load.
  18894. - */
  18895. - public void loadBuffer(@NonNull ByteBuffer buffer) {
  18896. - loadBuffer(buffer, shape);
  18897. - }
  18898. -
  18899. - /**
  18900. - * Constructs a fixed size {@link TensorBuffer} with specified {@code shape}.
  18901. - *
  18902. - * @throws NullPointerException if {@code shape} is null.
  18903. - * @throws IllegalArgumentException if {@code shape} has non-positive elements.
  18904. - */
  18905. - protected TensorBuffer(@NonNull int[] shape) {
  18906. - isDynamic = false;
  18907. - allocateMemory(shape);
  18908. - }
  18909. -
  18910. - /** Constructs a dynamic {@link TensorBuffer} which can be resized. */
  18911. - protected TensorBuffer() {
  18912. - isDynamic = true;
  18913. - // Initialize the dynamic TensorBuffer with an empty ByteBuffer.
  18914. - allocateMemory(new int[] {0});
  18915. - }
  18916. -
  18917. - /** Calculates number of elements in the buffer. */
  18918. - protected static int computeFlatSize(@NonNull int[] shape) {
  18919. - checkNotNull(shape, "Shape cannot be null.");
  18920. - int prod = 1;
  18921. - for (int s : shape) {
  18922. - prod = prod * s;
  18923. + /** Returns the data type of this buffer. */
  18924. + public abstract DataType getDataType();
  18925. +
  18926. + /**
  18927. + * Returns a float array of the values stored in this buffer. If the buffer is of different
  18928. + * types than float, the values will be converted into float. For example, values in {@link
  18929. + * TensorBufferUint8} will be converted from uint8 to float.
  18930. + */
  18931. + @NonNull
  18932. + public abstract float[] getFloatArray();
  18933. +
  18934. + /**
  18935. + * Returns a float value at a given index. If the buffer is of different types than float, the
  18936. + * value will be converted into float. For example, when reading a value from {@link
  18937. + * TensorBufferUint8}, the value will be first read out as uint8, and then will be converted
  18938. + * from uint8 to float.
  18939. + *
  18940. + * <pre>
  18941. + * For example, a TensorBuffer with shape {2, 3} that represents the following array,
  18942. + * [[0.0f, 1.0f, 2.0f], [3.0f, 4.0f, 5.0f]].
  18943. + *
  18944. + * The fourth element (whose value is 3.0f) in the TensorBuffer can be retrieved by:
  18945. + * float v = tensorBuffer.getFloatValue(3);
  18946. + * </pre>
  18947. + *
  18948. + * @param absIndex The absolute index of the value to be read.
  18949. + */
  18950. + public abstract float getFloatValue(int absIndex);
  18951. +
  18952. + /**
  18953. + * Returns an int array of the values stored in this buffer. If the buffer is of different type
  18954. + * than int, the values will be converted into int, and loss of precision may apply. For
  18955. + * example, getting an int array from a {@link TensorBufferFloat} with values {400.32f, 23.04f},
  18956. + * the output is {400, 23}.
  18957. + */
  18958. + @NonNull
  18959. + public abstract int[] getIntArray();
  18960. +
  18961. + /**
  18962. + * Returns an int value at a given index. If the buffer is of different types than int, the
  18963. + * value will be converted into int. For example, when reading a value from {@link
  18964. + * TensorBufferFloat}, the value will be first read out as float, and then will be converted
  18965. + * from float to int. Loss of precision may apply.
  18966. + *
  18967. + * <pre>
  18968. + * For example, a TensorBuffer with shape {2, 3} that represents the following array,
  18969. + * [[0.0f, 1.0f, 2.0f], [3.0f, 4.0f, 5.0f]].
  18970. + *
  18971. + * The fourth element (whose value is 3.0f) in the TensorBuffer can be retrieved by:
  18972. + * int v = tensorBuffer.getIntValue(3);
  18973. + * Note that v is converted from 3.0f to 3 as a result of type conversion.
  18974. + * </pre>
  18975. + *
  18976. + * @param absIndex The absolute index of the value to be read.
  18977. + */
  18978. + public abstract int getIntValue(int absIndex);
  18979. +
  18980. + /**
  18981. + * Returns the number of bytes of a single element in the array. For example, a float buffer
  18982. + * will return 4, and a byte buffer will return 1.
  18983. + */
  18984. + public abstract int getTypeSize();
  18985. +
  18986. + /** Returns if the {@link TensorBuffer} is dynamic sized (could resize arbitrarily). */
  18987. + public boolean isDynamic() {
  18988. + return isDynamic;
  18989. }
  18990. - return prod;
  18991. - }
  18992. -
  18993. - /**
  18994. - * For dynamic buffer, resize the memory if needed. For fixed-size buffer, check if the {@code
  18995. - * shape} of src fits the buffer size.
  18996. - */
  18997. - protected void resize(@NonNull int[] shape) {
  18998. - if (isDynamic) {
  18999. - allocateMemory(shape);
  19000. - } else {
  19001. - // Make sure the new shape fits the buffer size when TensorBuffer has fixed size.
  19002. - checkArgument(Arrays.equals(shape, this.shape));
  19003. - this.shape = shape.clone();
  19004. +
  19005. + /**
  19006. + * Loads an int array into this buffer with specific shape. If the buffer is of different types
  19007. + * than int, the values will be converted into the buffer's type before being loaded into the
  19008. + * buffer, and loss of precision may apply. For example, loading an int array with values {400,
  19009. + * -23} into a {@link TensorBufferUint8} , the values will be clamped to [0, 255] and then be
  19010. + * casted to uint8 by {255, 0}.
  19011. + *
  19012. + * @param src The source array to be loaded.
  19013. + * @param shape Shape of the tensor that {@code src} represents.
  19014. + * @throws NullPointerException if {@code src} is null.
  19015. + * @throws NullPointerException if {@code shape} is null.
  19016. + * @throws IllegalArgumentException if the size of the array to be loaded does not match the
  19017. + * specified shape.
  19018. + */
  19019. + public abstract void loadArray(@NonNull int[] src, @NonNull int[] shape);
  19020. +
  19021. + /**
  19022. + * Loads an int array into this buffer. If the buffer is of different types than int, the values
  19023. + * will be converted into the buffer's type before being loaded into the buffer, and loss of
  19024. + * precision may apply. For example, loading an int array with values {400, -23} into a {@link
  19025. + * TensorBufferUint8} , the values will be clamped to [0, 255] and then be casted to uint8 by
  19026. + * {255, 0}.
  19027. + *
  19028. + * <p>Using this method assumes that the shape of {@code src} is the same as the shape of this
  19029. + * {@link TensorBuffer}. Thus the size of {@code buffer} ({@code src.length}) should always
  19030. + * match the flat size of this {@link TensorBuffer}, for both fixed-size and dynamic {@link
  19031. + * TensorBuffer}. Use {@link #loadArray(int[], int[])} if {@code src} has a different shape.
  19032. + *
  19033. + * @param src The source array to be loaded.
  19034. + */
  19035. + public void loadArray(@NonNull int[] src) {
  19036. + loadArray(src, shape);
  19037. + }
  19038. +
  19039. + /**
  19040. + * Loads a float array into this buffer with specific shape. If the buffer is of different types
  19041. + * than float, the values will be converted into the buffer's type before being loaded into the
  19042. + * buffer, and loss of precision may apply. For example, loading a float array into a {@link
  19043. + * TensorBufferUint8} with values {400.32f, -23.04f}, the values will be clamped to [0, 255] and
  19044. + * then be casted to uint8 by {255, 0}.
  19045. + *
  19046. + * @param src The source array to be loaded.
  19047. + * @param shape Shape of the tensor that {@code src} represents.
  19048. + * @throws NullPointerException if {@code src} is null.
  19049. + * @throws NullPointerException if {@code shape} is null.
  19050. + * @throws IllegalArgumentException if the size of the array to be loaded does not match the
  19051. + * specified shape.
  19052. + */
  19053. + public abstract void loadArray(@NonNull float[] src, @NonNull int[] shape);
  19054. +
  19055. + /**
  19056. + * Loads a float array into this buffer. If the buffer is of different types than float, the
  19057. + * values will be converted into the buffer's type before being loaded into the buffer, and loss
  19058. + * of precision may apply. For example, loading a float array into a {@link TensorBufferUint8}
  19059. + * with values {400.32f, -23.04f}, the values will be clamped to [0, 255] and then be casted to
  19060. + * uint8 by {255, 0}.
  19061. + *
  19062. + * <p>Using this method assumes that the shape of {@code src} is the same as the shape of this
  19063. + * {@link TensorBuffer}. Thus the size of {@code buffer} ({@code src.length}) should always
  19064. + * match the flat size of this {@link TensorBuffer}, for both fixed-size and dynamic {@link
  19065. + * TensorBuffer}. Use {@link #loadArray(float[], int[])} if {@code src} has a different shape.
  19066. + *
  19067. + * @param src The source array to be loaded.
  19068. + */
  19069. + public void loadArray(@NonNull float[] src) {
  19070. + loadArray(src, shape);
  19071. }
  19072. - }
  19073. - /** Copies the underlying {@link ByteBuffer} if it's readonly. */
  19074. - protected synchronized void copyByteBufferIfReadOnly() {
  19075. - if (!buffer.isReadOnly()) {
  19076. - return;
  19077. + /**
  19078. + * Loads a byte buffer into this {@link TensorBuffer} with specific shape.
  19079. + *
  19080. + * <p>Important: The loaded buffer is a reference. DO NOT MODIFY. We don't create a copy here
  19081. + * for performance concern, but if modification is necessary, please make a copy.
  19082. + *
  19083. + * <p>For the best performance, always load a direct {@link ByteBuffer} or a {@link ByteBuffer}
  19084. + * backed by an array.
  19085. + *
  19086. + * @param buffer The byte buffer to load.
  19087. + * @throws NullPointerException if {@code buffer} is null.
  19088. + * @throws IllegalArgumentException if the size of {@code buffer} and {@code typeSize} do not
  19089. + * match or the size of {@code buffer} and {@code flatSize} do not match.
  19090. + */
  19091. + public void loadBuffer(@NonNull ByteBuffer buffer, @NonNull int[] shape) {
  19092. + checkNotNull(buffer, "Byte buffer cannot be null.");
  19093. + checkArgument(isShapeValid(shape), "Values in TensorBuffer shape should be non-negative.");
  19094. +
  19095. + int flatSize = computeFlatSize(shape);
  19096. + checkArgument((buffer.limit() == getTypeSize() * flatSize),
  19097. + "The size of byte buffer and the shape do not match. Expected: "
  19098. + + getTypeSize() * flatSize + " Actual: " + buffer.limit());
  19099. +
  19100. + if (!isDynamic) {
  19101. + // Make sure the new shape fits the buffer size when TensorBuffer has fixed size.
  19102. + checkArgument(Arrays.equals(shape, this.shape));
  19103. + }
  19104. +
  19105. + // Update to the new shape, since shape dim values might change.
  19106. + this.shape = shape.clone();
  19107. + this.flatSize = flatSize;
  19108. +
  19109. + buffer.rewind();
  19110. + this.buffer = buffer;
  19111. }
  19112. - ByteBuffer newByteBuffer = ByteBuffer.allocateDirect(buffer.capacity());
  19113. - newByteBuffer.order(buffer.order());
  19114. - newByteBuffer.put(buffer);
  19115. - newByteBuffer.rewind();
  19116. - buffer = newByteBuffer;
  19117. - }
  19118. -
  19119. - /**
  19120. - * Allocates buffer with corresponding size of the {@code shape}. If shape is an empty array, this
  19121. - * {@link TensorBuffer} will be created as a scalar and its flatSize will be 1.
  19122. - *
  19123. - * @throws NullPointerException if {@code shape} is null.
  19124. - * @throws IllegalArgumentException if {@code shape} has negative elements.
  19125. - */
  19126. - private void allocateMemory(@NonNull int[] shape) {
  19127. - checkNotNull(shape, "TensorBuffer shape cannot be null.");
  19128. - checkArgument(isShapeValid(shape), "Values in TensorBuffer shape should be non-negative.");
  19129. -
  19130. - // Check if the new shape is the same as current shape.
  19131. - int newFlatSize = computeFlatSize(shape);
  19132. - this.shape = shape.clone();
  19133. - if (flatSize == newFlatSize) {
  19134. - return;
  19135. +
  19136. + /**
  19137. + * Loads a byte buffer into this {@link TensorBuffer}. Buffer size must match the flat size of
  19138. + * this {@link TensorBuffer}.
  19139. + *
  19140. + * <p>Using this method assumes that the shape of {@code buffer} is the same as the shape of
  19141. + * this
  19142. + * {@link TensorBuffer}. Thus the size of {@code buffer} ({@code buffer.limit()}) should always
  19143. + * match the flat size of this {@link TensorBuffer}, for both fixed-size and dynamic {@link
  19144. + * TensorBuffer}. Use {@link #loadBuffer(ByteBuffer, int[])} if {@code buffer} has a different
  19145. + * shape.
  19146. + *
  19147. + * <p>Important: The loaded buffer is a reference. DO NOT MODIFY. We don't create a copy here
  19148. + * for performance concern, but if modification is necessary, please make a copy.
  19149. + *
  19150. + * <p>For the best performance, always load a direct {@link ByteBuffer} or a {@link ByteBuffer}
  19151. + * backed by an array.
  19152. + *
  19153. + * <p>If the {@code buffer} is read-only, we adopt a copy-on-write strategy for performance.
  19154. + *
  19155. + * @param buffer The byte buffer to load.
  19156. + */
  19157. + public void loadBuffer(@NonNull ByteBuffer buffer) {
  19158. + loadBuffer(buffer, shape);
  19159. + }
  19160. +
  19161. + /**
  19162. + * Constructs a fixed size {@link TensorBuffer} with specified {@code shape}.
  19163. + *
  19164. + * @throws NullPointerException if {@code shape} is null.
  19165. + * @throws IllegalArgumentException if {@code shape} has non-positive elements.
  19166. + */
  19167. + protected TensorBuffer(@NonNull int[] shape) {
  19168. + isDynamic = false;
  19169. + allocateMemory(shape);
  19170. + }
  19171. +
  19172. + /** Constructs a dynamic {@link TensorBuffer} which can be resized. */
  19173. + protected TensorBuffer() {
  19174. + isDynamic = true;
  19175. + // Initialize the dynamic TensorBuffer with an empty ByteBuffer.
  19176. + allocateMemory(new int[] {0});
  19177. + }
  19178. +
  19179. + /** Calculates number of elements in the buffer. */
  19180. + protected static int computeFlatSize(@NonNull int[] shape) {
  19181. + checkNotNull(shape, "Shape cannot be null.");
  19182. + int prod = 1;
  19183. + for (int s : shape) {
  19184. + prod = prod * s;
  19185. + }
  19186. + return prod;
  19187. + }
  19188. +
  19189. + /**
  19190. + * For dynamic buffer, resize the memory if needed. For fixed-size buffer, check if the {@code
  19191. + * shape} of src fits the buffer size.
  19192. + */
  19193. + protected void resize(@NonNull int[] shape) {
  19194. + if (isDynamic) {
  19195. + allocateMemory(shape);
  19196. + } else {
  19197. + // Make sure the new shape fits the buffer size when TensorBuffer has fixed size.
  19198. + checkArgument(Arrays.equals(shape, this.shape));
  19199. + this.shape = shape.clone();
  19200. + }
  19201. + }
  19202. +
  19203. + /** Copies the underlying {@link ByteBuffer} if it's readonly. */
  19204. + protected synchronized void copyByteBufferIfReadOnly() {
  19205. + if (!buffer.isReadOnly()) {
  19206. + return;
  19207. + }
  19208. + ByteBuffer newByteBuffer = ByteBuffer.allocateDirect(buffer.capacity());
  19209. + newByteBuffer.order(buffer.order());
  19210. + newByteBuffer.put(buffer);
  19211. + newByteBuffer.rewind();
  19212. + buffer = newByteBuffer;
  19213. + }
  19214. +
  19215. + /**
  19216. + * Allocates buffer with corresponding size of the {@code shape}. If shape is an empty array,
  19217. + * this
  19218. + * {@link TensorBuffer} will be created as a scalar and its flatSize will be 1.
  19219. + *
  19220. + * @throws NullPointerException if {@code shape} is null.
  19221. + * @throws IllegalArgumentException if {@code shape} has negative elements.
  19222. + */
  19223. + private void allocateMemory(@NonNull int[] shape) {
  19224. + checkNotNull(shape, "TensorBuffer shape cannot be null.");
  19225. + checkArgument(isShapeValid(shape), "Values in TensorBuffer shape should be non-negative.");
  19226. +
  19227. + // Check if the new shape is the same as current shape.
  19228. + int newFlatSize = computeFlatSize(shape);
  19229. + this.shape = shape.clone();
  19230. + if (flatSize == newFlatSize) {
  19231. + return;
  19232. + }
  19233. +
  19234. + // Update to the new shape.
  19235. + flatSize = newFlatSize;
  19236. + buffer = ByteBuffer.allocateDirect(flatSize * getTypeSize());
  19237. + buffer.order(ByteOrder.nativeOrder());
  19238. }
  19239. - // Update to the new shape.
  19240. - flatSize = newFlatSize;
  19241. - buffer = ByteBuffer.allocateDirect(flatSize * getTypeSize());
  19242. - buffer.order(ByteOrder.nativeOrder());
  19243. - }
  19244. -
  19245. - /**
  19246. - * Verifies if the shape of the {@link TensorBuffer} matched the size of the underlying {@link
  19247. - * ByteBuffer}.
  19248. - */
  19249. - private void assertShapeIsCorrect() {
  19250. - int flatSize = computeFlatSize(shape);
  19251. - checkState(
  19252. - (buffer.limit() == getTypeSize() * flatSize),
  19253. - String.format(
  19254. - "The size of underlying ByteBuffer (%d) and the shape (%s) do not match. The"
  19255. - + " ByteBuffer may have been changed.",
  19256. - buffer.limit(), Arrays.toString(shape)));
  19257. - }
  19258. -
  19259. - /**
  19260. - * Checks if {@code shape} meets one of following two requirements: 1. Elements in {@code shape}
  19261. - * are all non-negative numbers. 2. {@code shape} is an empty array, which corresponds to scalar.
  19262. - */
  19263. - private static boolean isShapeValid(@NonNull int[] shape) {
  19264. - if (shape.length == 0) {
  19265. - // This shape refers to a scalar.
  19266. - return true;
  19267. + /**
  19268. + * Verifies if the shape of the {@link TensorBuffer} matched the size of the underlying {@link
  19269. + * ByteBuffer}.
  19270. + */
  19271. + private void assertShapeIsCorrect() {
  19272. + int flatSize = computeFlatSize(shape);
  19273. + checkState((buffer.limit() == getTypeSize() * flatSize),
  19274. + String.format(
  19275. + "The size of underlying ByteBuffer (%d) and the shape (%s) do not match. The"
  19276. + + " ByteBuffer may have been changed.",
  19277. + buffer.limit(), Arrays.toString(shape)));
  19278. }
  19279. - // This shape refers to a multidimensional array.
  19280. - for (int s : shape) {
  19281. - // All elements in shape should be non-negative.
  19282. - if (s < 0) {
  19283. - return false;
  19284. - }
  19285. + /**
  19286. + * Checks if {@code shape} meets one of following two requirements: 1. Elements in {@code shape}
  19287. + * are all non-negative numbers. 2. {@code shape} is an empty array, which corresponds to
  19288. + * scalar.
  19289. + */
  19290. + private static boolean isShapeValid(@NonNull int[] shape) {
  19291. + if (shape.length == 0) {
  19292. + // This shape refers to a scalar.
  19293. + return true;
  19294. + }
  19295. +
  19296. + // This shape refers to a multidimensional array.
  19297. + for (int s : shape) {
  19298. + // All elements in shape should be non-negative.
  19299. + if (s < 0) {
  19300. + return false;
  19301. + }
  19302. + }
  19303. + return true;
  19304. }
  19305. - return true;
  19306. - }
  19307. }
  19308. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBufferFloat.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBufferFloat.java
  19309. index 8d2bc5ad0c84d..632db6c886b17 100644
  19310. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBufferFloat.java
  19311. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBufferFloat.java
  19312. @@ -15,103 +15,102 @@ limitations under the License.
  19313. package org.tensorflow.lite.support.tensorbuffer;
  19314. -import java.nio.FloatBuffer;
  19315. import org.checkerframework.checker.nullness.qual.NonNull;
  19316. import org.tensorflow.lite.DataType;
  19317. import org.tensorflow.lite.support.common.internal.SupportPreconditions;
  19318. +import java.nio.FloatBuffer;
  19319. +
  19320. /** Represents data buffer with float values. */
  19321. public final class TensorBufferFloat extends TensorBuffer {
  19322. - private static final DataType DATA_TYPE = DataType.FLOAT32;
  19323. -
  19324. - /**
  19325. - * Creates a {@link TensorBufferFloat} with specified {@code shape}.
  19326. - *
  19327. - * @throws NullPointerException if {@code shape} is null.
  19328. - * @throws IllegalArgumentException if {@code shape} has non-positive elements.
  19329. - */
  19330. - TensorBufferFloat(@NonNull int[] shape) {
  19331. - super(shape);
  19332. - }
  19333. -
  19334. - TensorBufferFloat() {
  19335. - super();
  19336. - }
  19337. -
  19338. - @Override
  19339. - public DataType getDataType() {
  19340. - return DATA_TYPE;
  19341. - }
  19342. -
  19343. - @Override
  19344. - @NonNull
  19345. - public float[] getFloatArray() {
  19346. - buffer.rewind();
  19347. - float[] arr = new float[flatSize];
  19348. -
  19349. - FloatBuffer floatBuffer = buffer.asFloatBuffer();
  19350. - floatBuffer.get(arr);
  19351. - return arr;
  19352. - }
  19353. -
  19354. - @Override
  19355. - public float getFloatValue(int absIndex) {
  19356. - return buffer.getFloat(absIndex << 2);
  19357. - }
  19358. -
  19359. - @Override
  19360. - @NonNull
  19361. - public int[] getIntArray() {
  19362. - buffer.rewind();
  19363. - float[] floatArr = new float[flatSize];
  19364. - buffer.asFloatBuffer().get(floatArr);
  19365. -
  19366. - int[] intArr = new int[flatSize];
  19367. - for (int i = 0; i < flatSize; i++) {
  19368. - intArr[i] = (int) floatArr[i];
  19369. + private static final DataType DATA_TYPE = DataType.FLOAT32;
  19370. +
  19371. + /**
  19372. + * Creates a {@link TensorBufferFloat} with specified {@code shape}.
  19373. + *
  19374. + * @throws NullPointerException if {@code shape} is null.
  19375. + * @throws IllegalArgumentException if {@code shape} has non-positive elements.
  19376. + */
  19377. + TensorBufferFloat(@NonNull int[] shape) {
  19378. + super(shape);
  19379. }
  19380. - return intArr;
  19381. - }
  19382. -
  19383. - @Override
  19384. - public int getIntValue(int absIndex) {
  19385. - return (int) buffer.getFloat(absIndex << 2);
  19386. - }
  19387. -
  19388. - @Override
  19389. - public int getTypeSize() {
  19390. - return DATA_TYPE.byteSize();
  19391. - }
  19392. -
  19393. - @Override
  19394. - public void loadArray(@NonNull float[] src, @NonNull int[] shape) {
  19395. - SupportPreconditions.checkNotNull(src, "The array to be loaded cannot be null.");
  19396. - SupportPreconditions.checkArgument(
  19397. - src.length == computeFlatSize(shape),
  19398. - "The size of the array to be loaded does not match the specified shape.");
  19399. - copyByteBufferIfReadOnly();
  19400. - resize(shape);
  19401. - buffer.rewind();
  19402. -
  19403. - FloatBuffer floatBuffer = buffer.asFloatBuffer();
  19404. - floatBuffer.put(src);
  19405. - }
  19406. -
  19407. - @Override
  19408. - public void loadArray(@NonNull int[] src, @NonNull int[] shape) {
  19409. - SupportPreconditions.checkNotNull(src, "The array to be loaded cannot be null.");
  19410. - SupportPreconditions.checkArgument(
  19411. - src.length == computeFlatSize(shape),
  19412. - "The size of the array to be loaded does not match the specified shape.");
  19413. - copyByteBufferIfReadOnly();
  19414. - resize(shape);
  19415. - buffer.rewind();
  19416. -
  19417. - float[] floatArray = new float[src.length];
  19418. - int cnt = 0;
  19419. - for (int a : src) {
  19420. - floatArray[cnt++] = (float) a;
  19421. +
  19422. + TensorBufferFloat() {
  19423. + super();
  19424. + }
  19425. +
  19426. + @Override
  19427. + public DataType getDataType() {
  19428. + return DATA_TYPE;
  19429. + }
  19430. +
  19431. + @Override
  19432. + @NonNull
  19433. + public float[] getFloatArray() {
  19434. + buffer.rewind();
  19435. + float[] arr = new float[flatSize];
  19436. +
  19437. + FloatBuffer floatBuffer = buffer.asFloatBuffer();
  19438. + floatBuffer.get(arr);
  19439. + return arr;
  19440. + }
  19441. +
  19442. + @Override
  19443. + public float getFloatValue(int absIndex) {
  19444. + return buffer.getFloat(absIndex << 2);
  19445. + }
  19446. +
  19447. + @Override
  19448. + @NonNull
  19449. + public int[] getIntArray() {
  19450. + buffer.rewind();
  19451. + float[] floatArr = new float[flatSize];
  19452. + buffer.asFloatBuffer().get(floatArr);
  19453. +
  19454. + int[] intArr = new int[flatSize];
  19455. + for (int i = 0; i < flatSize; i++) {
  19456. + intArr[i] = (int) floatArr[i];
  19457. + }
  19458. + return intArr;
  19459. + }
  19460. +
  19461. + @Override
  19462. + public int getIntValue(int absIndex) {
  19463. + return (int) buffer.getFloat(absIndex << 2);
  19464. + }
  19465. +
  19466. + @Override
  19467. + public int getTypeSize() {
  19468. + return DATA_TYPE.byteSize();
  19469. + }
  19470. +
  19471. + @Override
  19472. + public void loadArray(@NonNull float[] src, @NonNull int[] shape) {
  19473. + SupportPreconditions.checkNotNull(src, "The array to be loaded cannot be null.");
  19474. + SupportPreconditions.checkArgument(src.length == computeFlatSize(shape),
  19475. + "The size of the array to be loaded does not match the specified shape.");
  19476. + copyByteBufferIfReadOnly();
  19477. + resize(shape);
  19478. + buffer.rewind();
  19479. +
  19480. + FloatBuffer floatBuffer = buffer.asFloatBuffer();
  19481. + floatBuffer.put(src);
  19482. + }
  19483. +
  19484. + @Override
  19485. + public void loadArray(@NonNull int[] src, @NonNull int[] shape) {
  19486. + SupportPreconditions.checkNotNull(src, "The array to be loaded cannot be null.");
  19487. + SupportPreconditions.checkArgument(src.length == computeFlatSize(shape),
  19488. + "The size of the array to be loaded does not match the specified shape.");
  19489. + copyByteBufferIfReadOnly();
  19490. + resize(shape);
  19491. + buffer.rewind();
  19492. +
  19493. + float[] floatArray = new float[src.length];
  19494. + int cnt = 0;
  19495. + for (int a : src) {
  19496. + floatArray[cnt++] = (float) a;
  19497. + }
  19498. + buffer.asFloatBuffer().put(floatArray);
  19499. }
  19500. - buffer.asFloatBuffer().put(floatArray);
  19501. - }
  19502. }
  19503. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBufferUint8.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBufferUint8.java
  19504. index b2fa466e5be92..2924ef0af6c11 100644
  19505. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBufferUint8.java
  19506. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBufferUint8.java
  19507. @@ -21,103 +21,101 @@ import org.tensorflow.lite.support.common.internal.SupportPreconditions;
  19508. /** Represents data buffer with 8-bit unsigned integer values. */
  19509. public final class TensorBufferUint8 extends TensorBuffer {
  19510. - private static final DataType DATA_TYPE = DataType.UINT8;
  19511. -
  19512. - /**
  19513. - * Creates a {@link TensorBufferUint8} with specified {@code shape}.
  19514. - *
  19515. - * @throws NullPointerException if {@code shape} is null.
  19516. - * @throws IllegalArgumentException if {@code shape} has non-positive elements.
  19517. - */
  19518. - TensorBufferUint8(@NonNull int[] shape) {
  19519. - super(shape);
  19520. - }
  19521. -
  19522. - TensorBufferUint8() {
  19523. - super();
  19524. - }
  19525. -
  19526. - @Override
  19527. - public DataType getDataType() {
  19528. - return DATA_TYPE;
  19529. - }
  19530. -
  19531. - @Override
  19532. - @NonNull
  19533. - public float[] getFloatArray() {
  19534. - buffer.rewind();
  19535. - byte[] byteArr = new byte[flatSize];
  19536. - buffer.get(byteArr);
  19537. -
  19538. - float[] floatArr = new float[flatSize];
  19539. - for (int i = 0; i < flatSize; i++) {
  19540. - floatArr[i] = (float) (byteArr[i] & 0xff);
  19541. + private static final DataType DATA_TYPE = DataType.UINT8;
  19542. +
  19543. + /**
  19544. + * Creates a {@link TensorBufferUint8} with specified {@code shape}.
  19545. + *
  19546. + * @throws NullPointerException if {@code shape} is null.
  19547. + * @throws IllegalArgumentException if {@code shape} has non-positive elements.
  19548. + */
  19549. + TensorBufferUint8(@NonNull int[] shape) {
  19550. + super(shape);
  19551. }
  19552. - return floatArr;
  19553. - }
  19554. -
  19555. - @Override
  19556. - public float getFloatValue(int index) {
  19557. - return (float) (buffer.get(index) & 0xff);
  19558. - }
  19559. -
  19560. - @Override
  19561. - @NonNull
  19562. - public int[] getIntArray() {
  19563. - buffer.rewind();
  19564. - byte[] byteArr = new byte[flatSize];
  19565. - buffer.get(byteArr);
  19566. -
  19567. - int[] intArr = new int[flatSize];
  19568. - for (int i = 0; i < flatSize; i++) {
  19569. - intArr[i] = byteArr[i] & 0xff;
  19570. +
  19571. + TensorBufferUint8() {
  19572. + super();
  19573. + }
  19574. +
  19575. + @Override
  19576. + public DataType getDataType() {
  19577. + return DATA_TYPE;
  19578. + }
  19579. +
  19580. + @Override
  19581. + @NonNull
  19582. + public float[] getFloatArray() {
  19583. + buffer.rewind();
  19584. + byte[] byteArr = new byte[flatSize];
  19585. + buffer.get(byteArr);
  19586. +
  19587. + float[] floatArr = new float[flatSize];
  19588. + for (int i = 0; i < flatSize; i++) {
  19589. + floatArr[i] = (float) (byteArr[i] & 0xff);
  19590. + }
  19591. + return floatArr;
  19592. + }
  19593. +
  19594. + @Override
  19595. + public float getFloatValue(int index) {
  19596. + return (float) (buffer.get(index) & 0xff);
  19597. }
  19598. - return intArr;
  19599. - }
  19600. -
  19601. - @Override
  19602. - public int getIntValue(int index) {
  19603. - return buffer.get(index) & 0xff;
  19604. - }
  19605. -
  19606. - @Override
  19607. - public int getTypeSize() {
  19608. - return DATA_TYPE.byteSize();
  19609. - }
  19610. -
  19611. - @Override
  19612. - public void loadArray(@NonNull float[] src, @NonNull int[] shape) {
  19613. - SupportPreconditions.checkNotNull(src, "The array to be loaded cannot be null.");
  19614. - SupportPreconditions.checkArgument(
  19615. - src.length == computeFlatSize(shape),
  19616. - "The size of the array to be loaded does not match the specified shape.");
  19617. - copyByteBufferIfReadOnly();
  19618. - resize(shape);
  19619. - buffer.rewind();
  19620. -
  19621. - byte[] byteArr = new byte[src.length];
  19622. - int cnt = 0;
  19623. - for (float a : src) {
  19624. - byteArr[cnt++] = (byte) Math.max(Math.min(a, 255.0), 0.0);
  19625. +
  19626. + @Override
  19627. + @NonNull
  19628. + public int[] getIntArray() {
  19629. + buffer.rewind();
  19630. + byte[] byteArr = new byte[flatSize];
  19631. + buffer.get(byteArr);
  19632. +
  19633. + int[] intArr = new int[flatSize];
  19634. + for (int i = 0; i < flatSize; i++) {
  19635. + intArr[i] = byteArr[i] & 0xff;
  19636. + }
  19637. + return intArr;
  19638. }
  19639. - buffer.put(byteArr);
  19640. - }
  19641. -
  19642. - @Override
  19643. - public void loadArray(@NonNull int[] src, @NonNull int[] shape) {
  19644. - SupportPreconditions.checkNotNull(src, "The array to be loaded cannot be null.");
  19645. - SupportPreconditions.checkArgument(
  19646. - src.length == computeFlatSize(shape),
  19647. - "The size of the array to be loaded does not match the specified shape.");
  19648. - copyByteBufferIfReadOnly();
  19649. - resize(shape);
  19650. - buffer.rewind();
  19651. -
  19652. - byte[] byteArr = new byte[src.length];
  19653. - int cnt = 0;
  19654. - for (float a : src) {
  19655. - byteArr[cnt++] = (byte) Math.max(Math.min(a, 255), 0);
  19656. +
  19657. + @Override
  19658. + public int getIntValue(int index) {
  19659. + return buffer.get(index) & 0xff;
  19660. + }
  19661. +
  19662. + @Override
  19663. + public int getTypeSize() {
  19664. + return DATA_TYPE.byteSize();
  19665. + }
  19666. +
  19667. + @Override
  19668. + public void loadArray(@NonNull float[] src, @NonNull int[] shape) {
  19669. + SupportPreconditions.checkNotNull(src, "The array to be loaded cannot be null.");
  19670. + SupportPreconditions.checkArgument(src.length == computeFlatSize(shape),
  19671. + "The size of the array to be loaded does not match the specified shape.");
  19672. + copyByteBufferIfReadOnly();
  19673. + resize(shape);
  19674. + buffer.rewind();
  19675. +
  19676. + byte[] byteArr = new byte[src.length];
  19677. + int cnt = 0;
  19678. + for (float a : src) {
  19679. + byteArr[cnt++] = (byte) Math.max(Math.min(a, 255.0), 0.0);
  19680. + }
  19681. + buffer.put(byteArr);
  19682. + }
  19683. +
  19684. + @Override
  19685. + public void loadArray(@NonNull int[] src, @NonNull int[] shape) {
  19686. + SupportPreconditions.checkNotNull(src, "The array to be loaded cannot be null.");
  19687. + SupportPreconditions.checkArgument(src.length == computeFlatSize(shape),
  19688. + "The size of the array to be loaded does not match the specified shape.");
  19689. + copyByteBufferIfReadOnly();
  19690. + resize(shape);
  19691. + buffer.rewind();
  19692. +
  19693. + byte[] byteArr = new byte[src.length];
  19694. + int cnt = 0;
  19695. + for (float a : src) {
  19696. + byteArr[cnt++] = (byte) Math.max(Math.min(a, 255), 0);
  19697. + }
  19698. + buffer.put(byteArr);
  19699. }
  19700. - buffer.put(byteArr);
  19701. - }
  19702. }
  19703. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/audio/classifier/AudioClassifier.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/audio/classifier/AudioClassifier.java
  19704. index 043528aa88138..85c5d12e2fc53 100644
  19705. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/audio/classifier/AudioClassifier.java
  19706. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/audio/classifier/AudioClassifier.java
  19707. @@ -22,13 +22,7 @@ import android.media.AudioFormat;
  19708. import android.media.AudioRecord;
  19709. import android.media.MediaRecorder;
  19710. import android.os.ParcelFileDescriptor;
  19711. -import java.io.File;
  19712. -import java.io.IOException;
  19713. -import java.nio.ByteBuffer;
  19714. -import java.nio.MappedByteBuffer;
  19715. -import java.util.ArrayList;
  19716. -import java.util.Collections;
  19717. -import java.util.List;
  19718. +
  19719. import org.tensorflow.lite.DataType;
  19720. import org.tensorflow.lite.support.audio.TensorAudio;
  19721. import org.tensorflow.lite.support.audio.TensorAudio.TensorAudioFormat;
  19722. @@ -40,6 +34,14 @@ import org.tensorflow.lite.task.core.TaskJniUtils.EmptyHandleProvider;
  19723. import org.tensorflow.lite.task.core.TaskJniUtils.FdAndOptionsHandleProvider;
  19724. import org.tensorflow.lite.task.core.annotations.UsedByReflection;
  19725. +import java.io.File;
  19726. +import java.io.IOException;
  19727. +import java.nio.ByteBuffer;
  19728. +import java.nio.MappedByteBuffer;
  19729. +import java.util.ArrayList;
  19730. +import java.util.Collections;
  19731. +import java.util.List;
  19732. +
  19733. /**
  19734. * Performs classification on audio waveforms.
  19735. *
  19736. @@ -72,468 +74,437 @@ import org.tensorflow.lite.task.core.annotations.UsedByReflection;
  19737. * CLI demo tool</a> for easily trying out this API.
  19738. */
  19739. public final class AudioClassifier extends BaseTaskApi {
  19740. + private static final String AUDIO_CLASSIFIER_NATIVE_LIB = "task_audio_jni";
  19741. + private static final int OPTIONAL_FD_LENGTH = -1;
  19742. + private static final int OPTIONAL_FD_OFFSET = -1;
  19743. +
  19744. + /**
  19745. + * Creates an {@link AudioClassifier} instance from the default {@link AudioClassifierOptions}.
  19746. + *
  19747. + * @param modelPath path of the classification model with metadata in the assets
  19748. + * @throws IOException if an I/O error occurs when loading the tflite model
  19749. + * @throws IllegalArgumentException if an argument is invalid
  19750. + * @throws IllegalStateException if there is an internal error
  19751. + * @throws RuntimeException if there is an otherwise unspecified error
  19752. + */
  19753. + public static AudioClassifier createFromFile(Context context, String modelPath)
  19754. + throws IOException {
  19755. + return createFromFileAndOptions(
  19756. + context, modelPath, AudioClassifierOptions.builder().build());
  19757. + }
  19758. - private static final String AUDIO_CLASSIFIER_NATIVE_LIB = "task_audio_jni";
  19759. - private static final int OPTIONAL_FD_LENGTH = -1;
  19760. - private static final int OPTIONAL_FD_OFFSET = -1;
  19761. -
  19762. - /**
  19763. - * Creates an {@link AudioClassifier} instance from the default {@link AudioClassifierOptions}.
  19764. - *
  19765. - * @param modelPath path of the classification model with metadata in the assets
  19766. - * @throws IOException if an I/O error occurs when loading the tflite model
  19767. - * @throws IllegalArgumentException if an argument is invalid
  19768. - * @throws IllegalStateException if there is an internal error
  19769. - * @throws RuntimeException if there is an otherwise unspecified error
  19770. - */
  19771. - public static AudioClassifier createFromFile(Context context, String modelPath)
  19772. - throws IOException {
  19773. - return createFromFileAndOptions(context, modelPath, AudioClassifierOptions.builder().build());
  19774. - }
  19775. -
  19776. - /**
  19777. - * Creates an {@link AudioClassifier} instance from the default {@link AudioClassifierOptions}.
  19778. - *
  19779. - * @param modelFile the classification model {@link File} instance
  19780. - * @throws IOException if an I/O error occurs when loading the tflite model
  19781. - * @throws IllegalArgumentException if an argument is invalid
  19782. - * @throws IllegalStateException if there is an internal error
  19783. - * @throws RuntimeException if there is an otherwise unspecified error
  19784. - */
  19785. - public static AudioClassifier createFromFile(File modelFile) throws IOException {
  19786. - return createFromFileAndOptions(modelFile, AudioClassifierOptions.builder().build());
  19787. - }
  19788. -
  19789. - /**
  19790. - * Creates an {@link AudioClassifier} instance with a model buffer and the default {@link
  19791. - * AudioClassifierOptions}.
  19792. - *
  19793. - * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the
  19794. - * classification model
  19795. - * @throws IllegalStateException if there is an internal error
  19796. - * @throws RuntimeException if there is an otherwise unspecified error
  19797. - * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
  19798. - * {@link MappedByteBuffer}
  19799. - */
  19800. - public static AudioClassifier createFromBuffer(final ByteBuffer modelBuffer) {
  19801. - return createFromBufferAndOptions(modelBuffer, AudioClassifierOptions.builder().build());
  19802. - }
  19803. -
  19804. - /**
  19805. - * Creates an {@link AudioClassifier} instance from {@link AudioClassifierOptions}.
  19806. - *
  19807. - * @param modelPath path of the classification model with metadata in the assets
  19808. - * @throws IOException if an I/O error occurs when loading the tflite model
  19809. - * @throws IllegalArgumentException if an argument is invalid
  19810. - * @throws IllegalStateException if there is an internal error
  19811. - * @throws RuntimeException if there is an otherwise unspecified error
  19812. - */
  19813. - public static AudioClassifier createFromFileAndOptions(
  19814. - Context context, String modelPath, AudioClassifierOptions options) throws IOException {
  19815. - return new AudioClassifier(
  19816. - TaskJniUtils.createHandleFromFdAndOptions(
  19817. - context,
  19818. - new FdAndOptionsHandleProvider<AudioClassifierOptions>() {
  19819. - @Override
  19820. - public long createHandle(
  19821. - int fileDescriptor,
  19822. - long fileDescriptorLength,
  19823. - long fileDescriptorOffset,
  19824. - AudioClassifierOptions options) {
  19825. - return initJniWithModelFdAndOptions(
  19826. - fileDescriptor,
  19827. - fileDescriptorLength,
  19828. - fileDescriptorOffset,
  19829. - options,
  19830. - TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions()));
  19831. - }
  19832. - },
  19833. - AUDIO_CLASSIFIER_NATIVE_LIB,
  19834. - modelPath,
  19835. - options));
  19836. - }
  19837. -
  19838. - /**
  19839. - * Creates an {@link AudioClassifier} instance.
  19840. - *
  19841. - * @param modelFile the classification model {@link File} instance
  19842. - * @throws IOException if an I/O error occurs when loading the tflite model
  19843. - * @throws IllegalArgumentException if an argument is invalid
  19844. - * @throws IllegalStateException if there is an internal error
  19845. - * @throws RuntimeException if there is an otherwise unspecified error
  19846. - */
  19847. - public static AudioClassifier createFromFileAndOptions(
  19848. - File modelFile, final AudioClassifierOptions options) throws IOException {
  19849. - try (ParcelFileDescriptor descriptor =
  19850. - ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
  19851. - return new AudioClassifier(
  19852. - TaskJniUtils.createHandleFromLibrary(
  19853. - new TaskJniUtils.EmptyHandleProvider() {
  19854. - @Override
  19855. - public long createHandle() {
  19856. - return initJniWithModelFdAndOptions(
  19857. - descriptor.getFd(),
  19858. - /*fileDescriptorLength=*/ OPTIONAL_FD_LENGTH,
  19859. - /*fileDescriptorOffset=*/ OPTIONAL_FD_OFFSET,
  19860. - options,
  19861. - TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions()));
  19862. - }
  19863. - },
  19864. - AUDIO_CLASSIFIER_NATIVE_LIB));
  19865. + /**
  19866. + * Creates an {@link AudioClassifier} instance from the default {@link AudioClassifierOptions}.
  19867. + *
  19868. + * @param modelFile the classification model {@link File} instance
  19869. + * @throws IOException if an I/O error occurs when loading the tflite model
  19870. + * @throws IllegalArgumentException if an argument is invalid
  19871. + * @throws IllegalStateException if there is an internal error
  19872. + * @throws RuntimeException if there is an otherwise unspecified error
  19873. + */
  19874. + public static AudioClassifier createFromFile(File modelFile) throws IOException {
  19875. + return createFromFileAndOptions(modelFile, AudioClassifierOptions.builder().build());
  19876. }
  19877. - }
  19878. -
  19879. - /**
  19880. - * Creates an {@link AudioClassifier} instance with a model buffer and {@link
  19881. - * AudioClassifierOptions}.
  19882. - *
  19883. - * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the
  19884. - * classification model
  19885. - * @throws IllegalStateException if there is an internal error
  19886. - * @throws RuntimeException if there is an otherwise unspecified error
  19887. - * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
  19888. - * {@link MappedByteBuffer}
  19889. - */
  19890. - public static AudioClassifier createFromBufferAndOptions(
  19891. - final ByteBuffer modelBuffer, final AudioClassifierOptions options) {
  19892. - if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) {
  19893. - throw new IllegalArgumentException(
  19894. - "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer.");
  19895. +
  19896. + /**
  19897. + * Creates an {@link AudioClassifier} instance with a model buffer and the default {@link
  19898. + * AudioClassifierOptions}.
  19899. + *
  19900. + * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the
  19901. + * classification model
  19902. + * @throws IllegalStateException if there is an internal error
  19903. + * @throws RuntimeException if there is an otherwise unspecified error
  19904. + * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
  19905. + * {@link MappedByteBuffer}
  19906. + */
  19907. + public static AudioClassifier createFromBuffer(final ByteBuffer modelBuffer) {
  19908. + return createFromBufferAndOptions(modelBuffer, AudioClassifierOptions.builder().build());
  19909. }
  19910. - return new AudioClassifier(
  19911. - TaskJniUtils.createHandleFromLibrary(
  19912. - new EmptyHandleProvider() {
  19913. - @Override
  19914. - public long createHandle() {
  19915. - return initJniWithByteBuffer(
  19916. - modelBuffer,
  19917. - options,
  19918. - TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions()));
  19919. - }
  19920. - },
  19921. - AUDIO_CLASSIFIER_NATIVE_LIB));
  19922. - }
  19923. -
  19924. - /**
  19925. - * Constructor to initialize the JNI with a pointer from C++.
  19926. - *
  19927. - * @param nativeHandle a pointer referencing memory allocated in C++
  19928. - */
  19929. - private AudioClassifier(long nativeHandle) {
  19930. - super(nativeHandle);
  19931. - }
  19932. -
  19933. - /** Options for setting up an {@link AudioClassifier}. */
  19934. - @UsedByReflection("audio_classifier_jni.cc")
  19935. - public static class AudioClassifierOptions {
  19936. - // Not using AutoValue for this class because scoreThreshold cannot have default value
  19937. - // (otherwise, the default value would override the one in the model metadata) and `Optional` is
  19938. - // not an option here, because
  19939. - // 1. java.util.Optional require Java 8 while we need to support Java 7.
  19940. - // 2. The Guava library (com.google.common.base.Optional) is avoided in this project. See the
  19941. - // comments for labelAllowList.
  19942. - private final BaseOptions baseOptions;
  19943. - private final String displayNamesLocale;
  19944. - private final int maxResults;
  19945. - private final float scoreThreshold;
  19946. - private final boolean isScoreThresholdSet;
  19947. - // As an open source project, we've been trying avoiding depending on common java libraries,
  19948. - // such as Guava, because it may introduce conflicts with clients who also happen to use those
  19949. - // libraries. Therefore, instead of using ImmutableList here, we convert the List into
  19950. - // unmodifiableList in setLabelAllowList() and setLabelDenyList() to make it less
  19951. - // vulnerable.
  19952. - private final List<String> labelAllowList;
  19953. - private final List<String> labelDenyList;
  19954. -
  19955. - public static Builder builder() {
  19956. - return new Builder();
  19957. +
  19958. + /**
  19959. + * Creates an {@link AudioClassifier} instance from {@link AudioClassifierOptions}.
  19960. + *
  19961. + * @param modelPath path of the classification model with metadata in the assets
  19962. + * @throws IOException if an I/O error occurs when loading the tflite model
  19963. + * @throws IllegalArgumentException if an argument is invalid
  19964. + * @throws IllegalStateException if there is an internal error
  19965. + * @throws RuntimeException if there is an otherwise unspecified error
  19966. + */
  19967. + public static AudioClassifier createFromFileAndOptions(
  19968. + Context context, String modelPath, AudioClassifierOptions options) throws IOException {
  19969. + return new AudioClassifier(TaskJniUtils.createHandleFromFdAndOptions(
  19970. + context, new FdAndOptionsHandleProvider<AudioClassifierOptions>() {
  19971. + @Override
  19972. + public long createHandle(int fileDescriptor, long fileDescriptorLength,
  19973. + long fileDescriptorOffset, AudioClassifierOptions options) {
  19974. + return initJniWithModelFdAndOptions(fileDescriptor, fileDescriptorLength,
  19975. + fileDescriptorOffset, options,
  19976. + TaskJniUtils.createProtoBaseOptionsHandle(
  19977. + options.getBaseOptions()));
  19978. + }
  19979. + }, AUDIO_CLASSIFIER_NATIVE_LIB, modelPath, options));
  19980. }
  19981. - /** A builder that helps to configure an instance of AudioClassifierOptions. */
  19982. - public static class Builder {
  19983. - private BaseOptions baseOptions = BaseOptions.builder().build();
  19984. - private String displayNamesLocale = "en";
  19985. - private int maxResults = -1;
  19986. - private float scoreThreshold;
  19987. - private boolean isScoreThresholdSet;
  19988. - private List<String> labelAllowList = new ArrayList<>();
  19989. - private List<String> labelDenyList = new ArrayList<>();
  19990. -
  19991. - private Builder() {}
  19992. -
  19993. - /** Sets the general options to configure Task APIs, such as accelerators. */
  19994. - public Builder setBaseOptions(BaseOptions baseOptions) {
  19995. - this.baseOptions = baseOptions;
  19996. - return this;
  19997. - }
  19998. -
  19999. - /**
  20000. - * Sets the locale to use for display names specified through the TFLite Model Metadata, if
  20001. - * any.
  20002. - *
  20003. - * <p>Defaults to English({@code "en"}). See the <a
  20004. - * href="https://github.com/tensorflow/tflite-support/blob/3ce83f0cfe2c68fecf83e019f2acc354aaba471f/tensorflow_lite_support/metadata/metadata_schema.fbs#L147">TFLite
  20005. - * Metadata schema file.</a> for the accepted pattern of locale.
  20006. - */
  20007. - public Builder setDisplayNamesLocale(String displayNamesLocale) {
  20008. - this.displayNamesLocale = displayNamesLocale;
  20009. - return this;
  20010. - }
  20011. -
  20012. - /**
  20013. - * Sets the maximum number of top scored results to return.
  20014. - *
  20015. - * @param maxResults if < 0, all results will be returned. If 0, an invalid argument error is
  20016. - * returned. Defaults to -1.
  20017. - * @throws IllegalArgumentException if maxResults is 0
  20018. - */
  20019. - public Builder setMaxResults(int maxResults) {
  20020. - if (maxResults == 0) {
  20021. - throw new IllegalArgumentException("maxResults cannot be 0.");
  20022. + /**
  20023. + * Creates an {@link AudioClassifier} instance.
  20024. + *
  20025. + * @param modelFile the classification model {@link File} instance
  20026. + * @throws IOException if an I/O error occurs when loading the tflite model
  20027. + * @throws IllegalArgumentException if an argument is invalid
  20028. + * @throws IllegalStateException if there is an internal error
  20029. + * @throws RuntimeException if there is an otherwise unspecified error
  20030. + */
  20031. + public static AudioClassifier createFromFileAndOptions(
  20032. + File modelFile, final AudioClassifierOptions options) throws IOException {
  20033. + try (ParcelFileDescriptor descriptor =
  20034. + ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
  20035. + return new AudioClassifier(
  20036. + TaskJniUtils.createHandleFromLibrary(new TaskJniUtils.EmptyHandleProvider() {
  20037. + @Override
  20038. + public long createHandle() {
  20039. + return initJniWithModelFdAndOptions(descriptor.getFd(),
  20040. + /*fileDescriptorLength=*/OPTIONAL_FD_LENGTH,
  20041. + /*fileDescriptorOffset=*/OPTIONAL_FD_OFFSET, options,
  20042. + TaskJniUtils.createProtoBaseOptionsHandle(
  20043. + options.getBaseOptions()));
  20044. + }
  20045. + }, AUDIO_CLASSIFIER_NATIVE_LIB));
  20046. }
  20047. - this.maxResults = maxResults;
  20048. - return this;
  20049. - }
  20050. -
  20051. - /**
  20052. - * Sets the score threshold.
  20053. - *
  20054. - * <p>It overrides the one provided in the model metadata (if any). Results below this value
  20055. - * are rejected.
  20056. - */
  20057. - public Builder setScoreThreshold(float scoreThreshold) {
  20058. - this.scoreThreshold = scoreThreshold;
  20059. - isScoreThresholdSet = true;
  20060. - return this;
  20061. - }
  20062. -
  20063. - /**
  20064. - * Sets the optional allowlist of labels.
  20065. - *
  20066. - * <p>If non-empty, classifications whose label is not in this set will be filtered out.
  20067. - * Duplicate or unknown labels are ignored. Mutually exclusive with labelDenyList.
  20068. - */
  20069. - public Builder setLabelAllowList(List<String> labelAllowList) {
  20070. - this.labelAllowList = Collections.unmodifiableList(new ArrayList<>(labelAllowList));
  20071. - return this;
  20072. - }
  20073. -
  20074. - /**
  20075. - * Sets the optional denylist of labels.
  20076. - *
  20077. - * <p>If non-empty, classifications whose label is in this set will be filtered out. Duplicate
  20078. - * or unknown labels are ignored. Mutually exclusive with labelAllowList.
  20079. - */
  20080. - public Builder setLabelDenyList(List<String> labelDenyList) {
  20081. - this.labelDenyList = Collections.unmodifiableList(new ArrayList<>(labelDenyList));
  20082. - return this;
  20083. - }
  20084. -
  20085. - public AudioClassifierOptions build() {
  20086. - return new AudioClassifierOptions(this);
  20087. - }
  20088. }
  20089. - @UsedByReflection("audio_classifier_jni.cc")
  20090. - public String getDisplayNamesLocale() {
  20091. - return displayNamesLocale;
  20092. + /**
  20093. + * Creates an {@link AudioClassifier} instance with a model buffer and {@link
  20094. + * AudioClassifierOptions}.
  20095. + *
  20096. + * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the
  20097. + * classification model
  20098. + * @throws IllegalStateException if there is an internal error
  20099. + * @throws RuntimeException if there is an otherwise unspecified error
  20100. + * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
  20101. + * {@link MappedByteBuffer}
  20102. + */
  20103. + public static AudioClassifier createFromBufferAndOptions(
  20104. + final ByteBuffer modelBuffer, final AudioClassifierOptions options) {
  20105. + if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) {
  20106. + throw new IllegalArgumentException(
  20107. + "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer.");
  20108. + }
  20109. + return new AudioClassifier(TaskJniUtils.createHandleFromLibrary(new EmptyHandleProvider() {
  20110. + @Override
  20111. + public long createHandle() {
  20112. + return initJniWithByteBuffer(modelBuffer, options,
  20113. + TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions()));
  20114. + }
  20115. + }, AUDIO_CLASSIFIER_NATIVE_LIB));
  20116. }
  20117. - @UsedByReflection("audio_classifier_jni.cc")
  20118. - public int getMaxResults() {
  20119. - return maxResults;
  20120. + /**
  20121. + * Constructor to initialize the JNI with a pointer from C++.
  20122. + *
  20123. + * @param nativeHandle a pointer referencing memory allocated in C++
  20124. + */
  20125. + private AudioClassifier(long nativeHandle) {
  20126. + super(nativeHandle);
  20127. }
  20128. + /** Options for setting up an {@link AudioClassifier}. */
  20129. @UsedByReflection("audio_classifier_jni.cc")
  20130. - public float getScoreThreshold() {
  20131. - return scoreThreshold;
  20132. + public static class AudioClassifierOptions {
  20133. + // Not using AutoValue for this class because scoreThreshold cannot have default value
  20134. + // (otherwise, the default value would override the one in the model metadata) and
  20135. + // `Optional` is not an option here, because
  20136. + // 1. java.util.Optional require Java 8 while we need to support Java 7.
  20137. + // 2. The Guava library (com.google.common.base.Optional) is avoided in this project. See
  20138. + // the comments for labelAllowList.
  20139. + private final BaseOptions baseOptions;
  20140. + private final String displayNamesLocale;
  20141. + private final int maxResults;
  20142. + private final float scoreThreshold;
  20143. + private final boolean isScoreThresholdSet;
  20144. + // As an open source project, we've been trying avoiding depending on common java libraries,
  20145. + // such as Guava, because it may introduce conflicts with clients who also happen to use
  20146. + // those libraries. Therefore, instead of using ImmutableList here, we convert the List into
  20147. + // unmodifiableList in setLabelAllowList() and setLabelDenyList() to make it less
  20148. + // vulnerable.
  20149. + private final List<String> labelAllowList;
  20150. + private final List<String> labelDenyList;
  20151. +
  20152. + public static Builder builder() {
  20153. + return new Builder();
  20154. + }
  20155. +
  20156. + /** A builder that helps to configure an instance of AudioClassifierOptions. */
  20157. + public static class Builder {
  20158. + private BaseOptions baseOptions = BaseOptions.builder().build();
  20159. + private String displayNamesLocale = "en";
  20160. + private int maxResults = -1;
  20161. + private float scoreThreshold;
  20162. + private boolean isScoreThresholdSet;
  20163. + private List<String> labelAllowList = new ArrayList<>();
  20164. + private List<String> labelDenyList = new ArrayList<>();
  20165. +
  20166. + private Builder() {}
  20167. +
  20168. + /** Sets the general options to configure Task APIs, such as accelerators. */
  20169. + public Builder setBaseOptions(BaseOptions baseOptions) {
  20170. + this.baseOptions = baseOptions;
  20171. + return this;
  20172. + }
  20173. +
  20174. + /**
  20175. + * Sets the locale to use for display names specified through the TFLite Model Metadata,
  20176. + * if any.
  20177. + *
  20178. + * <p>Defaults to English({@code "en"}). See the <a
  20179. + * href="https://github.com/tensorflow/tflite-support/blob/3ce83f0cfe2c68fecf83e019f2acc354aaba471f/tensorflow_lite_support/metadata/metadata_schema.fbs#L147">TFLite
  20180. + * Metadata schema file.</a> for the accepted pattern of locale.
  20181. + */
  20182. + public Builder setDisplayNamesLocale(String displayNamesLocale) {
  20183. + this.displayNamesLocale = displayNamesLocale;
  20184. + return this;
  20185. + }
  20186. +
  20187. + /**
  20188. + * Sets the maximum number of top scored results to return.
  20189. + *
  20190. + * @param maxResults if < 0, all results will be returned. If 0, an invalid argument
  20191. + * error is
  20192. + * returned. Defaults to -1.
  20193. + * @throws IllegalArgumentException if maxResults is 0
  20194. + */
  20195. + public Builder setMaxResults(int maxResults) {
  20196. + if (maxResults == 0) {
  20197. + throw new IllegalArgumentException("maxResults cannot be 0.");
  20198. + }
  20199. + this.maxResults = maxResults;
  20200. + return this;
  20201. + }
  20202. +
  20203. + /**
  20204. + * Sets the score threshold.
  20205. + *
  20206. + * <p>It overrides the one provided in the model metadata (if any). Results below this
  20207. + * value are rejected.
  20208. + */
  20209. + public Builder setScoreThreshold(float scoreThreshold) {
  20210. + this.scoreThreshold = scoreThreshold;
  20211. + isScoreThresholdSet = true;
  20212. + return this;
  20213. + }
  20214. +
  20215. + /**
  20216. + * Sets the optional allowlist of labels.
  20217. + *
  20218. + * <p>If non-empty, classifications whose label is not in this set will be filtered out.
  20219. + * Duplicate or unknown labels are ignored. Mutually exclusive with labelDenyList.
  20220. + */
  20221. + public Builder setLabelAllowList(List<String> labelAllowList) {
  20222. + this.labelAllowList = Collections.unmodifiableList(new ArrayList<>(labelAllowList));
  20223. + return this;
  20224. + }
  20225. +
  20226. + /**
  20227. + * Sets the optional denylist of labels.
  20228. + *
  20229. + * <p>If non-empty, classifications whose label is in this set will be filtered out.
  20230. + * Duplicate or unknown labels are ignored. Mutually exclusive with labelAllowList.
  20231. + */
  20232. + public Builder setLabelDenyList(List<String> labelDenyList) {
  20233. + this.labelDenyList = Collections.unmodifiableList(new ArrayList<>(labelDenyList));
  20234. + return this;
  20235. + }
  20236. +
  20237. + public AudioClassifierOptions build() {
  20238. + return new AudioClassifierOptions(this);
  20239. + }
  20240. + }
  20241. +
  20242. + @UsedByReflection("audio_classifier_jni.cc")
  20243. + public String getDisplayNamesLocale() {
  20244. + return displayNamesLocale;
  20245. + }
  20246. +
  20247. + @UsedByReflection("audio_classifier_jni.cc")
  20248. + public int getMaxResults() {
  20249. + return maxResults;
  20250. + }
  20251. +
  20252. + @UsedByReflection("audio_classifier_jni.cc")
  20253. + public float getScoreThreshold() {
  20254. + return scoreThreshold;
  20255. + }
  20256. +
  20257. + @UsedByReflection("audio_classifier_jni.cc")
  20258. + public boolean getIsScoreThresholdSet() {
  20259. + return isScoreThresholdSet;
  20260. + }
  20261. +
  20262. + @UsedByReflection("audio_classifier_jni.cc")
  20263. + public List<String> getLabelAllowList() {
  20264. + return new ArrayList<>(labelAllowList);
  20265. + }
  20266. +
  20267. + @UsedByReflection("audio_classifier_jni.cc")
  20268. + public List<String> getLabelDenyList() {
  20269. + return new ArrayList<>(labelDenyList);
  20270. + }
  20271. +
  20272. + public BaseOptions getBaseOptions() {
  20273. + return baseOptions;
  20274. + }
  20275. +
  20276. + private AudioClassifierOptions(Builder builder) {
  20277. + displayNamesLocale = builder.displayNamesLocale;
  20278. + maxResults = builder.maxResults;
  20279. + scoreThreshold = builder.scoreThreshold;
  20280. + isScoreThresholdSet = builder.isScoreThresholdSet;
  20281. + labelAllowList = builder.labelAllowList;
  20282. + labelDenyList = builder.labelDenyList;
  20283. + baseOptions = builder.baseOptions;
  20284. + }
  20285. }
  20286. - @UsedByReflection("audio_classifier_jni.cc")
  20287. - public boolean getIsScoreThresholdSet() {
  20288. - return isScoreThresholdSet;
  20289. + /**
  20290. + * Performs actual classification on the provided audio tensor.
  20291. + *
  20292. + * @param tensor a {@link TensorAudio} containing the input audio clip in float with values
  20293. + * between [-1, 1). The {@code tensor} argument should have the same flat size as the TFLite
  20294. + * model's input tensor. It's recommended to create {@code tensor} using {@code
  20295. + * createInputTensorAudio} method.
  20296. + * @throws IllegalArgumentException if an argument is invalid
  20297. + * @throws IllegalStateException if error occurs when classifying the audio clip from the native
  20298. + * code
  20299. + */
  20300. + public List<Classifications> classify(TensorAudio tensor) {
  20301. + TensorBuffer buffer = tensor.getTensorBuffer();
  20302. + TensorAudioFormat format = tensor.getFormat();
  20303. + checkState(buffer.getBuffer().hasArray(),
  20304. + "Input tensor buffer should be a non-direct buffer with a backed array (i.e. not readonly"
  20305. + + " buffer).");
  20306. + return classifyNative(getNativeHandle(), buffer.getBuffer().array(), format.getChannels(),
  20307. + format.getSampleRate());
  20308. }
  20309. - @UsedByReflection("audio_classifier_jni.cc")
  20310. - public List<String> getLabelAllowList() {
  20311. - return new ArrayList<>(labelAllowList);
  20312. + /**
  20313. + * Creates a {@link TensorAudio} instance to store input audio samples.
  20314. + *
  20315. + * @return a {@link TensorAudio} with the same size as model input tensor
  20316. + * @throws IllegalArgumentException if the model is not compatible
  20317. + */
  20318. + public TensorAudio createInputTensorAudio() {
  20319. + TensorAudioFormat format = getRequiredTensorAudioFormat();
  20320. +
  20321. + long bufferSize = getRequiredInputBufferSize();
  20322. + long samples = bufferSize / format.getChannels();
  20323. + return TensorAudio.create(format, (int) samples);
  20324. }
  20325. - @UsedByReflection("audio_classifier_jni.cc")
  20326. - public List<String> getLabelDenyList() {
  20327. - return new ArrayList<>(labelDenyList);
  20328. + /** Returns the required input buffer size in number of float elements. */
  20329. + public long getRequiredInputBufferSize() {
  20330. + return getRequiredInputBufferSizeNative(getNativeHandle());
  20331. }
  20332. - public BaseOptions getBaseOptions() {
  20333. - return baseOptions;
  20334. + /**
  20335. + * Creates an {@link android.media.AudioRecord} instance to record audio stream. The returned
  20336. + * AudioRecord instance is initialized and client needs to call {@link
  20337. + * android.media.AudioRecord#startRecording} method to start recording.
  20338. + *
  20339. + * @return an {@link android.media.AudioRecord} instance in {@link
  20340. + * android.media.AudioRecord#STATE_INITIALIZED}
  20341. + * @throws IllegalArgumentException if the model required channel count is unsupported
  20342. + * @throws IllegalStateException if AudioRecord instance failed to initialize
  20343. + */
  20344. + public AudioRecord createAudioRecord() {
  20345. + TensorAudioFormat format = getRequiredTensorAudioFormat();
  20346. + int channelConfig = 0;
  20347. +
  20348. + switch (format.getChannels()) {
  20349. + case 1:
  20350. + channelConfig = AudioFormat.CHANNEL_IN_MONO;
  20351. + break;
  20352. + case 2:
  20353. + channelConfig = AudioFormat.CHANNEL_IN_STEREO;
  20354. + break;
  20355. + default:
  20356. + throw new IllegalArgumentException(String.format(
  20357. + "Number of channels required by the model is %d. getAudioRecord method only"
  20358. + + " supports 1 or 2 audio channels.",
  20359. + format.getChannels()));
  20360. + }
  20361. +
  20362. + int bufferSizeInBytes = AudioRecord.getMinBufferSize(
  20363. + format.getSampleRate(), channelConfig, AudioFormat.ENCODING_PCM_FLOAT);
  20364. + if (bufferSizeInBytes == AudioRecord.ERROR
  20365. + || bufferSizeInBytes == AudioRecord.ERROR_BAD_VALUE) {
  20366. + throw new IllegalStateException(String.format(
  20367. + "AudioRecord.getMinBufferSize failed. Returned: %d", bufferSizeInBytes));
  20368. + }
  20369. + // The buffer of AudioRecord should be strictly longer than what model requires so that
  20370. + // clients could run `TensorAudio::load(record)` together with `AudioClassifier::classify`.
  20371. + int bufferSizeMultiplier = 2;
  20372. + int modelRequiredBufferSize = (int) getRequiredInputBufferSize()
  20373. + * DataType.FLOAT32.byteSize() * bufferSizeMultiplier;
  20374. + if (bufferSizeInBytes < modelRequiredBufferSize) {
  20375. + bufferSizeInBytes = modelRequiredBufferSize;
  20376. + }
  20377. + AudioRecord audioRecord = new AudioRecord(
  20378. + // including MIC, UNPROCESSED, and CAMCORDER.
  20379. + MediaRecorder.AudioSource.VOICE_RECOGNITION, format.getSampleRate(), channelConfig,
  20380. + AudioFormat.ENCODING_PCM_FLOAT, bufferSizeInBytes);
  20381. + checkState(audioRecord.getState() == AudioRecord.STATE_INITIALIZED,
  20382. + "AudioRecord failed to initialize");
  20383. + return audioRecord;
  20384. }
  20385. - private AudioClassifierOptions(Builder builder) {
  20386. - displayNamesLocale = builder.displayNamesLocale;
  20387. - maxResults = builder.maxResults;
  20388. - scoreThreshold = builder.scoreThreshold;
  20389. - isScoreThresholdSet = builder.isScoreThresholdSet;
  20390. - labelAllowList = builder.labelAllowList;
  20391. - labelDenyList = builder.labelDenyList;
  20392. - baseOptions = builder.baseOptions;
  20393. + /** Returns the {@link TensorAudioFormat} required by the model. */
  20394. + public TensorAudioFormat getRequiredTensorAudioFormat() {
  20395. + return TensorAudioFormat.builder()
  20396. + .setChannels(getRequiredChannels())
  20397. + .setSampleRate(getRequiredSampleRate())
  20398. + .build();
  20399. }
  20400. - }
  20401. -
  20402. - /**
  20403. - * Performs actual classification on the provided audio tensor.
  20404. - *
  20405. - * @param tensor a {@link TensorAudio} containing the input audio clip in float with values
  20406. - * between [-1, 1). The {@code tensor} argument should have the same flat size as the TFLite
  20407. - * model's input tensor. It's recommended to create {@code tensor} using {@code
  20408. - * createInputTensorAudio} method.
  20409. - * @throws IllegalArgumentException if an argument is invalid
  20410. - * @throws IllegalStateException if error occurs when classifying the audio clip from the native
  20411. - * code
  20412. - */
  20413. - public List<Classifications> classify(TensorAudio tensor) {
  20414. - TensorBuffer buffer = tensor.getTensorBuffer();
  20415. - TensorAudioFormat format = tensor.getFormat();
  20416. - checkState(
  20417. - buffer.getBuffer().hasArray(),
  20418. - "Input tensor buffer should be a non-direct buffer with a backed array (i.e. not readonly"
  20419. - + " buffer).");
  20420. - return classifyNative(
  20421. - getNativeHandle(),
  20422. - buffer.getBuffer().array(),
  20423. - format.getChannels(),
  20424. - format.getSampleRate());
  20425. - }
  20426. -
  20427. - /**
  20428. - * Creates a {@link TensorAudio} instance to store input audio samples.
  20429. - *
  20430. - * @return a {@link TensorAudio} with the same size as model input tensor
  20431. - * @throws IllegalArgumentException if the model is not compatible
  20432. - */
  20433. - public TensorAudio createInputTensorAudio() {
  20434. - TensorAudioFormat format = getRequiredTensorAudioFormat();
  20435. -
  20436. - long bufferSize = getRequiredInputBufferSize();
  20437. - long samples = bufferSize / format.getChannels();
  20438. - return TensorAudio.create(format, (int) samples);
  20439. - }
  20440. -
  20441. - /** Returns the required input buffer size in number of float elements. */
  20442. - public long getRequiredInputBufferSize() {
  20443. - return getRequiredInputBufferSizeNative(getNativeHandle());
  20444. - }
  20445. -
  20446. - /**
  20447. - * Creates an {@link android.media.AudioRecord} instance to record audio stream. The returned
  20448. - * AudioRecord instance is initialized and client needs to call {@link
  20449. - * android.media.AudioRecord#startRecording} method to start recording.
  20450. - *
  20451. - * @return an {@link android.media.AudioRecord} instance in {@link
  20452. - * android.media.AudioRecord#STATE_INITIALIZED}
  20453. - * @throws IllegalArgumentException if the model required channel count is unsupported
  20454. - * @throws IllegalStateException if AudioRecord instance failed to initialize
  20455. - */
  20456. - public AudioRecord createAudioRecord() {
  20457. - TensorAudioFormat format = getRequiredTensorAudioFormat();
  20458. - int channelConfig = 0;
  20459. -
  20460. - switch (format.getChannels()) {
  20461. - case 1:
  20462. - channelConfig = AudioFormat.CHANNEL_IN_MONO;
  20463. - break;
  20464. - case 2:
  20465. - channelConfig = AudioFormat.CHANNEL_IN_STEREO;
  20466. - break;
  20467. - default:
  20468. - throw new IllegalArgumentException(
  20469. - String.format(
  20470. - "Number of channels required by the model is %d. getAudioRecord method only"
  20471. - + " supports 1 or 2 audio channels.",
  20472. - format.getChannels()));
  20473. +
  20474. + private int getRequiredChannels() {
  20475. + return getRequiredChannelsNative(getNativeHandle());
  20476. }
  20477. - int bufferSizeInBytes =
  20478. - AudioRecord.getMinBufferSize(
  20479. - format.getSampleRate(), channelConfig, AudioFormat.ENCODING_PCM_FLOAT);
  20480. - if (bufferSizeInBytes == AudioRecord.ERROR
  20481. - || bufferSizeInBytes == AudioRecord.ERROR_BAD_VALUE) {
  20482. - throw new IllegalStateException(
  20483. - String.format("AudioRecord.getMinBufferSize failed. Returned: %d", bufferSizeInBytes));
  20484. + private int getRequiredSampleRate() {
  20485. + return getRequiredSampleRateNative(getNativeHandle());
  20486. }
  20487. - // The buffer of AudioRecord should be strictly longer than what model requires so that clients
  20488. - // could run `TensorAudio::load(record)` together with `AudioClassifier::classify`.
  20489. - int bufferSizeMultiplier = 2;
  20490. - int modelRequiredBufferSize =
  20491. - (int) getRequiredInputBufferSize() * DataType.FLOAT32.byteSize() * bufferSizeMultiplier;
  20492. - if (bufferSizeInBytes < modelRequiredBufferSize) {
  20493. - bufferSizeInBytes = modelRequiredBufferSize;
  20494. +
  20495. + // TODO(b/183343074): JNI method invocation is very expensive, taking about .2ms
  20496. + // each time. Consider combining the native getter methods into 1 and cache it in Java layer.
  20497. + private static native long getRequiredInputBufferSizeNative(long nativeHandle);
  20498. +
  20499. + private static native int getRequiredChannelsNative(long nativeHandle);
  20500. +
  20501. + private static native int getRequiredSampleRateNative(long nativeHandle);
  20502. +
  20503. + private static native List<Classifications> classifyNative(
  20504. + long nativeHandle, byte[] audioBuffer, int channels, int sampleRate);
  20505. +
  20506. + private static native long initJniWithModelFdAndOptions(int fileDescriptor,
  20507. + long fileDescriptorLength, long fileDescriptorOffset, AudioClassifierOptions options,
  20508. + long baseOptionsHandle);
  20509. +
  20510. + private static native long initJniWithByteBuffer(
  20511. + ByteBuffer modelBuffer, AudioClassifierOptions options, long baseOptionsHandle);
  20512. +
  20513. + /**
  20514. + * Releases memory pointed by {@code nativeHandle}, namely a C++ `AudioClassifier` instance.
  20515. + *
  20516. + * @param nativeHandle pointer to memory allocated
  20517. + */
  20518. + @Override
  20519. + protected void deinit(long nativeHandle) {
  20520. + deinitJni(nativeHandle);
  20521. }
  20522. - AudioRecord audioRecord =
  20523. - new AudioRecord(
  20524. - // including MIC, UNPROCESSED, and CAMCORDER.
  20525. - MediaRecorder.AudioSource.VOICE_RECOGNITION,
  20526. - format.getSampleRate(),
  20527. - channelConfig,
  20528. - AudioFormat.ENCODING_PCM_FLOAT,
  20529. - bufferSizeInBytes);
  20530. - checkState(
  20531. - audioRecord.getState() == AudioRecord.STATE_INITIALIZED,
  20532. - "AudioRecord failed to initialize");
  20533. - return audioRecord;
  20534. - }
  20535. -
  20536. - /** Returns the {@link TensorAudioFormat} required by the model. */
  20537. - public TensorAudioFormat getRequiredTensorAudioFormat() {
  20538. - return TensorAudioFormat.builder()
  20539. - .setChannels(getRequiredChannels())
  20540. - .setSampleRate(getRequiredSampleRate())
  20541. - .build();
  20542. - }
  20543. -
  20544. - private int getRequiredChannels() {
  20545. - return getRequiredChannelsNative(getNativeHandle());
  20546. - }
  20547. -
  20548. - private int getRequiredSampleRate() {
  20549. - return getRequiredSampleRateNative(getNativeHandle());
  20550. - }
  20551. -
  20552. - // TODO(b/183343074): JNI method invocation is very expensive, taking about .2ms
  20553. - // each time. Consider combining the native getter methods into 1 and cache it in Java layer.
  20554. - private static native long getRequiredInputBufferSizeNative(long nativeHandle);
  20555. -
  20556. - private static native int getRequiredChannelsNative(long nativeHandle);
  20557. -
  20558. - private static native int getRequiredSampleRateNative(long nativeHandle);
  20559. -
  20560. - private static native List<Classifications> classifyNative(
  20561. - long nativeHandle, byte[] audioBuffer, int channels, int sampleRate);
  20562. -
  20563. - private static native long initJniWithModelFdAndOptions(
  20564. - int fileDescriptor,
  20565. - long fileDescriptorLength,
  20566. - long fileDescriptorOffset,
  20567. - AudioClassifierOptions options,
  20568. - long baseOptionsHandle);
  20569. -
  20570. - private static native long initJniWithByteBuffer(
  20571. - ByteBuffer modelBuffer, AudioClassifierOptions options, long baseOptionsHandle);
  20572. -
  20573. - /**
  20574. - * Releases memory pointed by {@code nativeHandle}, namely a C++ `AudioClassifier` instance.
  20575. - *
  20576. - * @param nativeHandle pointer to memory allocated
  20577. - */
  20578. - @Override
  20579. - protected void deinit(long nativeHandle) {
  20580. - deinitJni(nativeHandle);
  20581. - }
  20582. -
  20583. - /**
  20584. - * Native method to release memory pointed by {@code nativeHandle}, namely a C++ `AudioClassifier`
  20585. - * instance.
  20586. - *
  20587. - * @param nativeHandle pointer to memory allocated
  20588. - */
  20589. - private static native void deinitJni(long nativeHandle);
  20590. +
  20591. + /**
  20592. + * Native method to release memory pointed by {@code nativeHandle}, namely a C++
  20593. + * `AudioClassifier` instance.
  20594. + *
  20595. + * @param nativeHandle pointer to memory allocated
  20596. + */
  20597. + private static native void deinitJni(long nativeHandle);
  20598. }
  20599. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/audio/classifier/Classifications.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/audio/classifier/Classifications.java
  20600. index 9c0cdf9e249ae..8e8270269dad8 100644
  20601. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/audio/classifier/Classifications.java
  20602. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/audio/classifier/Classifications.java
  20603. @@ -16,11 +16,13 @@ limitations under the License.
  20604. package org.tensorflow.lite.task.audio.classifier;
  20605. import com.google.auto.value.AutoValue;
  20606. +
  20607. +import org.tensorflow.lite.support.label.Category;
  20608. +import org.tensorflow.lite.task.core.annotations.UsedByReflection;
  20609. +
  20610. import java.util.ArrayList;
  20611. import java.util.Collections;
  20612. import java.util.List;
  20613. -import org.tensorflow.lite.support.label.Category;
  20614. -import org.tensorflow.lite.task.core.annotations.UsedByReflection;
  20615. /**
  20616. * The classification results of one head in a multihead (a.k.a. multi-output) {@link
  20617. @@ -31,18 +33,18 @@ import org.tensorflow.lite.task.core.annotations.UsedByReflection;
  20618. @AutoValue
  20619. @UsedByReflection("audio_classifier_jni.cc")
  20620. public abstract class Classifications {
  20621. + @UsedByReflection("audio_classifier_jni.cc")
  20622. + static Classifications create(List<Category> categories, int headIndex, String headName) {
  20623. + return new AutoValue_Classifications(
  20624. + Collections.unmodifiableList(new ArrayList<Category>(categories)), headIndex,
  20625. + headName);
  20626. + }
  20627. - @UsedByReflection("audio_classifier_jni.cc")
  20628. - static Classifications create(List<Category> categories, int headIndex, String headName) {
  20629. - return new AutoValue_Classifications(
  20630. - Collections.unmodifiableList(new ArrayList<Category>(categories)), headIndex, headName);
  20631. - }
  20632. -
  20633. - // Same reason for not using ImmutableList as stated in
  20634. - // {@link ImageClassifier#ImageClassifierOptions#labelAllowList}.
  20635. - public abstract List<Category> getCategories();
  20636. + // Same reason for not using ImmutableList as stated in
  20637. + // {@link ImageClassifier#ImageClassifierOptions#labelAllowList}.
  20638. + public abstract List<Category> getCategories();
  20639. - public abstract int getHeadIndex();
  20640. + public abstract int getHeadIndex();
  20641. - public abstract String getHeadName();
  20642. + public abstract String getHeadName();
  20643. }
  20644. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/BaseOptions.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/BaseOptions.java
  20645. index 242414bd21bdb..b2d722332c954 100644
  20646. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/BaseOptions.java
  20647. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/BaseOptions.java
  20648. @@ -20,65 +20,66 @@ import com.google.auto.value.AutoValue;
  20649. /** Options to configure Task APIs in general. */
  20650. @AutoValue
  20651. public abstract class BaseOptions {
  20652. - private static final int DEFAULT_NUM_THREADS = -1;
  20653. + private static final int DEFAULT_NUM_THREADS = -1;
  20654. - /** Builder for {@link BaseOptions}. */
  20655. - @AutoValue.Builder
  20656. - public abstract static class Builder {
  20657. + /** Builder for {@link BaseOptions}. */
  20658. + @AutoValue.Builder
  20659. + public abstract static class Builder {
  20660. + /**
  20661. + * Sets the advanced accelerator options.
  20662. + *
  20663. + * <p>Note: this method will override those highlevel API to choose an delegate, such as
  20664. + * {@link #useGpu} and {@link #useNnapi}.
  20665. + */
  20666. + public abstract Builder setComputeSettings(ComputeSettings computeSettings);
  20667. - /**
  20668. - * Sets the advanced accelerator options.
  20669. - *
  20670. - * <p>Note: this method will override those highlevel API to choose an delegate, such as {@link
  20671. - * #useGpu} and {@link #useNnapi}.
  20672. - */
  20673. - public abstract Builder setComputeSettings(ComputeSettings computeSettings);
  20674. + /**
  20675. + * Sets the number of threads to be used for TFLite ops that support multi-threading when
  20676. + * running inference with CPU. Defaults to -1.
  20677. + *
  20678. + * <p>{@code numThreads} should be greater than 0 or equal to -1. Setting numThreads to -1
  20679. + * has the effect to let TFLite runtime set the value.
  20680. + */
  20681. + public abstract Builder setNumThreads(int numThreads);
  20682. - /**
  20683. - * Sets the number of threads to be used for TFLite ops that support multi-threading when
  20684. - * running inference with CPU. Defaults to -1.
  20685. - *
  20686. - * <p>{@code numThreads} should be greater than 0 or equal to -1. Setting numThreads to -1 has
  20687. - * the effect to let TFLite runtime set the value.
  20688. - */
  20689. - public abstract Builder setNumThreads(int numThreads);
  20690. + /**
  20691. + * Uses GPU for inference. The advanced GPU configuration settings will be set to default
  20692. + * values.
  20693. + *
  20694. + * <p>Note: this method will override the settings from {@link #setComputeSettings}.
  20695. + *
  20696. + * <p>To manipulate the advanced GPU configuration settings, use {@link
  20697. + * #setComputeSettings}.
  20698. + */
  20699. + public Builder useGpu() {
  20700. + return setComputeSettings(
  20701. + ComputeSettings.builder().setDelegate(ComputeSettings.Delegate.GPU).build());
  20702. + }
  20703. - /**
  20704. - * Uses GPU for inference. The advanced GPU configuration settings will be set to default
  20705. - * values.
  20706. - *
  20707. - * <p>Note: this method will override the settings from {@link #setComputeSettings}.
  20708. - *
  20709. - * <p>To manipulate the advanced GPU configuration settings, use {@link #setComputeSettings}.
  20710. - */
  20711. - public Builder useGpu() {
  20712. - return setComputeSettings(
  20713. - ComputeSettings.builder().setDelegate(ComputeSettings.Delegate.GPU).build());
  20714. - }
  20715. + /**
  20716. + * Uses NNAPI for inference. The advanced NNAPI configuration settings will be set to
  20717. + * default values.
  20718. + *
  20719. + * <p>Note: this method will override the settings from {@link #setComputeSettings}.
  20720. + *
  20721. + * <p>To manipulate the advanced NNAPI configuration settings, use {@link
  20722. + * #setComputeSettings}.
  20723. + */
  20724. + public Builder useNnapi() {
  20725. + return setComputeSettings(
  20726. + ComputeSettings.builder().setDelegate(ComputeSettings.Delegate.NNAPI).build());
  20727. + }
  20728. - /**
  20729. - * Uses NNAPI for inference. The advanced NNAPI configuration settings will be set to default
  20730. - * values.
  20731. - *
  20732. - * <p>Note: this method will override the settings from {@link #setComputeSettings}.
  20733. - *
  20734. - * <p>To manipulate the advanced NNAPI configuration settings, use {@link #setComputeSettings}.
  20735. - */
  20736. - public Builder useNnapi() {
  20737. - return setComputeSettings(
  20738. - ComputeSettings.builder().setDelegate(ComputeSettings.Delegate.NNAPI).build());
  20739. + public abstract BaseOptions build();
  20740. }
  20741. - public abstract BaseOptions build();
  20742. - }
  20743. -
  20744. - public static Builder builder() {
  20745. - return new AutoValue_BaseOptions.Builder()
  20746. - .setComputeSettings(ComputeSettings.builder().build())
  20747. - .setNumThreads(DEFAULT_NUM_THREADS);
  20748. - }
  20749. + public static Builder builder() {
  20750. + return new AutoValue_BaseOptions.Builder()
  20751. + .setComputeSettings(ComputeSettings.builder().build())
  20752. + .setNumThreads(DEFAULT_NUM_THREADS);
  20753. + }
  20754. - abstract ComputeSettings getComputeSettings();
  20755. + abstract ComputeSettings getComputeSettings();
  20756. - abstract int getNumThreads();
  20757. + abstract int getNumThreads();
  20758. }
  20759. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/BaseTaskApi.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/BaseTaskApi.java
  20760. index b3fe9def83c69..a8ae65cd1cf3b 100644
  20761. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/BaseTaskApi.java
  20762. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/BaseTaskApi.java
  20763. @@ -16,76 +16,78 @@ limitations under the License.
  20764. package org.tensorflow.lite.task.core;
  20765. import android.util.Log;
  20766. +
  20767. import java.io.Closeable;
  20768. /**
  20769. * Base class for Task API, provides shared logic to load/unload native libs to its C++ counterpart.
  20770. */
  20771. public abstract class BaseTaskApi implements Closeable {
  20772. - private static final String TAG = BaseTaskApi.class.getSimpleName();
  20773. -
  20774. - /**
  20775. - * Represents a pointer to the corresponding C++ task_api object. The nativeHandle pointer is
  20776. - * initialized from subclasses and must be released by calling {@link #deinit} after it is no
  20777. - * longer needed.
  20778. - */
  20779. - private final long nativeHandle;
  20780. -
  20781. - /** Indicates whether the {@link #nativeHandle} pointer has been released yet. */
  20782. - private boolean closed;
  20783. -
  20784. - /**
  20785. - * Constructor to initialize the JNI with a pointer from C++.
  20786. - *
  20787. - * @param nativeHandle a pointer referencing memory allocated in C++.
  20788. - */
  20789. - protected BaseTaskApi(long nativeHandle) {
  20790. - if (nativeHandle == TaskJniUtils.INVALID_POINTER) {
  20791. - throw new IllegalArgumentException("Failed to load C++ pointer from JNI");
  20792. + private static final String TAG = BaseTaskApi.class.getSimpleName();
  20793. +
  20794. + /**
  20795. + * Represents a pointer to the corresponding C++ task_api object. The nativeHandle pointer is
  20796. + * initialized from subclasses and must be released by calling {@link #deinit} after it is no
  20797. + * longer needed.
  20798. + */
  20799. + private final long nativeHandle;
  20800. +
  20801. + /** Indicates whether the {@link #nativeHandle} pointer has been released yet. */
  20802. + private boolean closed;
  20803. +
  20804. + /**
  20805. + * Constructor to initialize the JNI with a pointer from C++.
  20806. + *
  20807. + * @param nativeHandle a pointer referencing memory allocated in C++.
  20808. + */
  20809. + protected BaseTaskApi(long nativeHandle) {
  20810. + if (nativeHandle == TaskJniUtils.INVALID_POINTER) {
  20811. + throw new IllegalArgumentException("Failed to load C++ pointer from JNI");
  20812. + }
  20813. + this.nativeHandle = nativeHandle;
  20814. + }
  20815. +
  20816. + public boolean isClosed() {
  20817. + return closed;
  20818. }
  20819. - this.nativeHandle = nativeHandle;
  20820. - }
  20821. -
  20822. - public boolean isClosed() {
  20823. - return closed;
  20824. - }
  20825. -
  20826. - /** Release the memory allocated from C++ and deregister the library from the static holder. */
  20827. - @Override
  20828. - public synchronized void close() {
  20829. - if (closed) {
  20830. - return;
  20831. +
  20832. + /** Release the memory allocated from C++ and deregister the library from the static holder. */
  20833. + @Override
  20834. + public synchronized void close() {
  20835. + if (closed) {
  20836. + return;
  20837. + }
  20838. + deinit(nativeHandle);
  20839. + closed = true;
  20840. }
  20841. - deinit(nativeHandle);
  20842. - closed = true;
  20843. - }
  20844. - public long getNativeHandle() {
  20845. - return nativeHandle;
  20846. - }
  20847. + public long getNativeHandle() {
  20848. + return nativeHandle;
  20849. + }
  20850. - protected void checkNotClosed() {
  20851. - if (isClosed()) {
  20852. - throw new IllegalStateException("Internal error: The task lib has already been closed.");
  20853. + protected void checkNotClosed() {
  20854. + if (isClosed()) {
  20855. + throw new IllegalStateException(
  20856. + "Internal error: The task lib has already been closed.");
  20857. + }
  20858. }
  20859. - }
  20860. -
  20861. - @Override
  20862. - protected void finalize() throws Throwable {
  20863. - try {
  20864. - if (!closed) {
  20865. - Log.w(TAG, "Closing an already closed native lib");
  20866. - close();
  20867. - }
  20868. - } finally {
  20869. - super.finalize();
  20870. +
  20871. + @Override
  20872. + protected void finalize() throws Throwable {
  20873. + try {
  20874. + if (!closed) {
  20875. + Log.w(TAG, "Closing an already closed native lib");
  20876. + close();
  20877. + }
  20878. + } finally {
  20879. + super.finalize();
  20880. + }
  20881. }
  20882. - }
  20883. -
  20884. - /**
  20885. - * Releases memory pointed by the pointer in the native layer.
  20886. - *
  20887. - * @param nativeHandle pointer to memory allocated
  20888. - */
  20889. - protected abstract void deinit(long nativeHandle);
  20890. +
  20891. + /**
  20892. + * Releases memory pointed by the pointer in the native layer.
  20893. + *
  20894. + * @param nativeHandle pointer to memory allocated
  20895. + */
  20896. + protected abstract void deinit(long nativeHandle);
  20897. }
  20898. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/ComputeSettings.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/ComputeSettings.java
  20899. index 80a9e82ff3802..0c2d04283594d 100644
  20900. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/ComputeSettings.java
  20901. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/ComputeSettings.java
  20902. @@ -20,38 +20,36 @@ import com.google.auto.value.AutoValue;
  20903. /** Options to configure how to accelerate the model inference using dedicated delegates. */
  20904. @AutoValue
  20905. public abstract class ComputeSettings {
  20906. + /** TFLite accelerator delegate options. */
  20907. + public enum Delegate {
  20908. + NONE(0),
  20909. + NNAPI(1),
  20910. + GPU(2);
  20911. - /** TFLite accelerator delegate options. */
  20912. - public enum Delegate {
  20913. - NONE(0),
  20914. - NNAPI(1),
  20915. - GPU(2);
  20916. + private final int value;
  20917. - private final int value;
  20918. + Delegate(int value) {
  20919. + this.value = value;
  20920. + }
  20921. - Delegate(int value) {
  20922. - this.value = value;
  20923. + public int getValue() {
  20924. + return value;
  20925. + }
  20926. }
  20927. - public int getValue() {
  20928. - return value;
  20929. - }
  20930. - }
  20931. -
  20932. - /** Builder for {@link ComputeSettings}. */
  20933. - @AutoValue.Builder
  20934. - public abstract static class Builder {
  20935. -
  20936. - public abstract Builder setDelegate(Delegate delegate);
  20937. + /** Builder for {@link ComputeSettings}. */
  20938. + @AutoValue.Builder
  20939. + public abstract static class Builder {
  20940. + public abstract Builder setDelegate(Delegate delegate);
  20941. - public abstract ComputeSettings build();
  20942. - }
  20943. + public abstract ComputeSettings build();
  20944. + }
  20945. - public static Builder builder() {
  20946. - return new AutoValue_ComputeSettings.Builder().setDelegate(DEFAULT_DELEGATE);
  20947. - }
  20948. + public static Builder builder() {
  20949. + return new AutoValue_ComputeSettings.Builder().setDelegate(DEFAULT_DELEGATE);
  20950. + }
  20951. - public abstract Delegate getDelegate();
  20952. + public abstract Delegate getDelegate();
  20953. - private static final Delegate DEFAULT_DELEGATE = Delegate.NONE;
  20954. + private static final Delegate DEFAULT_DELEGATE = Delegate.NONE;
  20955. }
  20956. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/TaskJniUtils.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/TaskJniUtils.java
  20957. index 76109f453b01f..9d5b775456c43 100644
  20958. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/TaskJniUtils.java
  20959. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/TaskJniUtils.java
  20960. @@ -18,6 +18,7 @@ package org.tensorflow.lite.task.core;
  20961. import android.content.Context;
  20962. import android.content.res.AssetFileDescriptor;
  20963. import android.util.Log;
  20964. +
  20965. import java.io.FileInputStream;
  20966. import java.io.IOException;
  20967. import java.nio.ByteBuffer;
  20968. @@ -26,156 +27,146 @@ import java.nio.channels.FileChannel;
  20969. /** JNI utils for Task API. */
  20970. public class TaskJniUtils {
  20971. - public static final long INVALID_POINTER = 0;
  20972. - private static final String TAG = TaskJniUtils.class.getSimpleName();
  20973. - /** Syntax sugar to get nativeHandle from empty param list. */
  20974. - public interface EmptyHandleProvider {
  20975. - long createHandle();
  20976. - }
  20977. -
  20978. - /** Syntax sugar to get nativeHandle from an array of {@link ByteBuffer}s. */
  20979. - public interface MultipleBuffersHandleProvider {
  20980. - long createHandle(ByteBuffer... buffers);
  20981. - }
  20982. -
  20983. - /** Syntax sugar to get nativeHandle from file descriptor and options. */
  20984. - public interface FdAndOptionsHandleProvider<T> {
  20985. - long createHandle(
  20986. - int fileDescriptor, long fileDescriptorLength, long fileDescriptorOffset, T options);
  20987. - }
  20988. -
  20989. - /**
  20990. - * Initializes the JNI and returns C++ handle with file descriptor and options for task API.
  20991. - *
  20992. - * @param context the Android app context
  20993. - * @param provider provider to get C++ handle, usually returned from native call
  20994. - * @param libName name of C++ lib to be loaded
  20995. - * @param filePath path of the file to be loaded
  20996. - * @param options options to set up the task API, used by the provider
  20997. - * @return C++ handle as long
  20998. - * @throws IOException If model file fails to load.
  20999. - */
  21000. - public static <T> long createHandleFromFdAndOptions(
  21001. - Context context,
  21002. - final FdAndOptionsHandleProvider<T> provider,
  21003. - String libName,
  21004. - String filePath,
  21005. - final T options)
  21006. - throws IOException {
  21007. - try (AssetFileDescriptor assetFileDescriptor = context.getAssets().openFd(filePath)) {
  21008. - return createHandleFromLibrary(
  21009. - new EmptyHandleProvider() {
  21010. + public static final long INVALID_POINTER = 0;
  21011. + private static final String TAG = TaskJniUtils.class.getSimpleName();
  21012. + /** Syntax sugar to get nativeHandle from empty param list. */
  21013. + public interface EmptyHandleProvider {
  21014. + long createHandle();
  21015. + }
  21016. +
  21017. + /** Syntax sugar to get nativeHandle from an array of {@link ByteBuffer}s. */
  21018. + public interface MultipleBuffersHandleProvider {
  21019. + long createHandle(ByteBuffer... buffers);
  21020. + }
  21021. +
  21022. + /** Syntax sugar to get nativeHandle from file descriptor and options. */
  21023. + public interface FdAndOptionsHandleProvider<T> {
  21024. + long createHandle(int fileDescriptor, long fileDescriptorLength, long fileDescriptorOffset,
  21025. + T options);
  21026. + }
  21027. +
  21028. + /**
  21029. + * Initializes the JNI and returns C++ handle with file descriptor and options for task API.
  21030. + *
  21031. + * @param context the Android app context
  21032. + * @param provider provider to get C++ handle, usually returned from native call
  21033. + * @param libName name of C++ lib to be loaded
  21034. + * @param filePath path of the file to be loaded
  21035. + * @param options options to set up the task API, used by the provider
  21036. + * @return C++ handle as long
  21037. + * @throws IOException If model file fails to load.
  21038. + */
  21039. + public static <T> long createHandleFromFdAndOptions(Context context,
  21040. + final FdAndOptionsHandleProvider<T> provider, String libName, String filePath,
  21041. + final T options) throws IOException {
  21042. + try (AssetFileDescriptor assetFileDescriptor = context.getAssets().openFd(filePath)) {
  21043. + return createHandleFromLibrary(new EmptyHandleProvider() {
  21044. + @Override
  21045. + public long createHandle() {
  21046. + return provider.createHandle(
  21047. + /*fileDescriptor=*/assetFileDescriptor.getParcelFileDescriptor()
  21048. + .getFd(),
  21049. + /*fileDescriptorLength=*/assetFileDescriptor.getLength(),
  21050. + /*fileDescriptorOffset=*/assetFileDescriptor.getStartOffset(), options);
  21051. + }
  21052. + }, libName);
  21053. + }
  21054. + }
  21055. +
  21056. + /**
  21057. + * Initializes the JNI and returns C++ handle by first loading the C++ library and then invokes
  21058. + * {@link EmptyHandleProvider#createHandle()}.
  21059. + *
  21060. + * @param provider provider to get C++ handle, usually returned from native call
  21061. + * @return C++ handle as long
  21062. + */
  21063. + public static long createHandleFromLibrary(EmptyHandleProvider provider, String libName) {
  21064. + tryLoadLibrary(libName);
  21065. + try {
  21066. + return provider.createHandle();
  21067. + } catch (RuntimeException e) {
  21068. + String errorMessage = "Error getting native address of native library: " + libName;
  21069. + Log.e(TAG, errorMessage, e);
  21070. + throw new IllegalStateException(errorMessage, e);
  21071. + }
  21072. + }
  21073. +
  21074. + /**
  21075. + * Initializes the JNI and returns C++ handle by first loading the C++ library and then invokes
  21076. + * {@link MultipleBuffersHandleProvider#createHandle(ByteBuffer...)}.
  21077. + *
  21078. + * @param context app context
  21079. + * @param provider provider to get C++ pointer, usually returned from native call
  21080. + * @param libName name of C++ lib to load
  21081. + * @param filePaths file paths to load
  21082. + * @return C++ pointer as long
  21083. + * @throws IOException If model file fails to load.
  21084. + */
  21085. + public static long createHandleWithMultipleAssetFilesFromLibrary(Context context,
  21086. + final MultipleBuffersHandleProvider provider, String libName, String... filePaths)
  21087. + throws IOException {
  21088. + final MappedByteBuffer[] buffers = new MappedByteBuffer[filePaths.length];
  21089. + for (int i = 0; i < filePaths.length; i++) {
  21090. + buffers[i] = loadMappedFile(context, filePaths[i]);
  21091. + }
  21092. + return createHandleFromLibrary(new EmptyHandleProvider() {
  21093. @Override
  21094. public long createHandle() {
  21095. - return provider.createHandle(
  21096. - /*fileDescriptor=*/ assetFileDescriptor.getParcelFileDescriptor().getFd(),
  21097. - /*fileDescriptorLength=*/ assetFileDescriptor.getLength(),
  21098. - /*fileDescriptorOffset=*/ assetFileDescriptor.getStartOffset(),
  21099. - options);
  21100. + return provider.createHandle(buffers);
  21101. }
  21102. - },
  21103. - libName);
  21104. - }
  21105. - }
  21106. -
  21107. - /**
  21108. - * Initializes the JNI and returns C++ handle by first loading the C++ library and then invokes
  21109. - * {@link EmptyHandleProvider#createHandle()}.
  21110. - *
  21111. - * @param provider provider to get C++ handle, usually returned from native call
  21112. - * @return C++ handle as long
  21113. - */
  21114. - public static long createHandleFromLibrary(EmptyHandleProvider provider, String libName) {
  21115. - tryLoadLibrary(libName);
  21116. - try {
  21117. - return provider.createHandle();
  21118. - } catch (RuntimeException e) {
  21119. - String errorMessage = "Error getting native address of native library: " + libName;
  21120. - Log.e(TAG, errorMessage, e);
  21121. - throw new IllegalStateException(errorMessage, e);
  21122. - }
  21123. - }
  21124. -
  21125. - /**
  21126. - * Initializes the JNI and returns C++ handle by first loading the C++ library and then invokes
  21127. - * {@link MultipleBuffersHandleProvider#createHandle(ByteBuffer...)}.
  21128. - *
  21129. - * @param context app context
  21130. - * @param provider provider to get C++ pointer, usually returned from native call
  21131. - * @param libName name of C++ lib to load
  21132. - * @param filePaths file paths to load
  21133. - * @return C++ pointer as long
  21134. - * @throws IOException If model file fails to load.
  21135. - */
  21136. - public static long createHandleWithMultipleAssetFilesFromLibrary(
  21137. - Context context,
  21138. - final MultipleBuffersHandleProvider provider,
  21139. - String libName,
  21140. - String... filePaths)
  21141. - throws IOException {
  21142. - final MappedByteBuffer[] buffers = new MappedByteBuffer[filePaths.length];
  21143. - for (int i = 0; i < filePaths.length; i++) {
  21144. - buffers[i] = loadMappedFile(context, filePaths[i]);
  21145. + }, libName);
  21146. }
  21147. - return createHandleFromLibrary(
  21148. - new EmptyHandleProvider() {
  21149. - @Override
  21150. - public long createHandle() {
  21151. - return provider.createHandle(buffers);
  21152. - }
  21153. - },
  21154. - libName);
  21155. - }
  21156. -
  21157. - /**
  21158. - * Loads a file from the asset folder through memory mapping.
  21159. - *
  21160. - * @param context Application context to access assets.
  21161. - * @param filePath Asset path of the file.
  21162. - * @return the loaded memory mapped file.
  21163. - * @throws IOException If model file fails to load.
  21164. - */
  21165. - public static MappedByteBuffer loadMappedFile(Context context, String filePath)
  21166. - throws IOException {
  21167. - try (AssetFileDescriptor fileDescriptor = context.getAssets().openFd(filePath);
  21168. - FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor())) {
  21169. - FileChannel fileChannel = inputStream.getChannel();
  21170. - long startOffset = fileDescriptor.getStartOffset();
  21171. - long declaredLength = fileDescriptor.getDeclaredLength();
  21172. - return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
  21173. +
  21174. + /**
  21175. + * Loads a file from the asset folder through memory mapping.
  21176. + *
  21177. + * @param context Application context to access assets.
  21178. + * @param filePath Asset path of the file.
  21179. + * @return the loaded memory mapped file.
  21180. + * @throws IOException If model file fails to load.
  21181. + */
  21182. + public static MappedByteBuffer loadMappedFile(Context context, String filePath)
  21183. + throws IOException {
  21184. + try (AssetFileDescriptor fileDescriptor = context.getAssets().openFd(filePath);
  21185. + FileInputStream inputStream =
  21186. + new FileInputStream(fileDescriptor.getFileDescriptor())) {
  21187. + FileChannel fileChannel = inputStream.getChannel();
  21188. + long startOffset = fileDescriptor.getStartOffset();
  21189. + long declaredLength = fileDescriptor.getDeclaredLength();
  21190. + return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
  21191. + }
  21192. }
  21193. - }
  21194. -
  21195. - /**
  21196. - * Try loading a native library, if it's already loaded return directly.
  21197. - *
  21198. - * @param libName name of the lib
  21199. - */
  21200. - public static void tryLoadLibrary(String libName) {
  21201. - try {
  21202. - System.loadLibrary(libName);
  21203. - } catch (UnsatisfiedLinkError e) {
  21204. - String errorMessage = "Error loading native library: " + libName;
  21205. - Log.e(TAG, errorMessage, e);
  21206. - throw new UnsatisfiedLinkError(errorMessage);
  21207. +
  21208. + /**
  21209. + * Try loading a native library, if it's already loaded return directly.
  21210. + *
  21211. + * @param libName name of the lib
  21212. + */
  21213. + public static void tryLoadLibrary(String libName) {
  21214. + try {
  21215. + System.loadLibrary(libName);
  21216. + } catch (UnsatisfiedLinkError e) {
  21217. + String errorMessage = "Error loading native library: " + libName;
  21218. + Log.e(TAG, errorMessage, e);
  21219. + throw new UnsatisfiedLinkError(errorMessage);
  21220. + }
  21221. }
  21222. - }
  21223. - public static long createProtoBaseOptionsHandle(BaseOptions baseOptions) {
  21224. - return createProtoBaseOptionsHandleWithLegacyNumThreads(baseOptions, /*legacyNumThreads =*/ -1);
  21225. - }
  21226. + public static long createProtoBaseOptionsHandle(BaseOptions baseOptions) {
  21227. + return createProtoBaseOptionsHandleWithLegacyNumThreads(
  21228. + baseOptions, /*legacyNumThreads =*/-1);
  21229. + }
  21230. - public static long createProtoBaseOptionsHandleWithLegacyNumThreads(
  21231. - BaseOptions baseOptions, int legacyNumThreads) {
  21232. - // NumThreads should be configured through BaseOptions. However, if NumThreads is configured
  21233. - // through the legacy API of the Task Java API (then it will not equal to -1, the default
  21234. - // value), use it to overide the one in baseOptions.
  21235. - return createProtoBaseOptions(
  21236. - baseOptions.getComputeSettings().getDelegate().getValue(),
  21237. - legacyNumThreads == -1 ? baseOptions.getNumThreads() : legacyNumThreads);
  21238. - }
  21239. + public static long createProtoBaseOptionsHandleWithLegacyNumThreads(
  21240. + BaseOptions baseOptions, int legacyNumThreads) {
  21241. + // NumThreads should be configured through BaseOptions. However, if NumThreads is configured
  21242. + // through the legacy API of the Task Java API (then it will not equal to -1, the default
  21243. + // value), use it to overide the one in baseOptions.
  21244. + return createProtoBaseOptions(baseOptions.getComputeSettings().getDelegate().getValue(),
  21245. + legacyNumThreads == -1 ? baseOptions.getNumThreads() : legacyNumThreads);
  21246. + }
  21247. - private TaskJniUtils() {}
  21248. + private TaskJniUtils() {}
  21249. - private static native long createProtoBaseOptions(int delegate, int numThreads);
  21250. + private static native long createProtoBaseOptions(int delegate, int numThreads);
  21251. }
  21252. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/annotations/UsedByReflection.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/annotations/UsedByReflection.java
  21253. index bfa1ea750cf1f..fb1dfec82d7b4 100644
  21254. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/annotations/UsedByReflection.java
  21255. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/annotations/UsedByReflection.java
  21256. @@ -27,5 +27,5 @@ import java.lang.annotation.Target;
  21257. */
  21258. @Target({ElementType.METHOD, ElementType.FIELD, ElementType.TYPE, ElementType.CONSTRUCTOR})
  21259. public @interface UsedByReflection {
  21260. - String value();
  21261. + String value();
  21262. }
  21263. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/vision/ImageProcessingOptions.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/vision/ImageProcessingOptions.java
  21264. index 287ba444c386b..b1784d02f2362 100644
  21265. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/vision/ImageProcessingOptions.java
  21266. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/vision/ImageProcessingOptions.java
  21267. @@ -16,6 +16,7 @@ limitations under the License.
  21268. package org.tensorflow.lite.task.core.vision;
  21269. import android.graphics.Rect;
  21270. +
  21271. import com.google.auto.value.AutoValue;
  21272. /**
  21273. @@ -45,74 +46,74 @@ import com.google.auto.value.AutoValue;
  21274. */
  21275. @AutoValue
  21276. public abstract class ImageProcessingOptions {
  21277. -
  21278. - /**
  21279. - * Orientation type that follows EXIF specification.
  21280. - *
  21281. - * <p>The name of each enum value defines the position of the 0th row and the 0th column of the
  21282. - * image content. See the <a href="http://jpegclub.org/exif_orientation.html">EXIF orientation
  21283. - * documentation</a> for details.
  21284. - */
  21285. - public enum Orientation {
  21286. - TOP_LEFT(0),
  21287. - TOP_RIGHT(1),
  21288. - BOTTOM_RIGHT(2),
  21289. - BOTTOM_LEFT(3),
  21290. - LEFT_TOP(4),
  21291. - RIGHT_TOP(5),
  21292. - RIGHT_BOTTOM(6),
  21293. - LEFT_BOTTOM(7);
  21294. -
  21295. - private final int value;
  21296. -
  21297. - Orientation(int value) {
  21298. - this.value = value;
  21299. - }
  21300. -
  21301. - public int getValue() {
  21302. - return value;
  21303. - }
  21304. - };
  21305. -
  21306. - private static final Rect defaultRoi = new Rect();
  21307. - private static final Orientation DEFAULT_ORIENTATION = Orientation.TOP_LEFT;
  21308. -
  21309. - public abstract Rect getRoi();
  21310. -
  21311. - public abstract Orientation getOrientation();
  21312. -
  21313. - public static Builder builder() {
  21314. - return new AutoValue_ImageProcessingOptions.Builder()
  21315. - .setRoi(defaultRoi)
  21316. - .setOrientation(DEFAULT_ORIENTATION);
  21317. - }
  21318. -
  21319. - /** Builder for {@link ImageProcessingOptions}. */
  21320. - @AutoValue.Builder
  21321. - public abstract static class Builder {
  21322. -
  21323. /**
  21324. - * Sets the region of interest (ROI) of the image. Defaults to the entire image.
  21325. + * Orientation type that follows EXIF specification.
  21326. *
  21327. - * <p>Cropping according to this region of interest is prepended to the pre-processing
  21328. - * operations.
  21329. + * <p>The name of each enum value defines the position of the 0th row and the 0th column of the
  21330. + * image content. See the <a href="http://jpegclub.org/exif_orientation.html">EXIF orientation
  21331. + * documentation</a> for details.
  21332. */
  21333. - public abstract Builder setRoi(Rect roi);
  21334. + public enum Orientation {
  21335. + TOP_LEFT(0),
  21336. + TOP_RIGHT(1),
  21337. + BOTTOM_RIGHT(2),
  21338. + BOTTOM_LEFT(3),
  21339. + LEFT_TOP(4),
  21340. + RIGHT_TOP(5),
  21341. + RIGHT_BOTTOM(6),
  21342. + LEFT_BOTTOM(7);
  21343. +
  21344. + private final int value;
  21345. +
  21346. + Orientation(int value) {
  21347. + this.value = value;
  21348. + }
  21349. +
  21350. + public int getValue() {
  21351. + return value;
  21352. + }
  21353. + }
  21354. + ;
  21355. - /**
  21356. - * Sets the orientation of the image. Defaults to {@link Orientation#TOP_LEFT}.
  21357. - *
  21358. - * <p>Rotation will be applied accordingly so that inference is performed on an "upright" image.
  21359. - */
  21360. - public abstract Builder setOrientation(Orientation orientation);
  21361. + private static final Rect defaultRoi = new Rect();
  21362. + private static final Orientation DEFAULT_ORIENTATION = Orientation.TOP_LEFT;
  21363. - abstract Rect getRoi();
  21364. + public abstract Rect getRoi();
  21365. - abstract ImageProcessingOptions autoBuild();
  21366. + public abstract Orientation getOrientation();
  21367. +
  21368. + public static Builder builder() {
  21369. + return new AutoValue_ImageProcessingOptions.Builder()
  21370. + .setRoi(defaultRoi)
  21371. + .setOrientation(DEFAULT_ORIENTATION);
  21372. + }
  21373. - public ImageProcessingOptions build() {
  21374. - setRoi(new Rect(getRoi())); // Make a defensive copy, since Rect is mutable.
  21375. - return autoBuild();
  21376. + /** Builder for {@link ImageProcessingOptions}. */
  21377. + @AutoValue.Builder
  21378. + public abstract static class Builder {
  21379. + /**
  21380. + * Sets the region of interest (ROI) of the image. Defaults to the entire image.
  21381. + *
  21382. + * <p>Cropping according to this region of interest is prepended to the pre-processing
  21383. + * operations.
  21384. + */
  21385. + public abstract Builder setRoi(Rect roi);
  21386. +
  21387. + /**
  21388. + * Sets the orientation of the image. Defaults to {@link Orientation#TOP_LEFT}.
  21389. + *
  21390. + * <p>Rotation will be applied accordingly so that inference is performed on an "upright"
  21391. + * image.
  21392. + */
  21393. + public abstract Builder setOrientation(Orientation orientation);
  21394. +
  21395. + abstract Rect getRoi();
  21396. +
  21397. + abstract ImageProcessingOptions autoBuild();
  21398. +
  21399. + public ImageProcessingOptions build() {
  21400. + setRoi(new Rect(getRoi())); // Make a defensive copy, since Rect is mutable.
  21401. + return autoBuild();
  21402. + }
  21403. }
  21404. - }
  21405. }
  21406. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/processor/NearestNeighbor.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/processor/NearestNeighbor.java
  21407. index f5cc5af615117..a39247f1239c8 100644
  21408. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/processor/NearestNeighbor.java
  21409. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/processor/NearestNeighbor.java
  21410. @@ -16,37 +16,38 @@ limitations under the License.
  21411. package org.tensorflow.lite.task.processor;
  21412. import com.google.auto.value.AutoValue;
  21413. +
  21414. +import org.tensorflow.lite.task.core.annotations.UsedByReflection;
  21415. +
  21416. import java.nio.ByteBuffer;
  21417. import java.nio.ByteOrder;
  21418. -import org.tensorflow.lite.task.core.annotations.UsedByReflection;
  21419. /** Represents the search result of a Searcher model. */
  21420. @AutoValue
  21421. @UsedByReflection("searcher_jni.cc")
  21422. public abstract class NearestNeighbor {
  21423. -
  21424. - @UsedByReflection("searcher_jni.cc")
  21425. - static NearestNeighbor create(byte[] metadataArray, float distance) {
  21426. - // Convert byte[] metadataArray to ByteBuffer which handles endianess better.
  21427. - //
  21428. - // Ideally, the API should accept a ByteBuffer instead of a byte[]. However, converting byte[]
  21429. - // to ByteBuffer in JNI will lead to unnecessarily complex code which involves 6 more reflection
  21430. - // calls. We can make this method package private, because users in general shouldn't need to
  21431. - // create NearestNeighbor instances, but only consume the objects return from Task Library. This
  21432. - // API will be used mostly for internal purpose.
  21433. - ByteBuffer metadata = ByteBuffer.wrap(metadataArray);
  21434. - metadata.order(ByteOrder.nativeOrder());
  21435. - return new AutoValue_NearestNeighbor(metadata, distance);
  21436. - }
  21437. -
  21438. - /**
  21439. - * Gets the user-defined metadata about the result. This could be a label, a unique ID, a
  21440. - * serialized proto of some sort, etc.
  21441. - *
  21442. - * <p><b>Do not mutate</b> the returned metadata.
  21443. - */
  21444. - public abstract ByteBuffer getMetadata();
  21445. -
  21446. - /** Gets the distance score indicating how confident the result is. Lower is better. */
  21447. - public abstract float getDistance();
  21448. + @UsedByReflection("searcher_jni.cc")
  21449. + static NearestNeighbor create(byte[] metadataArray, float distance) {
  21450. + // Convert byte[] metadataArray to ByteBuffer which handles endianess better.
  21451. + //
  21452. + // Ideally, the API should accept a ByteBuffer instead of a byte[]. However, converting
  21453. + // byte[] to ByteBuffer in JNI will lead to unnecessarily complex code which involves 6 more
  21454. + // reflection calls. We can make this method package private, because users in general
  21455. + // shouldn't need to create NearestNeighbor instances, but only consume the objects return
  21456. + // from Task Library. This API will be used mostly for internal purpose.
  21457. + ByteBuffer metadata = ByteBuffer.wrap(metadataArray);
  21458. + metadata.order(ByteOrder.nativeOrder());
  21459. + return new AutoValue_NearestNeighbor(metadata, distance);
  21460. + }
  21461. +
  21462. + /**
  21463. + * Gets the user-defined metadata about the result. This could be a label, a unique ID, a
  21464. + * serialized proto of some sort, etc.
  21465. + *
  21466. + * <p><b>Do not mutate</b> the returned metadata.
  21467. + */
  21468. + public abstract ByteBuffer getMetadata();
  21469. +
  21470. + /** Gets the distance score indicating how confident the result is. Lower is better. */
  21471. + public abstract float getDistance();
  21472. }
  21473. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/processor/SearcherOptions.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/processor/SearcherOptions.java
  21474. index fa601edf92b30..86f5fdde0187c 100644
  21475. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/processor/SearcherOptions.java
  21476. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/processor/SearcherOptions.java
  21477. @@ -16,66 +16,68 @@ limitations under the License.
  21478. package org.tensorflow.lite.task.processor;
  21479. import androidx.annotation.Nullable;
  21480. +
  21481. import com.google.auto.value.AutoValue;
  21482. +
  21483. import java.io.File;
  21484. /** Options to configure Searcher API. */
  21485. @AutoValue
  21486. public abstract class SearcherOptions {
  21487. - private static final boolean DEFAULT_L2_NORMALIZE = false;
  21488. - private static final boolean DEFAULT_QUANTIZE = false;
  21489. - private static final int DEFAULT_MAX_RESULTS = 5;
  21490. -
  21491. - public abstract boolean getL2Normalize();
  21492. -
  21493. - public abstract boolean getQuantize();
  21494. -
  21495. - @Nullable
  21496. - public abstract File getIndexFile();
  21497. -
  21498. - public abstract int getMaxResults();
  21499. -
  21500. - public static Builder builder() {
  21501. - return new AutoValue_SearcherOptions.Builder()
  21502. - .setL2Normalize(DEFAULT_L2_NORMALIZE)
  21503. - .setQuantize(DEFAULT_QUANTIZE)
  21504. - .setIndexFile(null)
  21505. - .setMaxResults(DEFAULT_MAX_RESULTS);
  21506. - }
  21507. -
  21508. - /** Builder for {@link SearcherOptions}. */
  21509. - @AutoValue.Builder
  21510. - public abstract static class Builder {
  21511. - /**
  21512. - * Sets whether to normalize the embedding feature vector with L2 norm. Defaults to false.
  21513. - *
  21514. - * <p>Use this option only if the model does not already contain a native L2_NORMALIZATION
  21515. - * TFLite Op. In most cases, this is already the case and L2 norm is thus achieved through
  21516. - * TFLite inference.
  21517. - */
  21518. - public abstract Builder setL2Normalize(boolean l2Normalize);
  21519. -
  21520. - /**
  21521. - * Sets whether the embedding should be quantized to bytes via scalar quantization. Defaults to
  21522. - * false.
  21523. - *
  21524. - * <p>Embeddings are implicitly assumed to be unit-norm and therefore any dimension is
  21525. - * guaranteed to have a value in {@code [-1.0, 1.0]}. Use the l2_normalize option if this is not
  21526. - * the case.
  21527. - */
  21528. - public abstract Builder setQuantize(boolean quantize);
  21529. -
  21530. - /**
  21531. - * Sets the index file to search into.
  21532. - *
  21533. - * <p>Required if the model does not come with an index file inside. Otherwise, it can be ignore
  21534. - * by setting to {@code null}.
  21535. - */
  21536. - public abstract Builder setIndexFile(@Nullable File indexFile);
  21537. -
  21538. - /** Sets the maximum number of nearest neighbor results to return. Defaults to {@code 5} */
  21539. - public abstract Builder setMaxResults(int maxResults);
  21540. -
  21541. - public abstract SearcherOptions build();
  21542. - }
  21543. + private static final boolean DEFAULT_L2_NORMALIZE = false;
  21544. + private static final boolean DEFAULT_QUANTIZE = false;
  21545. + private static final int DEFAULT_MAX_RESULTS = 5;
  21546. +
  21547. + public abstract boolean getL2Normalize();
  21548. +
  21549. + public abstract boolean getQuantize();
  21550. +
  21551. + @Nullable
  21552. + public abstract File getIndexFile();
  21553. +
  21554. + public abstract int getMaxResults();
  21555. +
  21556. + public static Builder builder() {
  21557. + return new AutoValue_SearcherOptions.Builder()
  21558. + .setL2Normalize(DEFAULT_L2_NORMALIZE)
  21559. + .setQuantize(DEFAULT_QUANTIZE)
  21560. + .setIndexFile(null)
  21561. + .setMaxResults(DEFAULT_MAX_RESULTS);
  21562. + }
  21563. +
  21564. + /** Builder for {@link SearcherOptions}. */
  21565. + @AutoValue.Builder
  21566. + public abstract static class Builder {
  21567. + /**
  21568. + * Sets whether to normalize the embedding feature vector with L2 norm. Defaults to false.
  21569. + *
  21570. + * <p>Use this option only if the model does not already contain a native L2_NORMALIZATION
  21571. + * TFLite Op. In most cases, this is already the case and L2 norm is thus achieved through
  21572. + * TFLite inference.
  21573. + */
  21574. + public abstract Builder setL2Normalize(boolean l2Normalize);
  21575. +
  21576. + /**
  21577. + * Sets whether the embedding should be quantized to bytes via scalar quantization. Defaults
  21578. + * to false.
  21579. + *
  21580. + * <p>Embeddings are implicitly assumed to be unit-norm and therefore any dimension is
  21581. + * guaranteed to have a value in {@code [-1.0, 1.0]}. Use the l2_normalize option if this is
  21582. + * not the case.
  21583. + */
  21584. + public abstract Builder setQuantize(boolean quantize);
  21585. +
  21586. + /**
  21587. + * Sets the index file to search into.
  21588. + *
  21589. + * <p>Required if the model does not come with an index file inside. Otherwise, it can be
  21590. + * ignore by setting to {@code null}.
  21591. + */
  21592. + public abstract Builder setIndexFile(@Nullable File indexFile);
  21593. +
  21594. + /** Sets the maximum number of nearest neighbor results to return. Defaults to {@code 5} */
  21595. + public abstract Builder setMaxResults(int maxResults);
  21596. +
  21597. + public abstract SearcherOptions build();
  21598. + }
  21599. }
  21600. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifier.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifier.java
  21601. index 55743055ff408..070b945e72b90 100644
  21602. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifier.java
  21603. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifier.java
  21604. @@ -17,12 +17,9 @@ package org.tensorflow.lite.task.text.nlclassifier;
  21605. import android.content.Context;
  21606. import android.os.ParcelFileDescriptor;
  21607. +
  21608. import com.google.auto.value.AutoValue;
  21609. -import java.io.File;
  21610. -import java.io.IOException;
  21611. -import java.nio.ByteBuffer;
  21612. -import java.nio.MappedByteBuffer;
  21613. -import java.util.List;
  21614. +
  21615. import org.tensorflow.lite.support.label.Category;
  21616. import org.tensorflow.lite.task.core.BaseOptions;
  21617. import org.tensorflow.lite.task.core.BaseTaskApi;
  21618. @@ -30,6 +27,12 @@ import org.tensorflow.lite.task.core.TaskJniUtils;
  21619. import org.tensorflow.lite.task.core.TaskJniUtils.EmptyHandleProvider;
  21620. import org.tensorflow.lite.task.core.annotations.UsedByReflection;
  21621. +import java.io.File;
  21622. +import java.io.IOException;
  21623. +import java.nio.ByteBuffer;
  21624. +import java.nio.MappedByteBuffer;
  21625. +import java.util.List;
  21626. +
  21627. /**
  21628. * Classifier API for NLClassification tasks with Bert models, categorizes string into different
  21629. * classes. The API expects a Bert based TFLite model with metadata populated.
  21630. @@ -45,209 +48,199 @@ import org.tensorflow.lite.task.core.annotations.UsedByReflection;
  21631. * </ul>
  21632. */
  21633. public class BertNLClassifier extends BaseTaskApi {
  21634. + private static final String BERT_NL_CLASSIFIER_NATIVE_LIBNAME = "task_text_jni";
  21635. +
  21636. + /** Options to configure BertNLClassifier. */
  21637. + @AutoValue
  21638. + @UsedByReflection("bert_nl_classifier_jni.cc")
  21639. + public abstract static class BertNLClassifierOptions {
  21640. + static final int DEFAULT_MAX_SEQ_LEN = 128;
  21641. +
  21642. + abstract int getMaxSeqLen();
  21643. +
  21644. + abstract BaseOptions getBaseOptions();
  21645. +
  21646. + public static Builder builder() {
  21647. + return new AutoValue_BertNLClassifier_BertNLClassifierOptions.Builder()
  21648. + .setMaxSeqLen(DEFAULT_MAX_SEQ_LEN)
  21649. + .setBaseOptions(BaseOptions.builder().build());
  21650. + }
  21651. +
  21652. + /** Builder for {@link BertNLClassifierOptions}. */
  21653. + @AutoValue.Builder
  21654. + public abstract static class Builder {
  21655. + /** Sets the general options to configure Task APIs, such as accelerators. */
  21656. + public abstract Builder setBaseOptions(BaseOptions baseOptions);
  21657. +
  21658. + /**
  21659. + * Set the maximum sequence length.
  21660. + *
  21661. + * @deprecated maximum sequence length is now read from the model (i.e. input tensor
  21662. + * size)
  21663. + * automatically
  21664. + */
  21665. + @Deprecated
  21666. + public abstract Builder setMaxSeqLen(int value);
  21667. +
  21668. + public abstract BertNLClassifierOptions build();
  21669. + }
  21670. + }
  21671. - private static final String BERT_NL_CLASSIFIER_NATIVE_LIBNAME = "task_text_jni";
  21672. -
  21673. - /** Options to configure BertNLClassifier. */
  21674. - @AutoValue
  21675. - @UsedByReflection("bert_nl_classifier_jni.cc")
  21676. - public abstract static class BertNLClassifierOptions {
  21677. - static final int DEFAULT_MAX_SEQ_LEN = 128;
  21678. -
  21679. - abstract int getMaxSeqLen();
  21680. + /**
  21681. + * Creates {@link BertNLClassifier} from a model file with metadata and default {@link
  21682. + * BertNLClassifierOptions}.
  21683. + *
  21684. + * @param context Android context
  21685. + * @param modelPath Path to the classification model
  21686. + * @return a {@link BertNLClassifier} instance
  21687. + * @throws IOException If model file fails to load
  21688. + * @throws IllegalArgumentException if an argument is invalid
  21689. + * @throws IllegalStateException if there is an internal error
  21690. + * @throws RuntimeException if there is an otherwise unspecified error
  21691. + */
  21692. + public static BertNLClassifier createFromFile(final Context context, final String modelPath)
  21693. + throws IOException {
  21694. + return createFromBuffer(TaskJniUtils.loadMappedFile(context, modelPath));
  21695. + }
  21696. - abstract BaseOptions getBaseOptions();
  21697. + /**
  21698. + * Creates {@link BertNLClassifier} from a {@link File} object with metadata and default {@link
  21699. + * BertNLClassifierOptions}.
  21700. + *
  21701. + * @param modelFile The classification model {@link File} instance
  21702. + * @return a {@link BertNLClassifier} instance
  21703. + * @throws IOException If model file fails to load
  21704. + * @throws IllegalArgumentException if an argument is invalid
  21705. + * @throws IllegalStateException if there is an internal error
  21706. + * @throws RuntimeException if there is an otherwise unspecified error
  21707. + */
  21708. + public static BertNLClassifier createFromFile(File modelFile) throws IOException {
  21709. + return createFromFileAndOptions(modelFile, BertNLClassifierOptions.builder().build());
  21710. + }
  21711. - public static Builder builder() {
  21712. - return new AutoValue_BertNLClassifier_BertNLClassifierOptions.Builder()
  21713. - .setMaxSeqLen(DEFAULT_MAX_SEQ_LEN)
  21714. - .setBaseOptions(BaseOptions.builder().build());
  21715. + /**
  21716. + * Creates {@link BertNLClassifier} from a model file with metadata and {@link
  21717. + * BertNLClassifierOptions}.
  21718. + *
  21719. + * @param context Android context.
  21720. + * @param modelPath Path to the classification model
  21721. + * @param options to configure the classifier
  21722. + * @return a {@link BertNLClassifier} instance
  21723. + * @throws IOException If model file fails to load
  21724. + * @throws IllegalArgumentException if an argument is invalid
  21725. + * @throws IllegalStateException if there is an internal error
  21726. + * @throws RuntimeException if there is an otherwise unspecified error
  21727. + */
  21728. + public static BertNLClassifier createFromFileAndOptions(final Context context,
  21729. + final String modelPath, BertNLClassifierOptions options) throws IOException {
  21730. + return createFromBufferAndOptions(TaskJniUtils.loadMappedFile(context, modelPath), options);
  21731. }
  21732. - /** Builder for {@link BertNLClassifierOptions}. */
  21733. - @AutoValue.Builder
  21734. - public abstract static class Builder {
  21735. + /**
  21736. + * Creates {@link BertNLClassifier} from a {@link File} object with metadata and {@link
  21737. + * BertNLClassifierOptions}.
  21738. + *
  21739. + * @param modelFile The classification model {@link File} instance
  21740. + * @param options to configure the classifier
  21741. + * @return a {@link BertNLClassifier} instance
  21742. + * @throws IOException If model file fails to load
  21743. + * @throws IllegalArgumentException if an argument is invalid
  21744. + * @throws IllegalStateException if there is an internal error
  21745. + * @throws RuntimeException if there is an otherwise unspecified error
  21746. + */
  21747. + public static BertNLClassifier createFromFileAndOptions(
  21748. + File modelFile, final BertNLClassifierOptions options) throws IOException {
  21749. + try (ParcelFileDescriptor descriptor =
  21750. + ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
  21751. + return new BertNLClassifier(
  21752. + TaskJniUtils.createHandleFromLibrary(new EmptyHandleProvider() {
  21753. + @Override
  21754. + public long createHandle() {
  21755. + return initJniWithFileDescriptor(descriptor.getFd(), options,
  21756. + TaskJniUtils.createProtoBaseOptionsHandle(
  21757. + options.getBaseOptions()));
  21758. + }
  21759. + }, BERT_NL_CLASSIFIER_NATIVE_LIBNAME));
  21760. + }
  21761. + }
  21762. - /** Sets the general options to configure Task APIs, such as accelerators. */
  21763. - public abstract Builder setBaseOptions(BaseOptions baseOptions);
  21764. + /**
  21765. + * Creates {@link BertNLClassifier} with a model buffer and default {@link
  21766. + * BertNLClassifierOptions}.
  21767. + *
  21768. + * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the model
  21769. + * @return a {@link BertNLClassifier} instance
  21770. + * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
  21771. + * {@link MappedByteBuffer}
  21772. + * @throws IllegalStateException if there is an internal error
  21773. + * @throws RuntimeException if there is an otherwise unspecified error
  21774. + */
  21775. + public static BertNLClassifier createFromBuffer(final ByteBuffer modelBuffer) {
  21776. + return createFromBufferAndOptions(modelBuffer, BertNLClassifierOptions.builder().build());
  21777. + }
  21778. - /**
  21779. - * Set the maximum sequence length.
  21780. - *
  21781. - * @deprecated maximum sequence length is now read from the model (i.e. input tensor size)
  21782. - * automatically
  21783. - */
  21784. - @Deprecated
  21785. - public abstract Builder setMaxSeqLen(int value);
  21786. + /**
  21787. + * Creates {@link BertNLClassifier} with a model buffer and {@link BertNLClassifierOptions}.
  21788. + *
  21789. + * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the model
  21790. + * @param options to configure the classifier
  21791. + * @return a {@link BertNLClassifier} instance
  21792. + * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
  21793. + * {@link MappedByteBuffer}
  21794. + * @throws IllegalStateException if there is an internal error
  21795. + * @throws RuntimeException if there is an otherwise unspecified error
  21796. + */
  21797. + public static BertNLClassifier createFromBufferAndOptions(
  21798. + final ByteBuffer modelBuffer, final BertNLClassifierOptions options) {
  21799. + if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) {
  21800. + throw new IllegalArgumentException(
  21801. + "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer.");
  21802. + }
  21803. + return new BertNLClassifier(TaskJniUtils.createHandleFromLibrary(new EmptyHandleProvider() {
  21804. + @Override
  21805. + public long createHandle() {
  21806. + return initJniWithByteBuffer(modelBuffer, options,
  21807. + TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions()));
  21808. + }
  21809. + }, BERT_NL_CLASSIFIER_NATIVE_LIBNAME));
  21810. + }
  21811. - public abstract BertNLClassifierOptions build();
  21812. + /**
  21813. + * Performs classification on a string input, returns classified {@link Category}s.
  21814. + *
  21815. + * @param text input text to the model.
  21816. + * @return A list of Category results.
  21817. + */
  21818. + public List<Category> classify(String text) {
  21819. + return classifyNative(getNativeHandle(), text);
  21820. }
  21821. - }
  21822. -
  21823. - /**
  21824. - * Creates {@link BertNLClassifier} from a model file with metadata and default {@link
  21825. - * BertNLClassifierOptions}.
  21826. - *
  21827. - * @param context Android context
  21828. - * @param modelPath Path to the classification model
  21829. - * @return a {@link BertNLClassifier} instance
  21830. - * @throws IOException If model file fails to load
  21831. - * @throws IllegalArgumentException if an argument is invalid
  21832. - * @throws IllegalStateException if there is an internal error
  21833. - * @throws RuntimeException if there is an otherwise unspecified error
  21834. - */
  21835. - public static BertNLClassifier createFromFile(final Context context, final String modelPath)
  21836. - throws IOException {
  21837. - return createFromBuffer(TaskJniUtils.loadMappedFile(context, modelPath));
  21838. - }
  21839. -
  21840. - /**
  21841. - * Creates {@link BertNLClassifier} from a {@link File} object with metadata and default {@link
  21842. - * BertNLClassifierOptions}.
  21843. - *
  21844. - * @param modelFile The classification model {@link File} instance
  21845. - * @return a {@link BertNLClassifier} instance
  21846. - * @throws IOException If model file fails to load
  21847. - * @throws IllegalArgumentException if an argument is invalid
  21848. - * @throws IllegalStateException if there is an internal error
  21849. - * @throws RuntimeException if there is an otherwise unspecified error
  21850. - */
  21851. - public static BertNLClassifier createFromFile(File modelFile) throws IOException {
  21852. - return createFromFileAndOptions(modelFile, BertNLClassifierOptions.builder().build());
  21853. - }
  21854. -
  21855. - /**
  21856. - * Creates {@link BertNLClassifier} from a model file with metadata and {@link
  21857. - * BertNLClassifierOptions}.
  21858. - *
  21859. - * @param context Android context.
  21860. - * @param modelPath Path to the classification model
  21861. - * @param options to configure the classifier
  21862. - * @return a {@link BertNLClassifier} instance
  21863. - * @throws IOException If model file fails to load
  21864. - * @throws IllegalArgumentException if an argument is invalid
  21865. - * @throws IllegalStateException if there is an internal error
  21866. - * @throws RuntimeException if there is an otherwise unspecified error
  21867. - */
  21868. - public static BertNLClassifier createFromFileAndOptions(
  21869. - final Context context, final String modelPath, BertNLClassifierOptions options)
  21870. - throws IOException {
  21871. - return createFromBufferAndOptions(TaskJniUtils.loadMappedFile(context, modelPath), options);
  21872. - }
  21873. -
  21874. - /**
  21875. - * Creates {@link BertNLClassifier} from a {@link File} object with metadata and {@link
  21876. - * BertNLClassifierOptions}.
  21877. - *
  21878. - * @param modelFile The classification model {@link File} instance
  21879. - * @param options to configure the classifier
  21880. - * @return a {@link BertNLClassifier} instance
  21881. - * @throws IOException If model file fails to load
  21882. - * @throws IllegalArgumentException if an argument is invalid
  21883. - * @throws IllegalStateException if there is an internal error
  21884. - * @throws RuntimeException if there is an otherwise unspecified error
  21885. - */
  21886. - public static BertNLClassifier createFromFileAndOptions(
  21887. - File modelFile, final BertNLClassifierOptions options) throws IOException {
  21888. - try (ParcelFileDescriptor descriptor =
  21889. - ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
  21890. - return new BertNLClassifier(
  21891. - TaskJniUtils.createHandleFromLibrary(
  21892. - new EmptyHandleProvider() {
  21893. - @Override
  21894. - public long createHandle() {
  21895. - return initJniWithFileDescriptor(
  21896. - descriptor.getFd(),
  21897. - options,
  21898. - TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions()));
  21899. - }
  21900. - },
  21901. - BERT_NL_CLASSIFIER_NATIVE_LIBNAME));
  21902. +
  21903. + /**
  21904. + * Constructor to initialize the JNI with a pointer from C++.
  21905. + *
  21906. + * @param nativeHandle a pointer referencing memory allocated in C++.
  21907. + */
  21908. + private BertNLClassifier(long nativeHandle) {
  21909. + super(nativeHandle);
  21910. }
  21911. - }
  21912. -
  21913. - /**
  21914. - * Creates {@link BertNLClassifier} with a model buffer and default {@link
  21915. - * BertNLClassifierOptions}.
  21916. - *
  21917. - * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the model
  21918. - * @return a {@link BertNLClassifier} instance
  21919. - * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
  21920. - * {@link MappedByteBuffer}
  21921. - * @throws IllegalStateException if there is an internal error
  21922. - * @throws RuntimeException if there is an otherwise unspecified error
  21923. - */
  21924. - public static BertNLClassifier createFromBuffer(final ByteBuffer modelBuffer) {
  21925. - return createFromBufferAndOptions(modelBuffer, BertNLClassifierOptions.builder().build());
  21926. - }
  21927. -
  21928. - /**
  21929. - * Creates {@link BertNLClassifier} with a model buffer and {@link BertNLClassifierOptions}.
  21930. - *
  21931. - * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the model
  21932. - * @param options to configure the classifier
  21933. - * @return a {@link BertNLClassifier} instance
  21934. - * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
  21935. - * {@link MappedByteBuffer}
  21936. - * @throws IllegalStateException if there is an internal error
  21937. - * @throws RuntimeException if there is an otherwise unspecified error
  21938. - */
  21939. - public static BertNLClassifier createFromBufferAndOptions(
  21940. - final ByteBuffer modelBuffer, final BertNLClassifierOptions options) {
  21941. - if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) {
  21942. - throw new IllegalArgumentException(
  21943. - "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer.");
  21944. +
  21945. + private static native long initJniWithByteBuffer(
  21946. + ByteBuffer modelBuffer, BertNLClassifierOptions options, long baseOptionsHandle);
  21947. +
  21948. + private static native long initJniWithFileDescriptor(
  21949. + int fd, BertNLClassifierOptions options, long baseOptionsHandle);
  21950. +
  21951. + private static native List<Category> classifyNative(long nativeHandle, String text);
  21952. +
  21953. + @Override
  21954. + protected void deinit(long nativeHandle) {
  21955. + deinitJni(nativeHandle);
  21956. }
  21957. - return new BertNLClassifier(
  21958. - TaskJniUtils.createHandleFromLibrary(
  21959. - new EmptyHandleProvider() {
  21960. - @Override
  21961. - public long createHandle() {
  21962. - return initJniWithByteBuffer(
  21963. - modelBuffer,
  21964. - options,
  21965. - TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions()));
  21966. - }
  21967. - },
  21968. - BERT_NL_CLASSIFIER_NATIVE_LIBNAME));
  21969. - }
  21970. -
  21971. - /**
  21972. - * Performs classification on a string input, returns classified {@link Category}s.
  21973. - *
  21974. - * @param text input text to the model.
  21975. - * @return A list of Category results.
  21976. - */
  21977. - public List<Category> classify(String text) {
  21978. - return classifyNative(getNativeHandle(), text);
  21979. - }
  21980. -
  21981. - /**
  21982. - * Constructor to initialize the JNI with a pointer from C++.
  21983. - *
  21984. - * @param nativeHandle a pointer referencing memory allocated in C++.
  21985. - */
  21986. - private BertNLClassifier(long nativeHandle) {
  21987. - super(nativeHandle);
  21988. - }
  21989. -
  21990. - private static native long initJniWithByteBuffer(
  21991. - ByteBuffer modelBuffer, BertNLClassifierOptions options, long baseOptionsHandle);
  21992. -
  21993. - private static native long initJniWithFileDescriptor(
  21994. - int fd, BertNLClassifierOptions options, long baseOptionsHandle);
  21995. -
  21996. - private static native List<Category> classifyNative(long nativeHandle, String text);
  21997. -
  21998. - @Override
  21999. - protected void deinit(long nativeHandle) {
  22000. - deinitJni(nativeHandle);
  22001. - }
  22002. -
  22003. - /**
  22004. - * Native implementation to release memory pointed by the pointer.
  22005. - *
  22006. - * @param nativeHandle pointer to memory allocated
  22007. - */
  22008. - private native void deinitJni(long nativeHandle);
  22009. +
  22010. + /**
  22011. + * Native implementation to release memory pointed by the pointer.
  22012. + *
  22013. + * @param nativeHandle pointer to memory allocated
  22014. + */
  22015. + private native void deinitJni(long nativeHandle);
  22016. }
  22017. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/NLClassifier.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/NLClassifier.java
  22018. index 19dcffca5e697..5c3eb2c9e3768 100644
  22019. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/NLClassifier.java
  22020. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/NLClassifier.java
  22021. @@ -17,13 +17,11 @@ package org.tensorflow.lite.task.text.nlclassifier;
  22022. import android.content.Context;
  22023. import android.os.ParcelFileDescriptor;
  22024. +
  22025. import androidx.annotation.Nullable;
  22026. +
  22027. import com.google.auto.value.AutoValue;
  22028. -import java.io.File;
  22029. -import java.io.IOException;
  22030. -import java.nio.ByteBuffer;
  22031. -import java.nio.MappedByteBuffer;
  22032. -import java.util.List;
  22033. +
  22034. import org.tensorflow.lite.support.label.Category;
  22035. import org.tensorflow.lite.task.core.BaseOptions;
  22036. import org.tensorflow.lite.task.core.BaseTaskApi;
  22037. @@ -31,6 +29,12 @@ import org.tensorflow.lite.task.core.TaskJniUtils;
  22038. import org.tensorflow.lite.task.core.TaskJniUtils.EmptyHandleProvider;
  22039. import org.tensorflow.lite.task.core.annotations.UsedByReflection;
  22040. +import java.io.File;
  22041. +import java.io.IOException;
  22042. +import java.nio.ByteBuffer;
  22043. +import java.nio.MappedByteBuffer;
  22044. +import java.util.List;
  22045. +
  22046. /**
  22047. * Classifier API for natural language classification tasks, categorizes string into different
  22048. * classes.
  22049. @@ -67,294 +71,296 @@ import org.tensorflow.lite.task.core.annotations.UsedByReflection;
  22050. * configurable for different TFLite models.
  22051. */
  22052. public class NLClassifier extends BaseTaskApi {
  22053. -
  22054. - /** Options to identify input and output tensors of the model. */
  22055. - @AutoValue
  22056. - @UsedByReflection("nl_classifier_jni.cc")
  22057. - public abstract static class NLClassifierOptions {
  22058. - private static final int DEFAULT_INPUT_TENSOR_INDEX = 0;
  22059. - private static final int DEFAULT_OUTPUT_SCORE_TENSOR_INDEX = 0;
  22060. - // By default there is no output label tensor. The label file can be attached
  22061. - // to the output score tensor metadata.
  22062. - private static final int DEFAULT_OUTPUT_LABEL_TENSOR_INDEX = -1;
  22063. - private static final String DEFAULT_INPUT_TENSOR_NAME = "INPUT";
  22064. - private static final String DEFAULT_OUTPUT_SCORE_TENSOR_NAME = "OUTPUT_SCORE";
  22065. - private static final String DEFAULT_OUTPUT_LABEL_TENSOR_NAME = "OUTPUT_LABEL";
  22066. -
  22067. - @UsedByReflection("nl_classifier_jni.cc")
  22068. - abstract int getInputTensorIndex();
  22069. -
  22070. - @UsedByReflection("nl_classifier_jni.cc")
  22071. - abstract int getOutputScoreTensorIndex();
  22072. -
  22073. + /** Options to identify input and output tensors of the model. */
  22074. + @AutoValue
  22075. @UsedByReflection("nl_classifier_jni.cc")
  22076. - abstract int getOutputLabelTensorIndex();
  22077. -
  22078. - @UsedByReflection("nl_classifier_jni.cc")
  22079. - abstract String getInputTensorName();
  22080. + public abstract static class NLClassifierOptions {
  22081. + private static final int DEFAULT_INPUT_TENSOR_INDEX = 0;
  22082. + private static final int DEFAULT_OUTPUT_SCORE_TENSOR_INDEX = 0;
  22083. + // By default there is no output label tensor. The label file can be attached
  22084. + // to the output score tensor metadata.
  22085. + private static final int DEFAULT_OUTPUT_LABEL_TENSOR_INDEX = -1;
  22086. + private static final String DEFAULT_INPUT_TENSOR_NAME = "INPUT";
  22087. + private static final String DEFAULT_OUTPUT_SCORE_TENSOR_NAME = "OUTPUT_SCORE";
  22088. + private static final String DEFAULT_OUTPUT_LABEL_TENSOR_NAME = "OUTPUT_LABEL";
  22089. +
  22090. + @UsedByReflection("nl_classifier_jni.cc")
  22091. + abstract int getInputTensorIndex();
  22092. +
  22093. + @UsedByReflection("nl_classifier_jni.cc")
  22094. + abstract int getOutputScoreTensorIndex();
  22095. +
  22096. + @UsedByReflection("nl_classifier_jni.cc")
  22097. + abstract int getOutputLabelTensorIndex();
  22098. +
  22099. + @UsedByReflection("nl_classifier_jni.cc")
  22100. + abstract String getInputTensorName();
  22101. +
  22102. + @UsedByReflection("nl_classifier_jni.cc")
  22103. + abstract String getOutputScoreTensorName();
  22104. +
  22105. + @UsedByReflection("nl_classifier_jni.cc")
  22106. + abstract String getOutputLabelTensorName();
  22107. +
  22108. + @Nullable
  22109. + abstract BaseOptions getBaseOptions();
  22110. +
  22111. + public static Builder builder() {
  22112. + return new AutoValue_NLClassifier_NLClassifierOptions.Builder()
  22113. + .setInputTensorIndex(DEFAULT_INPUT_TENSOR_INDEX)
  22114. + .setOutputScoreTensorIndex(DEFAULT_OUTPUT_SCORE_TENSOR_INDEX)
  22115. + .setOutputLabelTensorIndex(DEFAULT_OUTPUT_LABEL_TENSOR_INDEX)
  22116. + .setInputTensorName(DEFAULT_INPUT_TENSOR_NAME)
  22117. + .setOutputScoreTensorName(DEFAULT_OUTPUT_SCORE_TENSOR_NAME)
  22118. + .setOutputLabelTensorName(DEFAULT_OUTPUT_LABEL_TENSOR_NAME);
  22119. + }
  22120. +
  22121. + /** Builder for {@link NLClassifierOptions}. */
  22122. + @AutoValue.Builder
  22123. + public abstract static class Builder {
  22124. + /** Sets the general options to configure Task APIs, such as accelerators. */
  22125. + public abstract Builder setBaseOptions(@Nullable BaseOptions baseOptions);
  22126. +
  22127. + /**
  22128. + * Configure the input/output tensors for NLClassifier:
  22129. + *
  22130. + * <p>- No special configuration is needed if the model has only one input tensor and
  22131. + * one output tensor.
  22132. + *
  22133. + * <p>- When the model has multiple input or output tensors, use the following
  22134. + * configurations to specifiy the desired tensors: <br>
  22135. + * -- tensor names: {@code inputTensorName}, {@code outputScoreTensorName}, {@code
  22136. + * outputLabelTensorName}<br>
  22137. + * -- tensor indices: {@code inputTensorIndex}, {@code outputScoreTensorIndex}, {@code
  22138. + * outputLabelTensorIndex} <br>
  22139. + * Tensor names has higher priorities than tensor indices in locating the tensors. It
  22140. + * means the tensors will be first located according to tensor names. If not found, then
  22141. + * the tensors will be located according to tensor indices.
  22142. + *
  22143. + * <p>- Failing to match the input text tensor or output score tensor with neither
  22144. + * tensor names nor tensor indices will trigger a runtime error. However, failing to
  22145. + * locate the output label tensor will not trigger an error because the label tensor is
  22146. + * optional.
  22147. + */
  22148. +
  22149. + /**
  22150. + * Set the name of the input text tensor, if the model has multiple inputs. Only the
  22151. + * input tensor specified will be used for inference; other input tensors will be
  22152. + * ignored. Dafualt to {@code "INPUT"}.
  22153. + *
  22154. + * <p>See the section, Configure the input/output tensors for NLClassifier, for more
  22155. + * details.
  22156. + */
  22157. + public abstract Builder setInputTensorName(String inputTensorName);
  22158. +
  22159. + /**
  22160. + * Set the name of the output score tensor, if the model has multiple outputs. Dafualt
  22161. + * to
  22162. + * {@code "OUTPUT_SCORE"}.
  22163. + *
  22164. + * <p>See the section, Configure the input/output tensors for NLClassifier, for more
  22165. + * details.
  22166. + */
  22167. + public abstract Builder setOutputScoreTensorName(String outputScoreTensorName);
  22168. +
  22169. + /**
  22170. + * Set the name of the output label tensor, if the model has multiple outputs. Dafualt
  22171. + * to
  22172. + * {@code "OUTPUT_LABEL"}.
  22173. + *
  22174. + * <p>See the section, Configure the input/output tensors for NLClassifier, for more
  22175. + * details.
  22176. + *
  22177. + * <p>By default, label file should be packed with the output score tensor through Model
  22178. + * Metadata. See the <a
  22179. + * href="https://www.tensorflow.org/lite/convert/metadata_writer_tutorial#natural_language_classifiers">MetadataWriter
  22180. + * for NLClassifier</a>. NLClassifier reads and parses labels from the label file
  22181. + * automatically. However, some models may output a specific label tensor instead. In
  22182. + * this case, NLClassifier reads labels from the output label tensor.
  22183. + */
  22184. + public abstract Builder setOutputLabelTensorName(String outputLabelTensorName);
  22185. +
  22186. + /**
  22187. + * Set the index of the input text tensor among all input tensors, if the model has
  22188. + * multiple inputs. Only the input tensor specified will be used for inference; other
  22189. + * input tensors will be ignored. Dafualt to 0.
  22190. + *
  22191. + * <p>See the section, Configure the input/output tensors for NLClassifier, for more
  22192. + * details.
  22193. + */
  22194. + public abstract Builder setInputTensorIndex(int inputTensorIndex);
  22195. +
  22196. + /**
  22197. + * Set the index of the output score tensor among all output tensors, if the model has
  22198. + * multiple outputs. Dafualt to 0.
  22199. + *
  22200. + * <p>See the section, Configure the input/output tensors for NLClassifier, for more
  22201. + * details.
  22202. + */
  22203. + public abstract Builder setOutputScoreTensorIndex(int outputScoreTensorIndex);
  22204. +
  22205. + /**
  22206. + * Set the index of the optional output label tensor among all output tensors, if the
  22207. + * model has multiple outputs.
  22208. + *
  22209. + * <p>See the document above {@code outputLabelTensorName} for more information about
  22210. + * what the output label tensor is.
  22211. + *
  22212. + * <p>See the section, Configure the input/output tensors for NLClassifier, for more
  22213. + * details.
  22214. + *
  22215. + * <p>{@code outputLabelTensorIndex} dafualts to -1, meaning to disable the output label
  22216. + * tensor.
  22217. + */
  22218. + public abstract Builder setOutputLabelTensorIndex(int outputLabelTensorIndex);
  22219. +
  22220. + public abstract NLClassifierOptions build();
  22221. + }
  22222. + }
  22223. - @UsedByReflection("nl_classifier_jni.cc")
  22224. - abstract String getOutputScoreTensorName();
  22225. + private static final String NL_CLASSIFIER_NATIVE_LIBNAME = "task_text_jni";
  22226. +
  22227. + /**
  22228. + * Creates {@link NLClassifier} from default {@link NLClassifierOptions}.
  22229. + *
  22230. + * @param context Android context
  22231. + * @param modelPath path to the classification model relative to asset dir
  22232. + * @return an {@link NLClassifier} instance
  22233. + * @throws IOException if model file fails to load
  22234. + * @throws IllegalArgumentException if an argument is invalid
  22235. + * @throws IllegalStateException if there is an internal error
  22236. + * @throws RuntimeException if there is an otherwise unspecified error
  22237. + */
  22238. + public static NLClassifier createFromFile(Context context, String modelPath)
  22239. + throws IOException {
  22240. + return createFromFileAndOptions(context, modelPath, NLClassifierOptions.builder().build());
  22241. + }
  22242. - @UsedByReflection("nl_classifier_jni.cc")
  22243. - abstract String getOutputLabelTensorName();
  22244. -
  22245. - @Nullable
  22246. - abstract BaseOptions getBaseOptions();
  22247. -
  22248. - public static Builder builder() {
  22249. - return new AutoValue_NLClassifier_NLClassifierOptions.Builder()
  22250. - .setInputTensorIndex(DEFAULT_INPUT_TENSOR_INDEX)
  22251. - .setOutputScoreTensorIndex(DEFAULT_OUTPUT_SCORE_TENSOR_INDEX)
  22252. - .setOutputLabelTensorIndex(DEFAULT_OUTPUT_LABEL_TENSOR_INDEX)
  22253. - .setInputTensorName(DEFAULT_INPUT_TENSOR_NAME)
  22254. - .setOutputScoreTensorName(DEFAULT_OUTPUT_SCORE_TENSOR_NAME)
  22255. - .setOutputLabelTensorName(DEFAULT_OUTPUT_LABEL_TENSOR_NAME);
  22256. + /**
  22257. + * Creates {@link NLClassifier} from default {@link NLClassifierOptions}.
  22258. + *
  22259. + * @param modelFile the classification model {@link File} instance
  22260. + * @return an {@link NLClassifier} instance
  22261. + * @throws IOException if model file fails to load
  22262. + * @throws IllegalArgumentException if an argument is invalid
  22263. + * @throws IllegalStateException if there is an internal error
  22264. + * @throws RuntimeException if there is an otherwise unspecified error
  22265. + */
  22266. + public static NLClassifier createFromFile(File modelFile) throws IOException {
  22267. + return createFromFileAndOptions(modelFile, NLClassifierOptions.builder().build());
  22268. }
  22269. - /** Builder for {@link NLClassifierOptions}. */
  22270. - @AutoValue.Builder
  22271. - public abstract static class Builder {
  22272. - /** Sets the general options to configure Task APIs, such as accelerators. */
  22273. - public abstract Builder setBaseOptions(@Nullable BaseOptions baseOptions);
  22274. -
  22275. - /**
  22276. - * Configure the input/output tensors for NLClassifier:
  22277. - *
  22278. - * <p>- No special configuration is needed if the model has only one input tensor and one
  22279. - * output tensor.
  22280. - *
  22281. - * <p>- When the model has multiple input or output tensors, use the following configurations
  22282. - * to specifiy the desired tensors: <br>
  22283. - * -- tensor names: {@code inputTensorName}, {@code outputScoreTensorName}, {@code
  22284. - * outputLabelTensorName}<br>
  22285. - * -- tensor indices: {@code inputTensorIndex}, {@code outputScoreTensorIndex}, {@code
  22286. - * outputLabelTensorIndex} <br>
  22287. - * Tensor names has higher priorities than tensor indices in locating the tensors. It means
  22288. - * the tensors will be first located according to tensor names. If not found, then the tensors
  22289. - * will be located according to tensor indices.
  22290. - *
  22291. - * <p>- Failing to match the input text tensor or output score tensor with neither tensor
  22292. - * names nor tensor indices will trigger a runtime error. However, failing to locate the
  22293. - * output label tensor will not trigger an error because the label tensor is optional.
  22294. - */
  22295. -
  22296. - /**
  22297. - * Set the name of the input text tensor, if the model has multiple inputs. Only the input
  22298. - * tensor specified will be used for inference; other input tensors will be ignored. Dafualt
  22299. - * to {@code "INPUT"}.
  22300. - *
  22301. - * <p>See the section, Configure the input/output tensors for NLClassifier, for more details.
  22302. - */
  22303. - public abstract Builder setInputTensorName(String inputTensorName);
  22304. -
  22305. - /**
  22306. - * Set the name of the output score tensor, if the model has multiple outputs. Dafualt to
  22307. - * {@code "OUTPUT_SCORE"}.
  22308. - *
  22309. - * <p>See the section, Configure the input/output tensors for NLClassifier, for more details.
  22310. - */
  22311. - public abstract Builder setOutputScoreTensorName(String outputScoreTensorName);
  22312. -
  22313. - /**
  22314. - * Set the name of the output label tensor, if the model has multiple outputs. Dafualt to
  22315. - * {@code "OUTPUT_LABEL"}.
  22316. - *
  22317. - * <p>See the section, Configure the input/output tensors for NLClassifier, for more details.
  22318. - *
  22319. - * <p>By default, label file should be packed with the output score tensor through Model
  22320. - * Metadata. See the <a
  22321. - * href="https://www.tensorflow.org/lite/convert/metadata_writer_tutorial#natural_language_classifiers">MetadataWriter
  22322. - * for NLClassifier</a>. NLClassifier reads and parses labels from the label file
  22323. - * automatically. However, some models may output a specific label tensor instead. In this
  22324. - * case, NLClassifier reads labels from the output label tensor.
  22325. - */
  22326. - public abstract Builder setOutputLabelTensorName(String outputLabelTensorName);
  22327. -
  22328. - /**
  22329. - * Set the index of the input text tensor among all input tensors, if the model has multiple
  22330. - * inputs. Only the input tensor specified will be used for inference; other input tensors
  22331. - * will be ignored. Dafualt to 0.
  22332. - *
  22333. - * <p>See the section, Configure the input/output tensors for NLClassifier, for more details.
  22334. - */
  22335. - public abstract Builder setInputTensorIndex(int inputTensorIndex);
  22336. -
  22337. - /**
  22338. - * Set the index of the output score tensor among all output tensors, if the model has
  22339. - * multiple outputs. Dafualt to 0.
  22340. - *
  22341. - * <p>See the section, Configure the input/output tensors for NLClassifier, for more details.
  22342. - */
  22343. - public abstract Builder setOutputScoreTensorIndex(int outputScoreTensorIndex);
  22344. -
  22345. - /**
  22346. - * Set the index of the optional output label tensor among all output tensors, if the model
  22347. - * has multiple outputs.
  22348. - *
  22349. - * <p>See the document above {@code outputLabelTensorName} for more information about what the
  22350. - * output label tensor is.
  22351. - *
  22352. - * <p>See the section, Configure the input/output tensors for NLClassifier, for more details.
  22353. - *
  22354. - * <p>{@code outputLabelTensorIndex} dafualts to -1, meaning to disable the output label
  22355. - * tensor.
  22356. - */
  22357. - public abstract Builder setOutputLabelTensorIndex(int outputLabelTensorIndex);
  22358. -
  22359. - public abstract NLClassifierOptions build();
  22360. + /**
  22361. + * Creates {@link NLClassifier} from {@link NLClassifierOptions}.
  22362. + *
  22363. + * @param context Android context
  22364. + * @param modelPath path to the classification model relative to asset dir
  22365. + * @param options configurations for the model.
  22366. + * @return an {@link NLClassifier} instance
  22367. + * @throws IOException if model file fails to load
  22368. + * @throws IllegalArgumentException if an argument is invalid
  22369. + * @throws IllegalStateException if there is an internal error
  22370. + * @throws RuntimeException if there is an otherwise unspecified error
  22371. + */
  22372. + public static NLClassifier createFromFileAndOptions(
  22373. + Context context, String modelPath, NLClassifierOptions options) throws IOException {
  22374. + return createFromBufferAndOptions(TaskJniUtils.loadMappedFile(context, modelPath), options);
  22375. }
  22376. - }
  22377. -
  22378. - private static final String NL_CLASSIFIER_NATIVE_LIBNAME = "task_text_jni";
  22379. -
  22380. - /**
  22381. - * Creates {@link NLClassifier} from default {@link NLClassifierOptions}.
  22382. - *
  22383. - * @param context Android context
  22384. - * @param modelPath path to the classification model relative to asset dir
  22385. - * @return an {@link NLClassifier} instance
  22386. - * @throws IOException if model file fails to load
  22387. - * @throws IllegalArgumentException if an argument is invalid
  22388. - * @throws IllegalStateException if there is an internal error
  22389. - * @throws RuntimeException if there is an otherwise unspecified error
  22390. - */
  22391. - public static NLClassifier createFromFile(Context context, String modelPath) throws IOException {
  22392. - return createFromFileAndOptions(context, modelPath, NLClassifierOptions.builder().build());
  22393. - }
  22394. -
  22395. - /**
  22396. - * Creates {@link NLClassifier} from default {@link NLClassifierOptions}.
  22397. - *
  22398. - * @param modelFile the classification model {@link File} instance
  22399. - * @return an {@link NLClassifier} instance
  22400. - * @throws IOException if model file fails to load
  22401. - * @throws IllegalArgumentException if an argument is invalid
  22402. - * @throws IllegalStateException if there is an internal error
  22403. - * @throws RuntimeException if there is an otherwise unspecified error
  22404. - */
  22405. - public static NLClassifier createFromFile(File modelFile) throws IOException {
  22406. - return createFromFileAndOptions(modelFile, NLClassifierOptions.builder().build());
  22407. - }
  22408. -
  22409. - /**
  22410. - * Creates {@link NLClassifier} from {@link NLClassifierOptions}.
  22411. - *
  22412. - * @param context Android context
  22413. - * @param modelPath path to the classification model relative to asset dir
  22414. - * @param options configurations for the model.
  22415. - * @return an {@link NLClassifier} instance
  22416. - * @throws IOException if model file fails to load
  22417. - * @throws IllegalArgumentException if an argument is invalid
  22418. - * @throws IllegalStateException if there is an internal error
  22419. - * @throws RuntimeException if there is an otherwise unspecified error
  22420. - */
  22421. - public static NLClassifier createFromFileAndOptions(
  22422. - Context context, String modelPath, NLClassifierOptions options) throws IOException {
  22423. - return createFromBufferAndOptions(TaskJniUtils.loadMappedFile(context, modelPath), options);
  22424. - }
  22425. -
  22426. - /**
  22427. - * Creates {@link NLClassifier} from {@link NLClassifierOptions}.
  22428. - *
  22429. - * @param modelFile the classification model {@link File} instance
  22430. - * @param options configurations for the model
  22431. - * @return an {@link NLClassifier} instance
  22432. - * @throws IOException if model file fails to load
  22433. - * @throws IllegalArgumentException if an argument is invalid
  22434. - * @throws IllegalStateException if there is an internal error
  22435. - * @throws RuntimeException if there is an otherwise unspecified error
  22436. - */
  22437. - public static NLClassifier createFromFileAndOptions(
  22438. - File modelFile, final NLClassifierOptions options) throws IOException {
  22439. - try (ParcelFileDescriptor descriptor =
  22440. - ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
  22441. - return new NLClassifier(
  22442. - TaskJniUtils.createHandleFromLibrary(
  22443. - new EmptyHandleProvider() {
  22444. +
  22445. + /**
  22446. + * Creates {@link NLClassifier} from {@link NLClassifierOptions}.
  22447. + *
  22448. + * @param modelFile the classification model {@link File} instance
  22449. + * @param options configurations for the model
  22450. + * @return an {@link NLClassifier} instance
  22451. + * @throws IOException if model file fails to load
  22452. + * @throws IllegalArgumentException if an argument is invalid
  22453. + * @throws IllegalStateException if there is an internal error
  22454. + * @throws RuntimeException if there is an otherwise unspecified error
  22455. + */
  22456. + public static NLClassifier createFromFileAndOptions(
  22457. + File modelFile, final NLClassifierOptions options) throws IOException {
  22458. + try (ParcelFileDescriptor descriptor =
  22459. + ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
  22460. + return new NLClassifier(TaskJniUtils.createHandleFromLibrary(new EmptyHandleProvider() {
  22461. @Override
  22462. public long createHandle() {
  22463. - long baseOptionsHandle =
  22464. - options.getBaseOptions() == null
  22465. - ? 0 // pass an invalid native handle
  22466. - : TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions());
  22467. - return initJniWithFileDescriptor(options, descriptor.getFd(), baseOptionsHandle);
  22468. + long baseOptionsHandle = options.getBaseOptions() == null
  22469. + ? 0 // pass an invalid native handle
  22470. + : TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions());
  22471. + return initJniWithFileDescriptor(
  22472. + options, descriptor.getFd(), baseOptionsHandle);
  22473. }
  22474. - },
  22475. - NL_CLASSIFIER_NATIVE_LIBNAME));
  22476. - }
  22477. - }
  22478. -
  22479. - /**
  22480. - * Creates {@link NLClassifier} with a model {@link ByteBuffer} and {@link NLClassifierOptions}.
  22481. - *
  22482. - * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the
  22483. - * classification model
  22484. - * @param options configurations for the model
  22485. - * @return {@link NLClassifier} instance
  22486. - * @throws IllegalStateException if there is an internal error
  22487. - * @throws RuntimeException if there is an otherwise unspecified error
  22488. - * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
  22489. - * {@link MappedByteBuffer}
  22490. - */
  22491. - public static NLClassifier createFromBufferAndOptions(
  22492. - final ByteBuffer modelBuffer, final NLClassifierOptions options) {
  22493. - if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) {
  22494. - throw new IllegalArgumentException(
  22495. - "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer.");
  22496. + }, NL_CLASSIFIER_NATIVE_LIBNAME));
  22497. + }
  22498. }
  22499. - return new NLClassifier(
  22500. - TaskJniUtils.createHandleFromLibrary(
  22501. - new EmptyHandleProvider() {
  22502. - @Override
  22503. - public long createHandle() {
  22504. - long baseOptionsHandle =
  22505. - options.getBaseOptions() == null
  22506. + /**
  22507. + * Creates {@link NLClassifier} with a model {@link ByteBuffer} and {@link NLClassifierOptions}.
  22508. + *
  22509. + * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the
  22510. + * classification model
  22511. + * @param options configurations for the model
  22512. + * @return {@link NLClassifier} instance
  22513. + * @throws IllegalStateException if there is an internal error
  22514. + * @throws RuntimeException if there is an otherwise unspecified error
  22515. + * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
  22516. + * {@link MappedByteBuffer}
  22517. + */
  22518. + public static NLClassifier createFromBufferAndOptions(
  22519. + final ByteBuffer modelBuffer, final NLClassifierOptions options) {
  22520. + if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) {
  22521. + throw new IllegalArgumentException(
  22522. + "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer.");
  22523. + }
  22524. +
  22525. + return new NLClassifier(TaskJniUtils.createHandleFromLibrary(new EmptyHandleProvider() {
  22526. + @Override
  22527. + public long createHandle() {
  22528. + long baseOptionsHandle = options.getBaseOptions() == null
  22529. ? 0 // pass an invalid native handle
  22530. : TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions());
  22531. return initJniWithByteBuffer(options, modelBuffer, baseOptionsHandle);
  22532. - }
  22533. - },
  22534. - NL_CLASSIFIER_NATIVE_LIBNAME));
  22535. - }
  22536. -
  22537. - /**
  22538. - * Performs classification on a string input, returns classified {@link Category}s.
  22539. - *
  22540. - * @param text input text to the model
  22541. - * @return a list of Category results
  22542. - */
  22543. - public List<Category> classify(String text) {
  22544. - return classifyNative(getNativeHandle(), text);
  22545. - }
  22546. -
  22547. - /**
  22548. - * Constructor to initialize the JNI with a pointer from C++.
  22549. - *
  22550. - * @param nativeHandle a pointer referencing memory allocated in C++.
  22551. - */
  22552. - protected NLClassifier(long nativeHandle) {
  22553. - super(nativeHandle);
  22554. - }
  22555. -
  22556. - @Override
  22557. - protected void deinit(long nativeHandle) {
  22558. - deinitJni(nativeHandle);
  22559. - }
  22560. -
  22561. - private static native long initJniWithByteBuffer(
  22562. - NLClassifierOptions options, ByteBuffer modelBuffer, long baseOptionsHandle);
  22563. -
  22564. - private static native long initJniWithFileDescriptor(
  22565. - NLClassifierOptions options, int fd, long baseOptionsHandle);
  22566. -
  22567. - private static native List<Category> classifyNative(long nativeHandle, String text);
  22568. -
  22569. - /**
  22570. - * Native implementation to release memory pointed by the pointer.
  22571. - *
  22572. - * @param nativeHandle pointer to memory allocated
  22573. - */
  22574. - private native void deinitJni(long nativeHandle);
  22575. + }
  22576. + }, NL_CLASSIFIER_NATIVE_LIBNAME));
  22577. + }
  22578. +
  22579. + /**
  22580. + * Performs classification on a string input, returns classified {@link Category}s.
  22581. + *
  22582. + * @param text input text to the model
  22583. + * @return a list of Category results
  22584. + */
  22585. + public List<Category> classify(String text) {
  22586. + return classifyNative(getNativeHandle(), text);
  22587. + }
  22588. +
  22589. + /**
  22590. + * Constructor to initialize the JNI with a pointer from C++.
  22591. + *
  22592. + * @param nativeHandle a pointer referencing memory allocated in C++.
  22593. + */
  22594. + protected NLClassifier(long nativeHandle) {
  22595. + super(nativeHandle);
  22596. + }
  22597. +
  22598. + @Override
  22599. + protected void deinit(long nativeHandle) {
  22600. + deinitJni(nativeHandle);
  22601. + }
  22602. +
  22603. + private static native long initJniWithByteBuffer(
  22604. + NLClassifierOptions options, ByteBuffer modelBuffer, long baseOptionsHandle);
  22605. +
  22606. + private static native long initJniWithFileDescriptor(
  22607. + NLClassifierOptions options, int fd, long baseOptionsHandle);
  22608. +
  22609. + private static native List<Category> classifyNative(long nativeHandle, String text);
  22610. +
  22611. + /**
  22612. + * Native implementation to release memory pointed by the pointer.
  22613. + *
  22614. + * @param nativeHandle pointer to memory allocated
  22615. + */
  22616. + private native void deinitJni(long nativeHandle);
  22617. }
  22618. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/BertQuestionAnswerer.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/BertQuestionAnswerer.java
  22619. index aafa2c88c55e8..39648d9bb4042 100644
  22620. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/BertQuestionAnswerer.java
  22621. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/BertQuestionAnswerer.java
  22622. @@ -17,11 +17,9 @@ package org.tensorflow.lite.task.text.qa;
  22623. import android.content.Context;
  22624. import android.os.ParcelFileDescriptor;
  22625. +
  22626. import com.google.auto.value.AutoValue;
  22627. -import java.io.File;
  22628. -import java.io.IOException;
  22629. -import java.nio.ByteBuffer;
  22630. -import java.util.List;
  22631. +
  22632. import org.tensorflow.lite.task.core.BaseOptions;
  22633. import org.tensorflow.lite.task.core.BaseTaskApi;
  22634. import org.tensorflow.lite.task.core.TaskJniUtils;
  22635. @@ -29,6 +27,11 @@ import org.tensorflow.lite.task.core.TaskJniUtils.EmptyHandleProvider;
  22636. import org.tensorflow.lite.task.core.TaskJniUtils.FdAndOptionsHandleProvider;
  22637. import org.tensorflow.lite.task.core.TaskJniUtils.MultipleBuffersHandleProvider;
  22638. +import java.io.File;
  22639. +import java.io.IOException;
  22640. +import java.nio.ByteBuffer;
  22641. +import java.util.List;
  22642. +
  22643. /**
  22644. * Returns the most possible answers on a given question for QA models (BERT, Albert, etc.).
  22645. *
  22646. @@ -45,225 +48,204 @@ import org.tensorflow.lite.task.core.TaskJniUtils.MultipleBuffersHandleProvider;
  22647. * </ul>
  22648. */
  22649. public class BertQuestionAnswerer extends BaseTaskApi implements QuestionAnswerer {
  22650. - private static final String BERT_QUESTION_ANSWERER_NATIVE_LIBNAME = "task_text_jni";
  22651. - private static final int OPTIONAL_FD_LENGTH = -1;
  22652. - private static final int OPTIONAL_FD_OFFSET = -1;
  22653. -
  22654. - /**
  22655. - * Creates a {@link BertQuestionAnswerer} instance from the default {@link
  22656. - * BertQuestionAnswererOptions}.
  22657. - *
  22658. - * @param context android context
  22659. - * @param modelPath file path to the model with metadata. Note: The model should not be compressed
  22660. - * @return a {@link BertQuestionAnswerer} instance
  22661. - * @throws IOException if model file fails to load
  22662. - * @throws IllegalArgumentException if an argument is invalid
  22663. - * @throws IllegalStateException if there is an internal error
  22664. - * @throws RuntimeException if there is an otherwise unspecified error
  22665. - */
  22666. - public static BertQuestionAnswerer createFromFile(Context context, String modelPath)
  22667. - throws IOException {
  22668. - return createFromFileAndOptions(
  22669. - context, modelPath, BertQuestionAnswererOptions.builder().build());
  22670. - }
  22671. + private static final String BERT_QUESTION_ANSWERER_NATIVE_LIBNAME = "task_text_jni";
  22672. + private static final int OPTIONAL_FD_LENGTH = -1;
  22673. + private static final int OPTIONAL_FD_OFFSET = -1;
  22674. +
  22675. + /**
  22676. + * Creates a {@link BertQuestionAnswerer} instance from the default {@link
  22677. + * BertQuestionAnswererOptions}.
  22678. + *
  22679. + * @param context android context
  22680. + * @param modelPath file path to the model with metadata. Note: The model should not be
  22681. + * compressed
  22682. + * @return a {@link BertQuestionAnswerer} instance
  22683. + * @throws IOException if model file fails to load
  22684. + * @throws IllegalArgumentException if an argument is invalid
  22685. + * @throws IllegalStateException if there is an internal error
  22686. + * @throws RuntimeException if there is an otherwise unspecified error
  22687. + */
  22688. + public static BertQuestionAnswerer createFromFile(Context context, String modelPath)
  22689. + throws IOException {
  22690. + return createFromFileAndOptions(
  22691. + context, modelPath, BertQuestionAnswererOptions.builder().build());
  22692. + }
  22693. - /**
  22694. - * Creates a {@link BertQuestionAnswerer} instance from the default {@link
  22695. - * BertQuestionAnswererOptions}.
  22696. - *
  22697. - * @param modelFile a {@link File} object of the model
  22698. - * @return a {@link BertQuestionAnswerer} instance
  22699. - * @throws IOException if model file fails to load
  22700. - * @throws IllegalArgumentException if an argument is invalid
  22701. - * @throws IllegalStateException if there is an internal error
  22702. - * @throws RuntimeException if there is an otherwise unspecified error
  22703. - */
  22704. - public static BertQuestionAnswerer createFromFile(File modelFile) throws IOException {
  22705. - return createFromFileAndOptions(modelFile, BertQuestionAnswererOptions.builder().build());
  22706. - }
  22707. + /**
  22708. + * Creates a {@link BertQuestionAnswerer} instance from the default {@link
  22709. + * BertQuestionAnswererOptions}.
  22710. + *
  22711. + * @param modelFile a {@link File} object of the model
  22712. + * @return a {@link BertQuestionAnswerer} instance
  22713. + * @throws IOException if model file fails to load
  22714. + * @throws IllegalArgumentException if an argument is invalid
  22715. + * @throws IllegalStateException if there is an internal error
  22716. + * @throws RuntimeException if there is an otherwise unspecified error
  22717. + */
  22718. + public static BertQuestionAnswerer createFromFile(File modelFile) throws IOException {
  22719. + return createFromFileAndOptions(modelFile, BertQuestionAnswererOptions.builder().build());
  22720. + }
  22721. - /**
  22722. - * Creates a {@link BertQuestionAnswerer} instance from {@link BertQuestionAnswererOptions}.
  22723. - *
  22724. - * @param context android context
  22725. - * @param modelPath file path to the model with metadata. Note: The model should not be compressed
  22726. - * @return a {@link BertQuestionAnswerer} instance
  22727. - * @throws IOException if model file fails to load
  22728. - * @throws IllegalArgumentException if an argument is invalid
  22729. - * @throws IllegalStateException if there is an internal error
  22730. - * @throws RuntimeException if there is an otherwise unspecified error
  22731. - */
  22732. - public static BertQuestionAnswerer createFromFileAndOptions(
  22733. - Context context, String modelPath, BertQuestionAnswererOptions options) throws IOException {
  22734. - return new BertQuestionAnswerer(
  22735. - TaskJniUtils.createHandleFromFdAndOptions(
  22736. - context,
  22737. - new FdAndOptionsHandleProvider<BertQuestionAnswererOptions>() {
  22738. - @Override
  22739. - public long createHandle(
  22740. - int fileDescriptor,
  22741. - long fileDescriptorLength,
  22742. - long fileDescriptorOffset,
  22743. - BertQuestionAnswererOptions options) {
  22744. - return initJniWithFileDescriptor(
  22745. - fileDescriptor,
  22746. - fileDescriptorLength,
  22747. - fileDescriptorOffset,
  22748. - TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions()));
  22749. - }
  22750. - },
  22751. - BERT_QUESTION_ANSWERER_NATIVE_LIBNAME,
  22752. - modelPath,
  22753. - options));
  22754. - }
  22755. + /**
  22756. + * Creates a {@link BertQuestionAnswerer} instance from {@link BertQuestionAnswererOptions}.
  22757. + *
  22758. + * @param context android context
  22759. + * @param modelPath file path to the model with metadata. Note: The model should not be
  22760. + * compressed
  22761. + * @return a {@link BertQuestionAnswerer} instance
  22762. + * @throws IOException if model file fails to load
  22763. + * @throws IllegalArgumentException if an argument is invalid
  22764. + * @throws IllegalStateException if there is an internal error
  22765. + * @throws RuntimeException if there is an otherwise unspecified error
  22766. + */
  22767. + public static BertQuestionAnswerer createFromFileAndOptions(Context context, String modelPath,
  22768. + BertQuestionAnswererOptions options) throws IOException {
  22769. + return new BertQuestionAnswerer(TaskJniUtils.createHandleFromFdAndOptions(
  22770. + context, new FdAndOptionsHandleProvider<BertQuestionAnswererOptions>() {
  22771. + @Override
  22772. + public long createHandle(int fileDescriptor, long fileDescriptorLength,
  22773. + long fileDescriptorOffset, BertQuestionAnswererOptions options) {
  22774. + return initJniWithFileDescriptor(fileDescriptor, fileDescriptorLength,
  22775. + fileDescriptorOffset,
  22776. + TaskJniUtils.createProtoBaseOptionsHandle(
  22777. + options.getBaseOptions()));
  22778. + }
  22779. + }, BERT_QUESTION_ANSWERER_NATIVE_LIBNAME, modelPath, options));
  22780. + }
  22781. - /**
  22782. - * Creates a {@link BertQuestionAnswerer} instance from {@link BertQuestionAnswererOptions}.
  22783. - *
  22784. - * @param modelFile a {@link File} object of the model
  22785. - * @return a {@link BertQuestionAnswerer} instance
  22786. - * @throws IOException if model file fails to load
  22787. - * @throws IllegalArgumentException if an argument is invalid
  22788. - * @throws IllegalStateException if there is an internal error
  22789. - * @throws RuntimeException if there is an otherwise unspecified error
  22790. - */
  22791. - public static BertQuestionAnswerer createFromFileAndOptions(
  22792. - File modelFile, final BertQuestionAnswererOptions options) throws IOException {
  22793. - try (ParcelFileDescriptor descriptor =
  22794. - ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
  22795. - return new BertQuestionAnswerer(
  22796. - TaskJniUtils.createHandleFromLibrary(
  22797. - new EmptyHandleProvider() {
  22798. - @Override
  22799. - public long createHandle() {
  22800. - return initJniWithFileDescriptor(
  22801. - /*fileDescriptor=*/ descriptor.getFd(),
  22802. - /*fileDescriptorLength=*/ OPTIONAL_FD_LENGTH,
  22803. - /*fileDescriptorOffset=*/ OPTIONAL_FD_OFFSET,
  22804. - TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions()));
  22805. - }
  22806. - },
  22807. - BERT_QUESTION_ANSWERER_NATIVE_LIBNAME));
  22808. + /**
  22809. + * Creates a {@link BertQuestionAnswerer} instance from {@link BertQuestionAnswererOptions}.
  22810. + *
  22811. + * @param modelFile a {@link File} object of the model
  22812. + * @return a {@link BertQuestionAnswerer} instance
  22813. + * @throws IOException if model file fails to load
  22814. + * @throws IllegalArgumentException if an argument is invalid
  22815. + * @throws IllegalStateException if there is an internal error
  22816. + * @throws RuntimeException if there is an otherwise unspecified error
  22817. + */
  22818. + public static BertQuestionAnswerer createFromFileAndOptions(
  22819. + File modelFile, final BertQuestionAnswererOptions options) throws IOException {
  22820. + try (ParcelFileDescriptor descriptor =
  22821. + ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
  22822. + return new BertQuestionAnswerer(
  22823. + TaskJniUtils.createHandleFromLibrary(new EmptyHandleProvider() {
  22824. + @Override
  22825. + public long createHandle() {
  22826. + return initJniWithFileDescriptor(
  22827. + /*fileDescriptor=*/descriptor.getFd(),
  22828. + /*fileDescriptorLength=*/OPTIONAL_FD_LENGTH,
  22829. + /*fileDescriptorOffset=*/OPTIONAL_FD_OFFSET,
  22830. + TaskJniUtils.createProtoBaseOptionsHandle(
  22831. + options.getBaseOptions()));
  22832. + }
  22833. + }, BERT_QUESTION_ANSWERER_NATIVE_LIBNAME));
  22834. + }
  22835. }
  22836. - }
  22837. - /**
  22838. - * Creates a {@link BertQuestionAnswerer} instance with a Bert model and a vocabulary file.
  22839. - *
  22840. - * <p>One suitable model is: https://tfhub.dev/tensorflow/lite-model/mobilebert/1/default/1
  22841. - *
  22842. - * @param context android context
  22843. - * @param modelPath file path to the Bert model. Note: The model should not be compressed
  22844. - * @param vocabPath file path to the vocabulary file. Note: The file should not be compressed
  22845. - * @return a {@link BertQuestionAnswerer} instance
  22846. - * @throws IOException If model file fails to load
  22847. - * @throws IllegalArgumentException if an argument is invalid
  22848. - * @throws IllegalStateException if there is an internal error
  22849. - * @throws RuntimeException if there is an otherwise unspecified error
  22850. - */
  22851. - public static BertQuestionAnswerer createBertQuestionAnswererFromFile(
  22852. - Context context, String modelPath, String vocabPath) throws IOException {
  22853. - return new BertQuestionAnswerer(
  22854. - TaskJniUtils.createHandleWithMultipleAssetFilesFromLibrary(
  22855. - context,
  22856. - new MultipleBuffersHandleProvider() {
  22857. - @Override
  22858. - public long createHandle(ByteBuffer... buffers) {
  22859. - return initJniWithBertByteBuffers(buffers);
  22860. - }
  22861. - },
  22862. - BERT_QUESTION_ANSWERER_NATIVE_LIBNAME,
  22863. - modelPath,
  22864. - vocabPath));
  22865. - }
  22866. + /**
  22867. + * Creates a {@link BertQuestionAnswerer} instance with a Bert model and a vocabulary file.
  22868. + *
  22869. + * <p>One suitable model is: https://tfhub.dev/tensorflow/lite-model/mobilebert/1/default/1
  22870. + *
  22871. + * @param context android context
  22872. + * @param modelPath file path to the Bert model. Note: The model should not be compressed
  22873. + * @param vocabPath file path to the vocabulary file. Note: The file should not be compressed
  22874. + * @return a {@link BertQuestionAnswerer} instance
  22875. + * @throws IOException If model file fails to load
  22876. + * @throws IllegalArgumentException if an argument is invalid
  22877. + * @throws IllegalStateException if there is an internal error
  22878. + * @throws RuntimeException if there is an otherwise unspecified error
  22879. + */
  22880. + public static BertQuestionAnswerer createBertQuestionAnswererFromFile(
  22881. + Context context, String modelPath, String vocabPath) throws IOException {
  22882. + return new BertQuestionAnswerer(TaskJniUtils.createHandleWithMultipleAssetFilesFromLibrary(
  22883. + context, new MultipleBuffersHandleProvider() {
  22884. + @Override
  22885. + public long createHandle(ByteBuffer... buffers) {
  22886. + return initJniWithBertByteBuffers(buffers);
  22887. + }
  22888. + }, BERT_QUESTION_ANSWERER_NATIVE_LIBNAME, modelPath, vocabPath));
  22889. + }
  22890. - /**
  22891. - * Creates a {@link BertQuestionAnswerer} instance with an Albert model and a sentence piece model
  22892. - * file.
  22893. - *
  22894. - * <p>One suitable model is: https://tfhub.dev/tensorflow/lite-model/albert_lite_base/squadv1/1
  22895. - *
  22896. - * @param context android context
  22897. - * @param modelPath file path to the Albert model. Note: The model should not be compressed
  22898. - * @param sentencePieceModelPath file path to the sentence piece model file. Note: The model
  22899. - * should not be compressed
  22900. - * @return a {@link BertQuestionAnswerer} instance
  22901. - * @throws IOException If model file fails to load
  22902. - * @throws IllegalArgumentException if an argument is invalid
  22903. - * @throws IllegalStateException if there is an internal error
  22904. - * @throws RuntimeException if there is an otherwise unspecified error
  22905. - */
  22906. - public static BertQuestionAnswerer createAlbertQuestionAnswererFromFile(
  22907. - Context context, String modelPath, String sentencePieceModelPath) throws IOException {
  22908. - return new BertQuestionAnswerer(
  22909. - TaskJniUtils.createHandleWithMultipleAssetFilesFromLibrary(
  22910. - context,
  22911. - new MultipleBuffersHandleProvider() {
  22912. - @Override
  22913. - public long createHandle(ByteBuffer... buffers) {
  22914. - return initJniWithAlbertByteBuffers(buffers);
  22915. - }
  22916. - },
  22917. - BERT_QUESTION_ANSWERER_NATIVE_LIBNAME,
  22918. - modelPath,
  22919. - sentencePieceModelPath));
  22920. - }
  22921. + /**
  22922. + * Creates a {@link BertQuestionAnswerer} instance with an Albert model and a sentence piece
  22923. + * model file.
  22924. + *
  22925. + * <p>One suitable model is: https://tfhub.dev/tensorflow/lite-model/albert_lite_base/squadv1/1
  22926. + *
  22927. + * @param context android context
  22928. + * @param modelPath file path to the Albert model. Note: The model should not be compressed
  22929. + * @param sentencePieceModelPath file path to the sentence piece model file. Note: The model
  22930. + * should not be compressed
  22931. + * @return a {@link BertQuestionAnswerer} instance
  22932. + * @throws IOException If model file fails to load
  22933. + * @throws IllegalArgumentException if an argument is invalid
  22934. + * @throws IllegalStateException if there is an internal error
  22935. + * @throws RuntimeException if there is an otherwise unspecified error
  22936. + */
  22937. + public static BertQuestionAnswerer createAlbertQuestionAnswererFromFile(
  22938. + Context context, String modelPath, String sentencePieceModelPath) throws IOException {
  22939. + return new BertQuestionAnswerer(TaskJniUtils.createHandleWithMultipleAssetFilesFromLibrary(
  22940. + context, new MultipleBuffersHandleProvider() {
  22941. + @Override
  22942. + public long createHandle(ByteBuffer... buffers) {
  22943. + return initJniWithAlbertByteBuffers(buffers);
  22944. + }
  22945. + }, BERT_QUESTION_ANSWERER_NATIVE_LIBNAME, modelPath, sentencePieceModelPath));
  22946. + }
  22947. - /** Options for setting up a {@link BertQuestionAnswerer}. */
  22948. - @AutoValue
  22949. - public abstract static class BertQuestionAnswererOptions {
  22950. - abstract BaseOptions getBaseOptions();
  22951. + /** Options for setting up a {@link BertQuestionAnswerer}. */
  22952. + @AutoValue
  22953. + public abstract static class BertQuestionAnswererOptions {
  22954. + abstract BaseOptions getBaseOptions();
  22955. - public static Builder builder() {
  22956. - return new AutoValue_BertQuestionAnswerer_BertQuestionAnswererOptions.Builder()
  22957. - .setBaseOptions(BaseOptions.builder().build());
  22958. - }
  22959. + public static Builder builder() {
  22960. + return new AutoValue_BertQuestionAnswerer_BertQuestionAnswererOptions.Builder()
  22961. + .setBaseOptions(BaseOptions.builder().build());
  22962. + }
  22963. - /** Builder for {@link BertQuestionAnswererOptions}. */
  22964. - @AutoValue.Builder
  22965. - public abstract static class Builder {
  22966. - /** Sets the general options to configure Task APIs, such as accelerators. */
  22967. - public abstract Builder setBaseOptions(BaseOptions baseOptions);
  22968. + /** Builder for {@link BertQuestionAnswererOptions}. */
  22969. + @AutoValue.Builder
  22970. + public abstract static class Builder {
  22971. + /** Sets the general options to configure Task APIs, such as accelerators. */
  22972. + public abstract Builder setBaseOptions(BaseOptions baseOptions);
  22973. - public abstract BertQuestionAnswererOptions build();
  22974. + public abstract BertQuestionAnswererOptions build();
  22975. + }
  22976. }
  22977. - }
  22978. - @Override
  22979. - public List<QaAnswer> answer(String context, String question) {
  22980. - checkNotClosed();
  22981. - return answerNative(getNativeHandle(), context, question);
  22982. - }
  22983. + @Override
  22984. + public List<QaAnswer> answer(String context, String question) {
  22985. + checkNotClosed();
  22986. + return answerNative(getNativeHandle(), context, question);
  22987. + }
  22988. - private BertQuestionAnswerer(long nativeHandle) {
  22989. - super(nativeHandle);
  22990. - }
  22991. + private BertQuestionAnswerer(long nativeHandle) {
  22992. + super(nativeHandle);
  22993. + }
  22994. - // modelBuffers[0] is tflite model file buffer, and modelBuffers[1] is vocab file buffer.
  22995. - private static native long initJniWithBertByteBuffers(ByteBuffer... modelBuffers);
  22996. + // modelBuffers[0] is tflite model file buffer, and modelBuffers[1] is vocab file buffer.
  22997. + private static native long initJniWithBertByteBuffers(ByteBuffer... modelBuffers);
  22998. - // modelBuffers[0] is tflite model file buffer, and modelBuffers[1] is sentencepiece model file
  22999. - // buffer.
  23000. - private static native long initJniWithAlbertByteBuffers(ByteBuffer... modelBuffers);
  23001. + // modelBuffers[0] is tflite model file buffer, and modelBuffers[1] is sentencepiece model file
  23002. + // buffer.
  23003. + private static native long initJniWithAlbertByteBuffers(ByteBuffer... modelBuffers);
  23004. - private static native long initJniWithFileDescriptor(
  23005. - int fileDescriptor,
  23006. - long fileDescriptorLength,
  23007. - long fileDescriptorOffset,
  23008. - long baseOptionsHandle);
  23009. + private static native long initJniWithFileDescriptor(int fileDescriptor,
  23010. + long fileDescriptorLength, long fileDescriptorOffset, long baseOptionsHandle);
  23011. - private static native List<QaAnswer> answerNative(
  23012. - long nativeHandle, String context, String question);
  23013. + private static native List<QaAnswer> answerNative(
  23014. + long nativeHandle, String context, String question);
  23015. - @Override
  23016. - protected void deinit(long nativeHandle) {
  23017. - deinitJni(nativeHandle);
  23018. - }
  23019. + @Override
  23020. + protected void deinit(long nativeHandle) {
  23021. + deinitJni(nativeHandle);
  23022. + }
  23023. - /**
  23024. - * Native implementation to release memory pointed by the pointer.
  23025. - *
  23026. - * @param nativeHandle pointer to memory allocated
  23027. - */
  23028. - private native void deinitJni(long nativeHandle);
  23029. + /**
  23030. + * Native implementation to release memory pointed by the pointer.
  23031. + *
  23032. + * @param nativeHandle pointer to memory allocated
  23033. + */
  23034. + private native void deinitJni(long nativeHandle);
  23035. }
  23036. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/QaAnswer.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/QaAnswer.java
  23037. index b75a07e10cc7b..50917c035a995 100644
  23038. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/QaAnswer.java
  23039. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/QaAnswer.java
  23040. @@ -22,37 +22,37 @@ import org.tensorflow.lite.task.core.annotations.UsedByReflection;
  23041. * position information to the context.
  23042. */
  23043. public class QaAnswer {
  23044. - public Pos pos;
  23045. - public String text;
  23046. -
  23047. - @UsedByReflection("bert_question_answerer_jni.cc")
  23048. - public QaAnswer(String text, Pos pos) {
  23049. - this.text = text;
  23050. - this.pos = pos;
  23051. - }
  23052. -
  23053. - public QaAnswer(String text, int start, int end, float logit) {
  23054. - this(text, new Pos(start, end, logit));
  23055. - }
  23056. -
  23057. - /**
  23058. - * Position information of the answer relative to context. It is sortable in descending order
  23059. - * based on logit.
  23060. - */
  23061. - public static class Pos implements Comparable<Pos> {
  23062. - public int start;
  23063. - public int end;
  23064. - public float logit;
  23065. -
  23066. - public Pos(int start, int end, float logit) {
  23067. - this.start = start;
  23068. - this.end = end;
  23069. - this.logit = logit;
  23070. + public Pos pos;
  23071. + public String text;
  23072. +
  23073. + @UsedByReflection("bert_question_answerer_jni.cc")
  23074. + public QaAnswer(String text, Pos pos) {
  23075. + this.text = text;
  23076. + this.pos = pos;
  23077. + }
  23078. +
  23079. + public QaAnswer(String text, int start, int end, float logit) {
  23080. + this(text, new Pos(start, end, logit));
  23081. }
  23082. - @Override
  23083. - public int compareTo(Pos other) {
  23084. - return Float.compare(other.logit, this.logit);
  23085. + /**
  23086. + * Position information of the answer relative to context. It is sortable in descending order
  23087. + * based on logit.
  23088. + */
  23089. + public static class Pos implements Comparable<Pos> {
  23090. + public int start;
  23091. + public int end;
  23092. + public float logit;
  23093. +
  23094. + public Pos(int start, int end, float logit) {
  23095. + this.start = start;
  23096. + this.end = end;
  23097. + this.logit = logit;
  23098. + }
  23099. +
  23100. + @Override
  23101. + public int compareTo(Pos other) {
  23102. + return Float.compare(other.logit, this.logit);
  23103. + }
  23104. }
  23105. - }
  23106. }
  23107. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/QuestionAnswerer.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/QuestionAnswerer.java
  23108. index 8df6d3794e1b5..7a59a99d7fddf 100644
  23109. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/QuestionAnswerer.java
  23110. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/QuestionAnswerer.java
  23111. @@ -19,14 +19,13 @@ import java.util.List;
  23112. /** API to answer questions based on context. */
  23113. public interface QuestionAnswerer {
  23114. -
  23115. - /**
  23116. - * Answers question based on context, and returns a list of possible {@link QaAnswer}s. Could be
  23117. - * empty if no answer was found from the given context.
  23118. - *
  23119. - * @param context context the question bases on
  23120. - * @param question question to ask
  23121. - * @return a list of possible answers in {@link QaAnswer}
  23122. - */
  23123. - List<QaAnswer> answer(String context, String question);
  23124. + /**
  23125. + * Answers question based on context, and returns a list of possible {@link QaAnswer}s. Could be
  23126. + * empty if no answer was found from the given context.
  23127. + *
  23128. + * @param context context the question bases on
  23129. + * @param question question to ask
  23130. + * @return a list of possible answers in {@link QaAnswer}
  23131. + */
  23132. + List<QaAnswer> answer(String context, String question);
  23133. }
  23134. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/searcher/TextSearcher.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/searcher/TextSearcher.java
  23135. index 1a32d10e47114..ea3b1b8c25b34 100644
  23136. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/searcher/TextSearcher.java
  23137. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/searcher/TextSearcher.java
  23138. @@ -18,12 +18,9 @@ package org.tensorflow.lite.task.text.searcher;
  23139. import android.content.Context;
  23140. import android.content.res.AssetFileDescriptor;
  23141. import android.os.ParcelFileDescriptor;
  23142. +
  23143. import com.google.auto.value.AutoValue;
  23144. -import java.io.File;
  23145. -import java.io.IOException;
  23146. -import java.nio.ByteBuffer;
  23147. -import java.nio.MappedByteBuffer;
  23148. -import java.util.List;
  23149. +
  23150. import org.tensorflow.lite.task.core.BaseOptions;
  23151. import org.tensorflow.lite.task.core.BaseTaskApi;
  23152. import org.tensorflow.lite.task.core.TaskJniUtils;
  23153. @@ -31,6 +28,12 @@ import org.tensorflow.lite.task.core.TaskJniUtils.EmptyHandleProvider;
  23154. import org.tensorflow.lite.task.processor.NearestNeighbor;
  23155. import org.tensorflow.lite.task.processor.SearcherOptions;
  23156. +import java.io.File;
  23157. +import java.io.IOException;
  23158. +import java.nio.ByteBuffer;
  23159. +import java.nio.MappedByteBuffer;
  23160. +import java.util.List;
  23161. +
  23162. /**
  23163. * Performs similarity search on text string.
  23164. *
  23165. @@ -67,227 +70,193 @@ import org.tensorflow.lite.task.processor.SearcherOptions;
  23166. * the single file format (index file packed in the model) is supported.
  23167. */
  23168. public final class TextSearcher extends BaseTaskApi {
  23169. + private static final String TEXT_SEARCHER_NATIVE_LIB = "task_text_jni";
  23170. + private static final int OPTIONAL_FD_LENGTH = -1;
  23171. + private static final int OPTIONAL_FD_OFFSET = -1;
  23172. - private static final String TEXT_SEARCHER_NATIVE_LIB = "task_text_jni";
  23173. - private static final int OPTIONAL_FD_LENGTH = -1;
  23174. - private static final int OPTIONAL_FD_OFFSET = -1;
  23175. + /**
  23176. + * Creates an {@link TextSearcher} instance from {@link TextSearcherOptions}.
  23177. + *
  23178. + * @param modelPath path of the search model with metadata in the assets
  23179. + * @throws IOException if an I/O error occurs when loading the tflite model or the index file
  23180. + * @throws IllegalArgumentException if an argument is invalid
  23181. + * @throws IllegalStateException if there is an internal error
  23182. + * @throws RuntimeException if there is an otherwise unspecified error
  23183. + */
  23184. + public static TextSearcher createFromFileAndOptions(Context context, String modelPath,
  23185. + final TextSearcherOptions options) throws IOException {
  23186. + try (AssetFileDescriptor assetFileDescriptor = context.getAssets().openFd(modelPath)) {
  23187. + return createFromModelFdAndOptions(
  23188. + /*modelDescriptor=*/assetFileDescriptor.getParcelFileDescriptor().getFd(),
  23189. + /*modelDescriptorLength=*/assetFileDescriptor.getLength(),
  23190. + /*modelDescriptorOffset=*/assetFileDescriptor.getStartOffset(), options);
  23191. + }
  23192. + }
  23193. - /**
  23194. - * Creates an {@link TextSearcher} instance from {@link TextSearcherOptions}.
  23195. - *
  23196. - * @param modelPath path of the search model with metadata in the assets
  23197. - * @throws IOException if an I/O error occurs when loading the tflite model or the index file
  23198. - * @throws IllegalArgumentException if an argument is invalid
  23199. - * @throws IllegalStateException if there is an internal error
  23200. - * @throws RuntimeException if there is an otherwise unspecified error
  23201. - */
  23202. - public static TextSearcher createFromFileAndOptions(
  23203. - Context context, String modelPath, final TextSearcherOptions options) throws IOException {
  23204. - try (AssetFileDescriptor assetFileDescriptor = context.getAssets().openFd(modelPath)) {
  23205. - return createFromModelFdAndOptions(
  23206. - /*modelDescriptor=*/ assetFileDescriptor.getParcelFileDescriptor().getFd(),
  23207. - /*modelDescriptorLength=*/ assetFileDescriptor.getLength(),
  23208. - /*modelDescriptorOffset=*/ assetFileDescriptor.getStartOffset(),
  23209. - options);
  23210. + /**
  23211. + * Creates an {@link TextSearcher} instance.
  23212. + *
  23213. + * @param modelFile the search model {@link File} instance
  23214. + * @throws IOException if an I/O error occurs when loading the tflite model or the index file
  23215. + * @throws IllegalArgumentException if an argument is invalid
  23216. + * @throws IllegalStateException if there is an internal error
  23217. + * @throws RuntimeException if there is an otherwise unspecified error
  23218. + */
  23219. + public static TextSearcher createFromFileAndOptions(
  23220. + File modelFile, final TextSearcherOptions options) throws IOException {
  23221. + try (ParcelFileDescriptor descriptor =
  23222. + ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
  23223. + return createFromModelFdAndOptions(
  23224. + /*modelDescriptor=*/descriptor.getFd(),
  23225. + /*modelDescriptorLength=*/OPTIONAL_FD_LENGTH,
  23226. + /*modelDescriptorOffset=*/OPTIONAL_FD_OFFSET, options);
  23227. + }
  23228. }
  23229. - }
  23230. - /**
  23231. - * Creates an {@link TextSearcher} instance.
  23232. - *
  23233. - * @param modelFile the search model {@link File} instance
  23234. - * @throws IOException if an I/O error occurs when loading the tflite model or the index file
  23235. - * @throws IllegalArgumentException if an argument is invalid
  23236. - * @throws IllegalStateException if there is an internal error
  23237. - * @throws RuntimeException if there is an otherwise unspecified error
  23238. - */
  23239. - public static TextSearcher createFromFileAndOptions(
  23240. - File modelFile, final TextSearcherOptions options) throws IOException {
  23241. - try (ParcelFileDescriptor descriptor =
  23242. - ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
  23243. - return createFromModelFdAndOptions(
  23244. - /*modelDescriptor=*/ descriptor.getFd(),
  23245. - /*modelDescriptorLength=*/ OPTIONAL_FD_LENGTH,
  23246. - /*modelDescriptorOffset=*/ OPTIONAL_FD_OFFSET,
  23247. - options);
  23248. + /**
  23249. + * Creates an {@link TextSearcher} instance with a model buffer and {@link TextSearcherOptions}.
  23250. + *
  23251. + * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the search
  23252. + * model
  23253. + * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
  23254. + * {@link MappedByteBuffer}
  23255. + * @throws IOException if an I/O error occurs when loading the index file
  23256. + * @throws IllegalStateException if there is an internal error
  23257. + * @throws RuntimeException if there is an otherwise unspecified error
  23258. + */
  23259. + public static TextSearcher createFromBufferAndOptions(
  23260. + final ByteBuffer modelBuffer, final TextSearcherOptions options) throws IOException {
  23261. + if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) {
  23262. + throw new IllegalArgumentException(
  23263. + "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer.");
  23264. + }
  23265. + if (options.getSearcherOptions().getIndexFile() != null) {
  23266. + try (ParcelFileDescriptor indexDescriptor =
  23267. + ParcelFileDescriptor.open(options.getSearcherOptions().getIndexFile(),
  23268. + ParcelFileDescriptor.MODE_READ_ONLY)) {
  23269. + return createFromBufferAndOptionsImpl(
  23270. + modelBuffer, options, indexDescriptor.getFd());
  23271. + }
  23272. + } else {
  23273. + return createFromBufferAndOptionsImpl(modelBuffer, options, /*indexFd=*/0);
  23274. + }
  23275. }
  23276. - }
  23277. - /**
  23278. - * Creates an {@link TextSearcher} instance with a model buffer and {@link TextSearcherOptions}.
  23279. - *
  23280. - * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the search
  23281. - * model
  23282. - * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
  23283. - * {@link MappedByteBuffer}
  23284. - * @throws IOException if an I/O error occurs when loading the index file
  23285. - * @throws IllegalStateException if there is an internal error
  23286. - * @throws RuntimeException if there is an otherwise unspecified error
  23287. - */
  23288. - public static TextSearcher createFromBufferAndOptions(
  23289. - final ByteBuffer modelBuffer, final TextSearcherOptions options) throws IOException {
  23290. - if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) {
  23291. - throw new IllegalArgumentException(
  23292. - "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer.");
  23293. + public static TextSearcher createFromBufferAndOptionsImpl(
  23294. + final ByteBuffer modelBuffer, final TextSearcherOptions options, final int indexFd) {
  23295. + return new TextSearcher(TaskJniUtils.createHandleFromLibrary(new EmptyHandleProvider() {
  23296. + @Override
  23297. + public long createHandle() {
  23298. + return initJniWithByteBuffer(modelBuffer,
  23299. + TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions()),
  23300. + options.getSearcherOptions().getL2Normalize(),
  23301. + options.getSearcherOptions().getQuantize(), indexFd,
  23302. + options.getSearcherOptions().getMaxResults());
  23303. + }
  23304. + }, TEXT_SEARCHER_NATIVE_LIB));
  23305. }
  23306. - if (options.getSearcherOptions().getIndexFile() != null) {
  23307. - try (ParcelFileDescriptor indexDescriptor =
  23308. - ParcelFileDescriptor.open(
  23309. - options.getSearcherOptions().getIndexFile(), ParcelFileDescriptor.MODE_READ_ONLY)) {
  23310. - return createFromBufferAndOptionsImpl(modelBuffer, options, indexDescriptor.getFd());
  23311. - }
  23312. - } else {
  23313. - return createFromBufferAndOptionsImpl(modelBuffer, options, /*indexFd=*/ 0);
  23314. +
  23315. + /**
  23316. + * Constructor to initialize the JNI with a pointer from C++.
  23317. + *
  23318. + * @param nativeHandle a pointer referencing memory allocated in C++
  23319. + */
  23320. + TextSearcher(long nativeHandle) {
  23321. + super(nativeHandle);
  23322. }
  23323. - }
  23324. - public static TextSearcher createFromBufferAndOptionsImpl(
  23325. - final ByteBuffer modelBuffer, final TextSearcherOptions options, final int indexFd) {
  23326. - return new TextSearcher(
  23327. - TaskJniUtils.createHandleFromLibrary(
  23328. - new EmptyHandleProvider() {
  23329. - @Override
  23330. - public long createHandle() {
  23331. - return initJniWithByteBuffer(
  23332. - modelBuffer,
  23333. - TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions()),
  23334. - options.getSearcherOptions().getL2Normalize(),
  23335. - options.getSearcherOptions().getQuantize(),
  23336. - indexFd,
  23337. - options.getSearcherOptions().getMaxResults());
  23338. - }
  23339. - },
  23340. - TEXT_SEARCHER_NATIVE_LIB));
  23341. - }
  23342. + /** Options for setting up an TextSearcher. */
  23343. + @AutoValue
  23344. + public abstract static class TextSearcherOptions {
  23345. + abstract BaseOptions getBaseOptions();
  23346. - /**
  23347. - * Constructor to initialize the JNI with a pointer from C++.
  23348. - *
  23349. - * @param nativeHandle a pointer referencing memory allocated in C++
  23350. - */
  23351. - TextSearcher(long nativeHandle) {
  23352. - super(nativeHandle);
  23353. - }
  23354. + abstract SearcherOptions getSearcherOptions();
  23355. - /** Options for setting up an TextSearcher. */
  23356. - @AutoValue
  23357. - public abstract static class TextSearcherOptions {
  23358. + public static Builder builder() {
  23359. + return new AutoValue_TextSearcher_TextSearcherOptions.Builder()
  23360. + .setBaseOptions(BaseOptions.builder().build())
  23361. + .setSearcherOptions(SearcherOptions.builder().build());
  23362. + }
  23363. - abstract BaseOptions getBaseOptions();
  23364. + /** Builder for {@link TextSearcherOptions}. */
  23365. + @AutoValue.Builder
  23366. + public abstract static class Builder {
  23367. + /** Sets the general options to configure Task APIs, such as accelerators. */
  23368. + public abstract Builder setBaseOptions(BaseOptions baseOptions);
  23369. - abstract SearcherOptions getSearcherOptions();
  23370. + /** Sets the options to configure Searcher API. */
  23371. + public abstract Builder setSearcherOptions(SearcherOptions searcherOptions);
  23372. - public static Builder builder() {
  23373. - return new AutoValue_TextSearcher_TextSearcherOptions.Builder()
  23374. - .setBaseOptions(BaseOptions.builder().build())
  23375. - .setSearcherOptions(SearcherOptions.builder().build());
  23376. + public abstract TextSearcherOptions build();
  23377. + }
  23378. }
  23379. - /** Builder for {@link TextSearcherOptions}. */
  23380. - @AutoValue.Builder
  23381. - public abstract static class Builder {
  23382. - /** Sets the general options to configure Task APIs, such as accelerators. */
  23383. - public abstract Builder setBaseOptions(BaseOptions baseOptions);
  23384. -
  23385. - /** Sets the options to configure Searcher API. */
  23386. - public abstract Builder setSearcherOptions(SearcherOptions searcherOptions);
  23387. -
  23388. - public abstract TextSearcherOptions build();
  23389. + /**
  23390. + * Performs embedding extraction on the provided string input, followed by nearest-neighbor
  23391. + * search in the index.
  23392. + *
  23393. + * @param text input text query to the model
  23394. + */
  23395. + public List<NearestNeighbor> search(String text) {
  23396. + return searchNative(getNativeHandle(), text);
  23397. }
  23398. - }
  23399. -
  23400. - /**
  23401. - * Performs embedding extraction on the provided string input, followed by nearest-neighbor search
  23402. - * in the index.
  23403. - *
  23404. - * @param text input text query to the model
  23405. - */
  23406. - public List<NearestNeighbor> search(String text) {
  23407. - return searchNative(getNativeHandle(), text);
  23408. - }
  23409. - private static TextSearcher createFromModelFdAndOptions(
  23410. - final int modelDescriptor,
  23411. - final long modelDescriptorLength,
  23412. - final long modelDescriptorOffset,
  23413. - final TextSearcherOptions options)
  23414. - throws IOException {
  23415. - if (options.getSearcherOptions().getIndexFile() != null) {
  23416. - // indexDescriptor must be alive before TextSearcher is initialized completely in the native
  23417. - // layer.
  23418. - try (ParcelFileDescriptor indexDescriptor =
  23419. - ParcelFileDescriptor.open(
  23420. - options.getSearcherOptions().getIndexFile(), ParcelFileDescriptor.MODE_READ_ONLY)) {
  23421. - return createFromModelFdAndOptionsImpl(
  23422. - modelDescriptor,
  23423. - modelDescriptorLength,
  23424. - modelDescriptorOffset,
  23425. - options,
  23426. - indexDescriptor.getFd());
  23427. - }
  23428. - } else {
  23429. - // Index file is not configured. We'll check if the model contains one in the native layer.
  23430. - return createFromModelFdAndOptionsImpl(
  23431. - modelDescriptor, modelDescriptorLength, modelDescriptorOffset, options, /*indexFd=*/ 0);
  23432. + private static TextSearcher createFromModelFdAndOptions(final int modelDescriptor,
  23433. + final long modelDescriptorLength, final long modelDescriptorOffset,
  23434. + final TextSearcherOptions options) throws IOException {
  23435. + if (options.getSearcherOptions().getIndexFile() != null) {
  23436. + // indexDescriptor must be alive before TextSearcher is initialized completely in the
  23437. + // native layer.
  23438. + try (ParcelFileDescriptor indexDescriptor =
  23439. + ParcelFileDescriptor.open(options.getSearcherOptions().getIndexFile(),
  23440. + ParcelFileDescriptor.MODE_READ_ONLY)) {
  23441. + return createFromModelFdAndOptionsImpl(modelDescriptor, modelDescriptorLength,
  23442. + modelDescriptorOffset, options, indexDescriptor.getFd());
  23443. + }
  23444. + } else {
  23445. + // Index file is not configured. We'll check if the model contains one in the native
  23446. + // layer.
  23447. + return createFromModelFdAndOptionsImpl(modelDescriptor, modelDescriptorLength,
  23448. + modelDescriptorOffset, options, /*indexFd=*/0);
  23449. + }
  23450. }
  23451. - }
  23452. - private static TextSearcher createFromModelFdAndOptionsImpl(
  23453. - final int modelDescriptor,
  23454. - final long modelDescriptorLength,
  23455. - final long modelDescriptorOffset,
  23456. - final TextSearcherOptions options,
  23457. - final int indexFd) {
  23458. - long nativeHandle =
  23459. - TaskJniUtils.createHandleFromLibrary(
  23460. - new EmptyHandleProvider() {
  23461. - @Override
  23462. - public long createHandle() {
  23463. - return initJniWithModelFdAndOptions(
  23464. - modelDescriptor,
  23465. - modelDescriptorLength,
  23466. - modelDescriptorOffset,
  23467. - TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions()),
  23468. - options.getSearcherOptions().getL2Normalize(),
  23469. - options.getSearcherOptions().getQuantize(),
  23470. - indexFd,
  23471. - options.getSearcherOptions().getMaxResults());
  23472. - }
  23473. - },
  23474. - TEXT_SEARCHER_NATIVE_LIB);
  23475. - return new TextSearcher(nativeHandle);
  23476. - }
  23477. + private static TextSearcher createFromModelFdAndOptionsImpl(final int modelDescriptor,
  23478. + final long modelDescriptorLength, final long modelDescriptorOffset,
  23479. + final TextSearcherOptions options, final int indexFd) {
  23480. + long nativeHandle = TaskJniUtils.createHandleFromLibrary(new EmptyHandleProvider() {
  23481. + @Override
  23482. + public long createHandle() {
  23483. + return initJniWithModelFdAndOptions(modelDescriptor, modelDescriptorLength,
  23484. + modelDescriptorOffset,
  23485. + TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions()),
  23486. + options.getSearcherOptions().getL2Normalize(),
  23487. + options.getSearcherOptions().getQuantize(), indexFd,
  23488. + options.getSearcherOptions().getMaxResults());
  23489. + }
  23490. + }, TEXT_SEARCHER_NATIVE_LIB);
  23491. + return new TextSearcher(nativeHandle);
  23492. + }
  23493. - private static native long initJniWithModelFdAndOptions(
  23494. - int modelDescriptor,
  23495. - long modelDescriptorLength,
  23496. - long modelDescriptorOffset,
  23497. - long baseOptionsHandle,
  23498. - boolean l2Normalize,
  23499. - boolean quantize,
  23500. - int indexDescriptor,
  23501. - int maxResults);
  23502. + private static native long initJniWithModelFdAndOptions(int modelDescriptor,
  23503. + long modelDescriptorLength, long modelDescriptorOffset, long baseOptionsHandle,
  23504. + boolean l2Normalize, boolean quantize, int indexDescriptor, int maxResults);
  23505. - private static native long initJniWithByteBuffer(
  23506. - ByteBuffer modelBuffer,
  23507. - long baseOptionsHandle,
  23508. - boolean l2Normalize,
  23509. - boolean quantize,
  23510. - int indexFileDescriptor,
  23511. - int maxResults);
  23512. + private static native long initJniWithByteBuffer(ByteBuffer modelBuffer, long baseOptionsHandle,
  23513. + boolean l2Normalize, boolean quantize, int indexFileDescriptor, int maxResults);
  23514. - /** The native method to search an input text string. */
  23515. - private static native List<NearestNeighbor> searchNative(long nativeHandle, String text);
  23516. + /** The native method to search an input text string. */
  23517. + private static native List<NearestNeighbor> searchNative(long nativeHandle, String text);
  23518. - @Override
  23519. - protected void deinit(long nativeHandle) {
  23520. - deinitJni(nativeHandle);
  23521. - }
  23522. + @Override
  23523. + protected void deinit(long nativeHandle) {
  23524. + deinitJni(nativeHandle);
  23525. + }
  23526. - /**
  23527. - * Native implementation to release memory pointed by the pointer.
  23528. - *
  23529. - * @param nativeHandle pointer to memory allocated
  23530. - */
  23531. - private native void deinitJni(long nativeHandle);
  23532. + /**
  23533. + * Native implementation to release memory pointed by the pointer.
  23534. + *
  23535. + * @param nativeHandle pointer to memory allocated
  23536. + */
  23537. + private native void deinitJni(long nativeHandle);
  23538. }
  23539. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/Classifications.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/Classifications.java
  23540. index 88aeecc8d62ca..e59a2e89e86f4 100644
  23541. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/Classifications.java
  23542. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/Classifications.java
  23543. @@ -16,11 +16,13 @@ limitations under the License.
  23544. package org.tensorflow.lite.task.vision.classifier;
  23545. import com.google.auto.value.AutoValue;
  23546. +
  23547. +import org.tensorflow.lite.support.label.Category;
  23548. +import org.tensorflow.lite.task.core.annotations.UsedByReflection;
  23549. +
  23550. import java.util.ArrayList;
  23551. import java.util.Collections;
  23552. import java.util.List;
  23553. -import org.tensorflow.lite.support.label.Category;
  23554. -import org.tensorflow.lite.task.core.annotations.UsedByReflection;
  23555. /**
  23556. * The classification results of one head in a multihead (a.k.a. multi-output) {@link
  23557. @@ -31,16 +33,15 @@ import org.tensorflow.lite.task.core.annotations.UsedByReflection;
  23558. @AutoValue
  23559. @UsedByReflection("image_classifier_jni.cc")
  23560. public abstract class Classifications {
  23561. + @UsedByReflection("image_classifier_jni.cc")
  23562. + static Classifications create(List<Category> categories, int headIndex) {
  23563. + return new AutoValue_Classifications(
  23564. + Collections.unmodifiableList(new ArrayList<Category>(categories)), headIndex);
  23565. + }
  23566. - @UsedByReflection("image_classifier_jni.cc")
  23567. - static Classifications create(List<Category> categories, int headIndex) {
  23568. - return new AutoValue_Classifications(
  23569. - Collections.unmodifiableList(new ArrayList<Category>(categories)), headIndex);
  23570. - }
  23571. -
  23572. - // Same reason for not using ImmutableList as stated in
  23573. - // {@link ImageClassifier#ImageClassifierOptions#labelAllowList}.
  23574. - public abstract List<Category> getCategories();
  23575. + // Same reason for not using ImmutableList as stated in
  23576. + // {@link ImageClassifier#ImageClassifierOptions#labelAllowList}.
  23577. + public abstract List<Category> getCategories();
  23578. - public abstract int getHeadIndex();
  23579. + public abstract int getHeadIndex();
  23580. }
  23581. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/ImageClassifier.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/ImageClassifier.java
  23582. index 90628928198d5..5b5be73bcca1e 100644
  23583. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/ImageClassifier.java
  23584. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/ImageClassifier.java
  23585. @@ -18,14 +18,9 @@ package org.tensorflow.lite.task.vision.classifier;
  23586. import android.content.Context;
  23587. import android.graphics.Rect;
  23588. import android.os.ParcelFileDescriptor;
  23589. +
  23590. import com.google.android.odml.image.MlImage;
  23591. -import java.io.File;
  23592. -import java.io.IOException;
  23593. -import java.nio.ByteBuffer;
  23594. -import java.nio.MappedByteBuffer;
  23595. -import java.util.ArrayList;
  23596. -import java.util.Collections;
  23597. -import java.util.List;
  23598. +
  23599. import org.tensorflow.lite.support.image.MlImageAdapter;
  23600. import org.tensorflow.lite.support.image.TensorImage;
  23601. import org.tensorflow.lite.task.core.BaseOptions;
  23602. @@ -37,6 +32,14 @@ import org.tensorflow.lite.task.core.vision.ImageProcessingOptions;
  23603. import org.tensorflow.lite.task.vision.core.BaseVisionTaskApi;
  23604. import org.tensorflow.lite.task.vision.core.BaseVisionTaskApi.InferenceProvider;
  23605. +import java.io.File;
  23606. +import java.io.IOException;
  23607. +import java.nio.ByteBuffer;
  23608. +import java.nio.MappedByteBuffer;
  23609. +import java.util.ArrayList;
  23610. +import java.util.Collections;
  23611. +import java.util.List;
  23612. +
  23613. /**
  23614. * Performs classification on images.
  23615. *
  23616. @@ -71,476 +74,449 @@ import org.tensorflow.lite.task.vision.core.BaseVisionTaskApi.InferenceProvider;
  23617. * Hub.</a>.
  23618. */
  23619. public final class ImageClassifier extends BaseVisionTaskApi {
  23620. + private static final String IMAGE_CLASSIFIER_NATIVE_LIB = "task_vision_jni";
  23621. + private static final int OPTIONAL_FD_LENGTH = -1;
  23622. + private static final int OPTIONAL_FD_OFFSET = -1;
  23623. +
  23624. + /**
  23625. + * Creates an {@link ImageClassifier} instance from the default {@link ImageClassifierOptions}.
  23626. + *
  23627. + * @param modelPath path of the classification model with metadata in the assets
  23628. + * @throws IOException if an I/O error occurs when loading the tflite model
  23629. + * @throws IllegalArgumentException if an argument is invalid
  23630. + * @throws IllegalStateException if there is an internal error
  23631. + * @throws RuntimeException if there is an otherwise unspecified error
  23632. + */
  23633. + public static ImageClassifier createFromFile(Context context, String modelPath)
  23634. + throws IOException {
  23635. + return createFromFileAndOptions(
  23636. + context, modelPath, ImageClassifierOptions.builder().build());
  23637. + }
  23638. - private static final String IMAGE_CLASSIFIER_NATIVE_LIB = "task_vision_jni";
  23639. - private static final int OPTIONAL_FD_LENGTH = -1;
  23640. - private static final int OPTIONAL_FD_OFFSET = -1;
  23641. -
  23642. - /**
  23643. - * Creates an {@link ImageClassifier} instance from the default {@link ImageClassifierOptions}.
  23644. - *
  23645. - * @param modelPath path of the classification model with metadata in the assets
  23646. - * @throws IOException if an I/O error occurs when loading the tflite model
  23647. - * @throws IllegalArgumentException if an argument is invalid
  23648. - * @throws IllegalStateException if there is an internal error
  23649. - * @throws RuntimeException if there is an otherwise unspecified error
  23650. - */
  23651. - public static ImageClassifier createFromFile(Context context, String modelPath)
  23652. - throws IOException {
  23653. - return createFromFileAndOptions(context, modelPath, ImageClassifierOptions.builder().build());
  23654. - }
  23655. -
  23656. - /**
  23657. - * Creates an {@link ImageClassifier} instance from the default {@link ImageClassifierOptions}.
  23658. - *
  23659. - * @param modelFile the classification model {@link File} instance
  23660. - * @throws IOException if an I/O error occurs when loading the tflite model
  23661. - * @throws IllegalArgumentException if an argument is invalid
  23662. - * @throws IllegalStateException if there is an internal error
  23663. - * @throws RuntimeException if there is an otherwise unspecified error
  23664. - */
  23665. - public static ImageClassifier createFromFile(File modelFile) throws IOException {
  23666. - return createFromFileAndOptions(modelFile, ImageClassifierOptions.builder().build());
  23667. - }
  23668. -
  23669. - /**
  23670. - * Creates an {@link ImageClassifier} instance with a model buffer and the default {@link
  23671. - * ImageClassifierOptions}.
  23672. - *
  23673. - * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the
  23674. - * classification model
  23675. - * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
  23676. - * {@link MappedByteBuffer}
  23677. - * @throws IllegalStateException if there is an internal error
  23678. - * @throws RuntimeException if there is an otherwise unspecified error
  23679. - */
  23680. - public static ImageClassifier createFromBuffer(final ByteBuffer modelBuffer) {
  23681. - return createFromBufferAndOptions(modelBuffer, ImageClassifierOptions.builder().build());
  23682. - }
  23683. -
  23684. - /**
  23685. - * Creates an {@link ImageClassifier} instance from {@link ImageClassifierOptions}.
  23686. - *
  23687. - * @param modelPath path of the classification model with metadata in the assets
  23688. - * @throws IOException if an I/O error occurs when loading the tflite model
  23689. - * @throws IllegalArgumentException if an argument is invalid
  23690. - * @throws IllegalStateException if there is an internal error
  23691. - * @throws RuntimeException if there is an otherwise unspecified error
  23692. - */
  23693. - public static ImageClassifier createFromFileAndOptions(
  23694. - Context context, String modelPath, ImageClassifierOptions options) throws IOException {
  23695. - return new ImageClassifier(
  23696. - TaskJniUtils.createHandleFromFdAndOptions(
  23697. - context,
  23698. - new FdAndOptionsHandleProvider<ImageClassifierOptions>() {
  23699. - @Override
  23700. - public long createHandle(
  23701. - int fileDescriptor,
  23702. - long fileDescriptorLength,
  23703. - long fileDescriptorOffset,
  23704. - ImageClassifierOptions options) {
  23705. - return initJniWithModelFdAndOptions(
  23706. - fileDescriptor,
  23707. - fileDescriptorLength,
  23708. - fileDescriptorOffset,
  23709. - options,
  23710. - TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads(
  23711. - options.getBaseOptions(), options.getNumThreads()));
  23712. - }
  23713. - },
  23714. - IMAGE_CLASSIFIER_NATIVE_LIB,
  23715. - modelPath,
  23716. - options));
  23717. - }
  23718. -
  23719. - /**
  23720. - * Creates an {@link ImageClassifier} instance.
  23721. - *
  23722. - * @param modelFile the classification model {@link File} instance
  23723. - * @throws IOException if an I/O error occurs when loading the tflite model
  23724. - * @throws IllegalArgumentException if an argument is invalid
  23725. - * @throws IllegalStateException if there is an internal error
  23726. - * @throws RuntimeException if there is an otherwise unspecified error
  23727. - */
  23728. - public static ImageClassifier createFromFileAndOptions(
  23729. - File modelFile, final ImageClassifierOptions options) throws IOException {
  23730. - try (ParcelFileDescriptor descriptor =
  23731. - ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
  23732. - return new ImageClassifier(
  23733. - TaskJniUtils.createHandleFromLibrary(
  23734. - new TaskJniUtils.EmptyHandleProvider() {
  23735. - @Override
  23736. - public long createHandle() {
  23737. - return initJniWithModelFdAndOptions(
  23738. - descriptor.getFd(),
  23739. - /*fileDescriptorLength=*/ OPTIONAL_FD_LENGTH,
  23740. - /*fileDescriptorOffset=*/ OPTIONAL_FD_OFFSET,
  23741. - options,
  23742. - TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads(
  23743. - options.getBaseOptions(), options.getNumThreads()));
  23744. - }
  23745. - },
  23746. - IMAGE_CLASSIFIER_NATIVE_LIB));
  23747. + /**
  23748. + * Creates an {@link ImageClassifier} instance from the default {@link ImageClassifierOptions}.
  23749. + *
  23750. + * @param modelFile the classification model {@link File} instance
  23751. + * @throws IOException if an I/O error occurs when loading the tflite model
  23752. + * @throws IllegalArgumentException if an argument is invalid
  23753. + * @throws IllegalStateException if there is an internal error
  23754. + * @throws RuntimeException if there is an otherwise unspecified error
  23755. + */
  23756. + public static ImageClassifier createFromFile(File modelFile) throws IOException {
  23757. + return createFromFileAndOptions(modelFile, ImageClassifierOptions.builder().build());
  23758. }
  23759. - }
  23760. -
  23761. - /**
  23762. - * Creates an {@link ImageClassifier} instance with a model buffer and {@link
  23763. - * ImageClassifierOptions}.
  23764. - *
  23765. - * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the
  23766. - * classification model
  23767. - * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
  23768. - * {@link MappedByteBuffer}
  23769. - * @throws IllegalStateException if there is an internal error
  23770. - * @throws RuntimeException if there is an otherwise unspecified error
  23771. - */
  23772. - public static ImageClassifier createFromBufferAndOptions(
  23773. - final ByteBuffer modelBuffer, final ImageClassifierOptions options) {
  23774. - if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) {
  23775. - throw new IllegalArgumentException(
  23776. - "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer.");
  23777. +
  23778. + /**
  23779. + * Creates an {@link ImageClassifier} instance with a model buffer and the default {@link
  23780. + * ImageClassifierOptions}.
  23781. + *
  23782. + * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the
  23783. + * classification model
  23784. + * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
  23785. + * {@link MappedByteBuffer}
  23786. + * @throws IllegalStateException if there is an internal error
  23787. + * @throws RuntimeException if there is an otherwise unspecified error
  23788. + */
  23789. + public static ImageClassifier createFromBuffer(final ByteBuffer modelBuffer) {
  23790. + return createFromBufferAndOptions(modelBuffer, ImageClassifierOptions.builder().build());
  23791. }
  23792. - return new ImageClassifier(
  23793. - TaskJniUtils.createHandleFromLibrary(
  23794. - new EmptyHandleProvider() {
  23795. - @Override
  23796. - public long createHandle() {
  23797. - return initJniWithByteBuffer(
  23798. - modelBuffer,
  23799. - options,
  23800. - TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads(
  23801. - options.getBaseOptions(), options.getNumThreads()));
  23802. - }
  23803. - },
  23804. - IMAGE_CLASSIFIER_NATIVE_LIB));
  23805. - }
  23806. -
  23807. - /**
  23808. - * Constructor to initialize the JNI with a pointer from C++.
  23809. - *
  23810. - * @param nativeHandle a pointer referencing memory allocated in C++
  23811. - */
  23812. - ImageClassifier(long nativeHandle) {
  23813. - super(nativeHandle);
  23814. - }
  23815. -
  23816. - /** Options for setting up an ImageClassifier. */
  23817. - @UsedByReflection("image_classifier_jni.cc")
  23818. - public static class ImageClassifierOptions {
  23819. - // Not using AutoValue for this class because scoreThreshold cannot have default value
  23820. - // (otherwise, the default value would override the one in the model metadata) and `Optional` is
  23821. - // not an option here, because
  23822. - // 1. java.util.Optional require Java 8 while we need to support Java 7.
  23823. - // 2. The Guava library (com.google.common.base.Optional) is avoided in this project. See the
  23824. - // comments for labelAllowList.
  23825. - private final BaseOptions baseOptions;
  23826. - private final String displayNamesLocale;
  23827. - private final int maxResults;
  23828. - private final float scoreThreshold;
  23829. - private final boolean isScoreThresholdSet;
  23830. - // As an open source project, we've been trying avoiding depending on common java libraries,
  23831. - // such as Guava, because it may introduce conflicts with clients who also happen to use those
  23832. - // libraries. Therefore, instead of using ImmutableList here, we convert the List into
  23833. - // unmodifiableList in setLabelAllowList() and setLabelDenyList() to make it less
  23834. - // vulnerable.
  23835. - private final List<String> labelAllowList;
  23836. - private final List<String> labelDenyList;
  23837. - private final int numThreads;
  23838. -
  23839. - public static Builder builder() {
  23840. - return new Builder();
  23841. +
  23842. + /**
  23843. + * Creates an {@link ImageClassifier} instance from {@link ImageClassifierOptions}.
  23844. + *
  23845. + * @param modelPath path of the classification model with metadata in the assets
  23846. + * @throws IOException if an I/O error occurs when loading the tflite model
  23847. + * @throws IllegalArgumentException if an argument is invalid
  23848. + * @throws IllegalStateException if there is an internal error
  23849. + * @throws RuntimeException if there is an otherwise unspecified error
  23850. + */
  23851. + public static ImageClassifier createFromFileAndOptions(
  23852. + Context context, String modelPath, ImageClassifierOptions options) throws IOException {
  23853. + return new ImageClassifier(TaskJniUtils.createHandleFromFdAndOptions(
  23854. + context, new FdAndOptionsHandleProvider<ImageClassifierOptions>() {
  23855. + @Override
  23856. + public long createHandle(int fileDescriptor, long fileDescriptorLength,
  23857. + long fileDescriptorOffset, ImageClassifierOptions options) {
  23858. + return initJniWithModelFdAndOptions(fileDescriptor, fileDescriptorLength,
  23859. + fileDescriptorOffset, options,
  23860. + TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads(
  23861. + options.getBaseOptions(), options.getNumThreads()));
  23862. + }
  23863. + }, IMAGE_CLASSIFIER_NATIVE_LIB, modelPath, options));
  23864. }
  23865. - /** A builder that helps to configure an instance of ImageClassifierOptions. */
  23866. - public static class Builder {
  23867. - private BaseOptions baseOptions = BaseOptions.builder().build();
  23868. - private String displayNamesLocale = "en";
  23869. - private int maxResults = -1;
  23870. - private float scoreThreshold;
  23871. - private boolean isScoreThresholdSet = false;
  23872. - private List<String> labelAllowList = new ArrayList<>();
  23873. - private List<String> labelDenyList = new ArrayList<>();
  23874. - private int numThreads = -1;
  23875. -
  23876. - Builder() {}
  23877. -
  23878. - /** Sets the general options to configure Task APIs, such as accelerators. */
  23879. - public Builder setBaseOptions(BaseOptions baseOptions) {
  23880. - this.baseOptions = baseOptions;
  23881. - return this;
  23882. - }
  23883. -
  23884. - /**
  23885. - * Sets the locale to use for display names specified through the TFLite Model Metadata, if
  23886. - * any.
  23887. - *
  23888. - * <p>Defaults to English({@code "en"}). See the <a
  23889. - * href="https://github.com/tensorflow/tflite-support/blob/3ce83f0cfe2c68fecf83e019f2acc354aaba471f/tensorflow_lite_support/metadata/metadata_schema.fbs#L147">TFLite
  23890. - * Metadata schema file.</a> for the accepted pattern of locale.
  23891. - */
  23892. - public Builder setDisplayNamesLocale(String displayNamesLocale) {
  23893. - this.displayNamesLocale = displayNamesLocale;
  23894. - return this;
  23895. - }
  23896. -
  23897. - /**
  23898. - * Sets the maximum number of top scored results to return.
  23899. - *
  23900. - * <p>If < 0, all results will be returned. If 0, an invalid argument error is returned.
  23901. - * Defaults to -1.
  23902. - *
  23903. - * @throws IllegalArgumentException if maxResults is 0.
  23904. - */
  23905. - public Builder setMaxResults(int maxResults) {
  23906. - if (maxResults == 0) {
  23907. - throw new IllegalArgumentException("maxResults cannot be 0.");
  23908. + /**
  23909. + * Creates an {@link ImageClassifier} instance.
  23910. + *
  23911. + * @param modelFile the classification model {@link File} instance
  23912. + * @throws IOException if an I/O error occurs when loading the tflite model
  23913. + * @throws IllegalArgumentException if an argument is invalid
  23914. + * @throws IllegalStateException if there is an internal error
  23915. + * @throws RuntimeException if there is an otherwise unspecified error
  23916. + */
  23917. + public static ImageClassifier createFromFileAndOptions(
  23918. + File modelFile, final ImageClassifierOptions options) throws IOException {
  23919. + try (ParcelFileDescriptor descriptor =
  23920. + ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
  23921. + return new ImageClassifier(
  23922. + TaskJniUtils.createHandleFromLibrary(new TaskJniUtils.EmptyHandleProvider() {
  23923. + @Override
  23924. + public long createHandle() {
  23925. + return initJniWithModelFdAndOptions(descriptor.getFd(),
  23926. + /*fileDescriptorLength=*/OPTIONAL_FD_LENGTH,
  23927. + /*fileDescriptorOffset=*/OPTIONAL_FD_OFFSET, options,
  23928. + TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads(
  23929. + options.getBaseOptions(), options.getNumThreads()));
  23930. + }
  23931. + }, IMAGE_CLASSIFIER_NATIVE_LIB));
  23932. }
  23933. - this.maxResults = maxResults;
  23934. - return this;
  23935. - }
  23936. -
  23937. - /**
  23938. - * Sets the score threshold.
  23939. - *
  23940. - * <p>It overrides the one provided in the model metadata (if any). Results below this value
  23941. - * are rejected.
  23942. - */
  23943. - public Builder setScoreThreshold(float scoreThreshold) {
  23944. - this.scoreThreshold = scoreThreshold;
  23945. - isScoreThresholdSet = true;
  23946. - return this;
  23947. - }
  23948. -
  23949. - /**
  23950. - * Sets the optional allowlist of labels.
  23951. - *
  23952. - * <p>If non-empty, classifications whose label is not in this set will be filtered out.
  23953. - * Duplicate or unknown labels are ignored. Mutually exclusive with labelDenyList.
  23954. - */
  23955. - public Builder setLabelAllowList(List<String> labelAllowList) {
  23956. - this.labelAllowList = Collections.unmodifiableList(new ArrayList<>(labelAllowList));
  23957. - return this;
  23958. - }
  23959. -
  23960. - /**
  23961. - * Sets the optional denylist of labels.
  23962. - *
  23963. - * <p>If non-empty, classifications whose label is in this set will be filtered out. Duplicate
  23964. - * or unknown labels are ignored. Mutually exclusive with labelAllowList.
  23965. - */
  23966. - public Builder setLabelDenyList(List<String> labelDenyList) {
  23967. - this.labelDenyList = Collections.unmodifiableList(new ArrayList<>(labelDenyList));
  23968. - return this;
  23969. - }
  23970. -
  23971. - /**
  23972. - * Sets the number of threads to be used for TFLite ops that support multi-threading when
  23973. - * running inference with CPU. Defaults to -1.
  23974. - *
  23975. - * <p>numThreads should be greater than 0 or equal to -1. Setting numThreads to -1 has the
  23976. - * effect to let TFLite runtime set the value.
  23977. - *
  23978. - * @deprecated use {@link BaseOptions} to configure number of threads instead. This method
  23979. - * will override the number of threads configured from {@link BaseOptions}.
  23980. - */
  23981. - @Deprecated
  23982. - public Builder setNumThreads(int numThreads) {
  23983. - this.numThreads = numThreads;
  23984. - return this;
  23985. - }
  23986. -
  23987. - public ImageClassifierOptions build() {
  23988. - return new ImageClassifierOptions(this);
  23989. - }
  23990. }
  23991. - @UsedByReflection("image_classifier_jni.cc")
  23992. - public String getDisplayNamesLocale() {
  23993. - return displayNamesLocale;
  23994. + /**
  23995. + * Creates an {@link ImageClassifier} instance with a model buffer and {@link
  23996. + * ImageClassifierOptions}.
  23997. + *
  23998. + * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the
  23999. + * classification model
  24000. + * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
  24001. + * {@link MappedByteBuffer}
  24002. + * @throws IllegalStateException if there is an internal error
  24003. + * @throws RuntimeException if there is an otherwise unspecified error
  24004. + */
  24005. + public static ImageClassifier createFromBufferAndOptions(
  24006. + final ByteBuffer modelBuffer, final ImageClassifierOptions options) {
  24007. + if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) {
  24008. + throw new IllegalArgumentException(
  24009. + "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer.");
  24010. + }
  24011. + return new ImageClassifier(TaskJniUtils.createHandleFromLibrary(new EmptyHandleProvider() {
  24012. + @Override
  24013. + public long createHandle() {
  24014. + return initJniWithByteBuffer(modelBuffer, options,
  24015. + TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads(
  24016. + options.getBaseOptions(), options.getNumThreads()));
  24017. + }
  24018. + }, IMAGE_CLASSIFIER_NATIVE_LIB));
  24019. }
  24020. - @UsedByReflection("image_classifier_jni.cc")
  24021. - public int getMaxResults() {
  24022. - return maxResults;
  24023. + /**
  24024. + * Constructor to initialize the JNI with a pointer from C++.
  24025. + *
  24026. + * @param nativeHandle a pointer referencing memory allocated in C++
  24027. + */
  24028. + ImageClassifier(long nativeHandle) {
  24029. + super(nativeHandle);
  24030. }
  24031. + /** Options for setting up an ImageClassifier. */
  24032. @UsedByReflection("image_classifier_jni.cc")
  24033. - public float getScoreThreshold() {
  24034. - return scoreThreshold;
  24035. + public static class ImageClassifierOptions {
  24036. + // Not using AutoValue for this class because scoreThreshold cannot have default value
  24037. + // (otherwise, the default value would override the one in the model metadata) and
  24038. + // `Optional` is not an option here, because
  24039. + // 1. java.util.Optional require Java 8 while we need to support Java 7.
  24040. + // 2. The Guava library (com.google.common.base.Optional) is avoided in this project. See
  24041. + // the comments for labelAllowList.
  24042. + private final BaseOptions baseOptions;
  24043. + private final String displayNamesLocale;
  24044. + private final int maxResults;
  24045. + private final float scoreThreshold;
  24046. + private final boolean isScoreThresholdSet;
  24047. + // As an open source project, we've been trying avoiding depending on common java libraries,
  24048. + // such as Guava, because it may introduce conflicts with clients who also happen to use
  24049. + // those libraries. Therefore, instead of using ImmutableList here, we convert the List into
  24050. + // unmodifiableList in setLabelAllowList() and setLabelDenyList() to make it less
  24051. + // vulnerable.
  24052. + private final List<String> labelAllowList;
  24053. + private final List<String> labelDenyList;
  24054. + private final int numThreads;
  24055. +
  24056. + public static Builder builder() {
  24057. + return new Builder();
  24058. + }
  24059. +
  24060. + /** A builder that helps to configure an instance of ImageClassifierOptions. */
  24061. + public static class Builder {
  24062. + private BaseOptions baseOptions = BaseOptions.builder().build();
  24063. + private String displayNamesLocale = "en";
  24064. + private int maxResults = -1;
  24065. + private float scoreThreshold;
  24066. + private boolean isScoreThresholdSet = false;
  24067. + private List<String> labelAllowList = new ArrayList<>();
  24068. + private List<String> labelDenyList = new ArrayList<>();
  24069. + private int numThreads = -1;
  24070. +
  24071. + Builder() {}
  24072. +
  24073. + /** Sets the general options to configure Task APIs, such as accelerators. */
  24074. + public Builder setBaseOptions(BaseOptions baseOptions) {
  24075. + this.baseOptions = baseOptions;
  24076. + return this;
  24077. + }
  24078. +
  24079. + /**
  24080. + * Sets the locale to use for display names specified through the TFLite Model Metadata,
  24081. + * if any.
  24082. + *
  24083. + * <p>Defaults to English({@code "en"}). See the <a
  24084. + * href="https://github.com/tensorflow/tflite-support/blob/3ce83f0cfe2c68fecf83e019f2acc354aaba471f/tensorflow_lite_support/metadata/metadata_schema.fbs#L147">TFLite
  24085. + * Metadata schema file.</a> for the accepted pattern of locale.
  24086. + */
  24087. + public Builder setDisplayNamesLocale(String displayNamesLocale) {
  24088. + this.displayNamesLocale = displayNamesLocale;
  24089. + return this;
  24090. + }
  24091. +
  24092. + /**
  24093. + * Sets the maximum number of top scored results to return.
  24094. + *
  24095. + * <p>If < 0, all results will be returned. If 0, an invalid argument error is returned.
  24096. + * Defaults to -1.
  24097. + *
  24098. + * @throws IllegalArgumentException if maxResults is 0.
  24099. + */
  24100. + public Builder setMaxResults(int maxResults) {
  24101. + if (maxResults == 0) {
  24102. + throw new IllegalArgumentException("maxResults cannot be 0.");
  24103. + }
  24104. + this.maxResults = maxResults;
  24105. + return this;
  24106. + }
  24107. +
  24108. + /**
  24109. + * Sets the score threshold.
  24110. + *
  24111. + * <p>It overrides the one provided in the model metadata (if any). Results below this
  24112. + * value are rejected.
  24113. + */
  24114. + public Builder setScoreThreshold(float scoreThreshold) {
  24115. + this.scoreThreshold = scoreThreshold;
  24116. + isScoreThresholdSet = true;
  24117. + return this;
  24118. + }
  24119. +
  24120. + /**
  24121. + * Sets the optional allowlist of labels.
  24122. + *
  24123. + * <p>If non-empty, classifications whose label is not in this set will be filtered out.
  24124. + * Duplicate or unknown labels are ignored. Mutually exclusive with labelDenyList.
  24125. + */
  24126. + public Builder setLabelAllowList(List<String> labelAllowList) {
  24127. + this.labelAllowList = Collections.unmodifiableList(new ArrayList<>(labelAllowList));
  24128. + return this;
  24129. + }
  24130. +
  24131. + /**
  24132. + * Sets the optional denylist of labels.
  24133. + *
  24134. + * <p>If non-empty, classifications whose label is in this set will be filtered out.
  24135. + * Duplicate or unknown labels are ignored. Mutually exclusive with labelAllowList.
  24136. + */
  24137. + public Builder setLabelDenyList(List<String> labelDenyList) {
  24138. + this.labelDenyList = Collections.unmodifiableList(new ArrayList<>(labelDenyList));
  24139. + return this;
  24140. + }
  24141. +
  24142. + /**
  24143. + * Sets the number of threads to be used for TFLite ops that support multi-threading
  24144. + * when running inference with CPU. Defaults to -1.
  24145. + *
  24146. + * <p>numThreads should be greater than 0 or equal to -1. Setting numThreads to -1 has
  24147. + * the effect to let TFLite runtime set the value.
  24148. + *
  24149. + * @deprecated use {@link BaseOptions} to configure number of threads instead. This
  24150. + * method
  24151. + * will override the number of threads configured from {@link BaseOptions}.
  24152. + */
  24153. + @Deprecated
  24154. + public Builder setNumThreads(int numThreads) {
  24155. + this.numThreads = numThreads;
  24156. + return this;
  24157. + }
  24158. +
  24159. + public ImageClassifierOptions build() {
  24160. + return new ImageClassifierOptions(this);
  24161. + }
  24162. + }
  24163. +
  24164. + @UsedByReflection("image_classifier_jni.cc")
  24165. + public String getDisplayNamesLocale() {
  24166. + return displayNamesLocale;
  24167. + }
  24168. +
  24169. + @UsedByReflection("image_classifier_jni.cc")
  24170. + public int getMaxResults() {
  24171. + return maxResults;
  24172. + }
  24173. +
  24174. + @UsedByReflection("image_classifier_jni.cc")
  24175. + public float getScoreThreshold() {
  24176. + return scoreThreshold;
  24177. + }
  24178. +
  24179. + @UsedByReflection("image_classifier_jni.cc")
  24180. + public boolean getIsScoreThresholdSet() {
  24181. + return isScoreThresholdSet;
  24182. + }
  24183. +
  24184. + @UsedByReflection("image_classifier_jni.cc")
  24185. + public List<String> getLabelAllowList() {
  24186. + return new ArrayList<>(labelAllowList);
  24187. + }
  24188. +
  24189. + @UsedByReflection("image_classifier_jni.cc")
  24190. + public List<String> getLabelDenyList() {
  24191. + return new ArrayList<>(labelDenyList);
  24192. + }
  24193. +
  24194. + @UsedByReflection("image_classifier_jni.cc")
  24195. + public int getNumThreads() {
  24196. + return numThreads;
  24197. + }
  24198. +
  24199. + public BaseOptions getBaseOptions() {
  24200. + return baseOptions;
  24201. + }
  24202. +
  24203. + ImageClassifierOptions(Builder builder) {
  24204. + displayNamesLocale = builder.displayNamesLocale;
  24205. + maxResults = builder.maxResults;
  24206. + scoreThreshold = builder.scoreThreshold;
  24207. + isScoreThresholdSet = builder.isScoreThresholdSet;
  24208. + labelAllowList = builder.labelAllowList;
  24209. + labelDenyList = builder.labelDenyList;
  24210. + numThreads = builder.numThreads;
  24211. + baseOptions = builder.baseOptions;
  24212. + }
  24213. }
  24214. - @UsedByReflection("image_classifier_jni.cc")
  24215. - public boolean getIsScoreThresholdSet() {
  24216. - return isScoreThresholdSet;
  24217. + /**
  24218. + * Performs actual classification on the provided {@link TensorImage}.
  24219. + *
  24220. + * <p>{@link ImageClassifier} supports the following {@link TensorImage} color space types:
  24221. + *
  24222. + * <ul>
  24223. + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#RGB}
  24224. + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV12}
  24225. + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV21}
  24226. + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV12}
  24227. + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV21}
  24228. + * </ul>
  24229. + *
  24230. + * @param image a UINT8 {@link TensorImage} object that represents an RGB or YUV image
  24231. + * @throws IllegalArgumentException if the color space type of image is unsupported
  24232. + */
  24233. + public List<Classifications> classify(TensorImage image) {
  24234. + return classify(image, ImageProcessingOptions.builder().build());
  24235. }
  24236. - @UsedByReflection("image_classifier_jni.cc")
  24237. - public List<String> getLabelAllowList() {
  24238. - return new ArrayList<>(labelAllowList);
  24239. + /**
  24240. + * Performs actual classification on the provided {@link TensorImage} with {@link
  24241. + * ImageProcessingOptions}.
  24242. + *
  24243. + * <p>{@link ImageClassifier} supports the following options:
  24244. + *
  24245. + * <ul>
  24246. + * <li>Region of interest (ROI) (through {@link ImageProcessingOptions.Builder#setRoi}). It
  24247. + * defaults to the entire image.
  24248. + * <li>image rotation (through {@link ImageProcessingOptions.Builder#setOrientation}). It
  24249. + * defaults to {@link ImageProcessingOptions.Orientation#TOP_LEFT}.
  24250. + * </ul>
  24251. + *
  24252. + * <p>{@link ImageClassifier} supports the following {@link TensorImage} color space types:
  24253. + *
  24254. + * <ul>
  24255. + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#RGB}
  24256. + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV12}
  24257. + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV21}
  24258. + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV12}
  24259. + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV21}
  24260. + * </ul>
  24261. + *
  24262. + * @param image a UINT8 {@link TensorImage} object that represents an RGB or YUV image
  24263. + * @throws IllegalArgumentException if the color space type of image is unsupported
  24264. + */
  24265. + public List<Classifications> classify(TensorImage image, ImageProcessingOptions options) {
  24266. + return run(new InferenceProvider<List<Classifications>>() {
  24267. + @Override
  24268. + public List<Classifications> run(
  24269. + long frameBufferHandle, int width, int height, ImageProcessingOptions options) {
  24270. + return classify(frameBufferHandle, width, height, options);
  24271. + }
  24272. + }, image, options);
  24273. }
  24274. - @UsedByReflection("image_classifier_jni.cc")
  24275. - public List<String> getLabelDenyList() {
  24276. - return new ArrayList<>(labelDenyList);
  24277. + /**
  24278. + * Performs actual classification on the provided {@code MlImage}.
  24279. + *
  24280. + * @param image an {@code MlImage} object that represents an image
  24281. + * @throws IllegalArgumentException if the storage type or format of the image is unsupported
  24282. + */
  24283. + public List<Classifications> classify(MlImage image) {
  24284. + return classify(image, ImageProcessingOptions.builder().build());
  24285. }
  24286. - @UsedByReflection("image_classifier_jni.cc")
  24287. - public int getNumThreads() {
  24288. - return numThreads;
  24289. + /**
  24290. + * Performs actual classification on the provided {@code MlImage} with {@link
  24291. + * ImageProcessingOptions}.
  24292. + *
  24293. + * <p>{@link ImageClassifier} supports the following options:
  24294. + *
  24295. + * <ul>
  24296. + * <li>Region of interest (ROI) (through {@link ImageProcessingOptions.Builder#setRoi}). It
  24297. + * defaults to the entire image.
  24298. + * <li>image rotation (through {@link ImageProcessingOptions.Builder#setOrientation}). It
  24299. + * defaults to {@link ImageProcessingOptions.Orientation#TOP_LEFT}. {@link
  24300. + * MlImage#getRotation()} is not effective.
  24301. + * </ul>
  24302. + *
  24303. + * @param image a {@code MlImage} object that represents an image
  24304. + * @param options configures options including ROI and rotation
  24305. + * @throws IllegalArgumentException if the storage type or format of the image is unsupported
  24306. + */
  24307. + public List<Classifications> classify(MlImage image, ImageProcessingOptions options) {
  24308. + image.getInternal().acquire();
  24309. + TensorImage tensorImage = MlImageAdapter.createTensorImageFrom(image);
  24310. + List<Classifications> result = classify(tensorImage, options);
  24311. + image.close();
  24312. + return result;
  24313. }
  24314. - public BaseOptions getBaseOptions() {
  24315. - return baseOptions;
  24316. + private List<Classifications> classify(
  24317. + long frameBufferHandle, int width, int height, ImageProcessingOptions options) {
  24318. + checkNotClosed();
  24319. +
  24320. + Rect roi = options.getRoi().isEmpty() ? new Rect(0, 0, width, height) : options.getRoi();
  24321. +
  24322. + return classifyNative(getNativeHandle(), frameBufferHandle,
  24323. + new int[] {roi.left, roi.top, roi.width(), roi.height()});
  24324. }
  24325. - ImageClassifierOptions(Builder builder) {
  24326. - displayNamesLocale = builder.displayNamesLocale;
  24327. - maxResults = builder.maxResults;
  24328. - scoreThreshold = builder.scoreThreshold;
  24329. - isScoreThresholdSet = builder.isScoreThresholdSet;
  24330. - labelAllowList = builder.labelAllowList;
  24331. - labelDenyList = builder.labelDenyList;
  24332. - numThreads = builder.numThreads;
  24333. - baseOptions = builder.baseOptions;
  24334. + private static native long initJniWithModelFdAndOptions(int fileDescriptor,
  24335. + long fileDescriptorLength, long fileDescriptorOffset, ImageClassifierOptions options,
  24336. + long baseOptionsHandle);
  24337. +
  24338. + private static native long initJniWithByteBuffer(
  24339. + ByteBuffer modelBuffer, ImageClassifierOptions options, long baseOptionsHandle);
  24340. +
  24341. + /**
  24342. + * The native method to classify an image with the ROI and orientation.
  24343. + *
  24344. + * @param roi the ROI of the input image, an array representing the bounding box as {left, top,
  24345. + * width, height}
  24346. + */
  24347. + private static native List<Classifications> classifyNative(
  24348. + long nativeHandle, long frameBufferHandle, int[] roi);
  24349. +
  24350. + @Override
  24351. + protected void deinit(long nativeHandle) {
  24352. + deinitJni(nativeHandle);
  24353. }
  24354. - }
  24355. -
  24356. - /**
  24357. - * Performs actual classification on the provided {@link TensorImage}.
  24358. - *
  24359. - * <p>{@link ImageClassifier} supports the following {@link TensorImage} color space types:
  24360. - *
  24361. - * <ul>
  24362. - * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#RGB}
  24363. - * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV12}
  24364. - * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV21}
  24365. - * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV12}
  24366. - * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV21}
  24367. - * </ul>
  24368. - *
  24369. - * @param image a UINT8 {@link TensorImage} object that represents an RGB or YUV image
  24370. - * @throws IllegalArgumentException if the color space type of image is unsupported
  24371. - */
  24372. - public List<Classifications> classify(TensorImage image) {
  24373. - return classify(image, ImageProcessingOptions.builder().build());
  24374. - }
  24375. -
  24376. - /**
  24377. - * Performs actual classification on the provided {@link TensorImage} with {@link
  24378. - * ImageProcessingOptions}.
  24379. - *
  24380. - * <p>{@link ImageClassifier} supports the following options:
  24381. - *
  24382. - * <ul>
  24383. - * <li>Region of interest (ROI) (through {@link ImageProcessingOptions.Builder#setRoi}). It
  24384. - * defaults to the entire image.
  24385. - * <li>image rotation (through {@link ImageProcessingOptions.Builder#setOrientation}). It
  24386. - * defaults to {@link ImageProcessingOptions.Orientation#TOP_LEFT}.
  24387. - * </ul>
  24388. - *
  24389. - * <p>{@link ImageClassifier} supports the following {@link TensorImage} color space types:
  24390. - *
  24391. - * <ul>
  24392. - * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#RGB}
  24393. - * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV12}
  24394. - * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV21}
  24395. - * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV12}
  24396. - * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV21}
  24397. - * </ul>
  24398. - *
  24399. - * @param image a UINT8 {@link TensorImage} object that represents an RGB or YUV image
  24400. - * @throws IllegalArgumentException if the color space type of image is unsupported
  24401. - */
  24402. - public List<Classifications> classify(TensorImage image, ImageProcessingOptions options) {
  24403. - return run(
  24404. - new InferenceProvider<List<Classifications>>() {
  24405. - @Override
  24406. - public List<Classifications> run(
  24407. - long frameBufferHandle, int width, int height, ImageProcessingOptions options) {
  24408. - return classify(frameBufferHandle, width, height, options);
  24409. - }
  24410. - },
  24411. - image,
  24412. - options);
  24413. - }
  24414. -
  24415. - /**
  24416. - * Performs actual classification on the provided {@code MlImage}.
  24417. - *
  24418. - * @param image an {@code MlImage} object that represents an image
  24419. - * @throws IllegalArgumentException if the storage type or format of the image is unsupported
  24420. - */
  24421. - public List<Classifications> classify(MlImage image) {
  24422. - return classify(image, ImageProcessingOptions.builder().build());
  24423. - }
  24424. -
  24425. - /**
  24426. - * Performs actual classification on the provided {@code MlImage} with {@link
  24427. - * ImageProcessingOptions}.
  24428. - *
  24429. - * <p>{@link ImageClassifier} supports the following options:
  24430. - *
  24431. - * <ul>
  24432. - * <li>Region of interest (ROI) (through {@link ImageProcessingOptions.Builder#setRoi}). It
  24433. - * defaults to the entire image.
  24434. - * <li>image rotation (through {@link ImageProcessingOptions.Builder#setOrientation}). It
  24435. - * defaults to {@link ImageProcessingOptions.Orientation#TOP_LEFT}. {@link
  24436. - * MlImage#getRotation()} is not effective.
  24437. - * </ul>
  24438. - *
  24439. - * @param image a {@code MlImage} object that represents an image
  24440. - * @param options configures options including ROI and rotation
  24441. - * @throws IllegalArgumentException if the storage type or format of the image is unsupported
  24442. - */
  24443. - public List<Classifications> classify(MlImage image, ImageProcessingOptions options) {
  24444. - image.getInternal().acquire();
  24445. - TensorImage tensorImage = MlImageAdapter.createTensorImageFrom(image);
  24446. - List<Classifications> result = classify(tensorImage, options);
  24447. - image.close();
  24448. - return result;
  24449. - }
  24450. -
  24451. - private List<Classifications> classify(
  24452. - long frameBufferHandle, int width, int height, ImageProcessingOptions options) {
  24453. - checkNotClosed();
  24454. -
  24455. - Rect roi = options.getRoi().isEmpty() ? new Rect(0, 0, width, height) : options.getRoi();
  24456. -
  24457. - return classifyNative(
  24458. - getNativeHandle(),
  24459. - frameBufferHandle,
  24460. - new int[] {roi.left, roi.top, roi.width(), roi.height()});
  24461. - }
  24462. -
  24463. - private static native long initJniWithModelFdAndOptions(
  24464. - int fileDescriptor,
  24465. - long fileDescriptorLength,
  24466. - long fileDescriptorOffset,
  24467. - ImageClassifierOptions options,
  24468. - long baseOptionsHandle);
  24469. -
  24470. - private static native long initJniWithByteBuffer(
  24471. - ByteBuffer modelBuffer, ImageClassifierOptions options, long baseOptionsHandle);
  24472. -
  24473. - /**
  24474. - * The native method to classify an image with the ROI and orientation.
  24475. - *
  24476. - * @param roi the ROI of the input image, an array representing the bounding box as {left, top,
  24477. - * width, height}
  24478. - */
  24479. - private static native List<Classifications> classifyNative(
  24480. - long nativeHandle, long frameBufferHandle, int[] roi);
  24481. -
  24482. - @Override
  24483. - protected void deinit(long nativeHandle) {
  24484. - deinitJni(nativeHandle);
  24485. - }
  24486. -
  24487. - /**
  24488. - * Native implementation to release memory pointed by the pointer.
  24489. - *
  24490. - * @param nativeHandle pointer to memory allocated
  24491. - */
  24492. - private native void deinitJni(long nativeHandle);
  24493. +
  24494. + /**
  24495. + * Native implementation to release memory pointed by the pointer.
  24496. + *
  24497. + * @param nativeHandle pointer to memory allocated
  24498. + */
  24499. + private native void deinitJni(long nativeHandle);
  24500. }
  24501. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/core/BaseVisionTaskApi.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/core/BaseVisionTaskApi.java
  24502. index fdc898f451337..59ab62a949a25 100644
  24503. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/core/BaseVisionTaskApi.java
  24504. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/core/BaseVisionTaskApi.java
  24505. @@ -21,213 +21,184 @@ import static org.tensorflow.lite.support.common.internal.SupportPreconditions.c
  24506. import android.graphics.ImageFormat;
  24507. import android.media.Image;
  24508. import android.media.Image.Plane;
  24509. +
  24510. import com.google.auto.value.AutoValue;
  24511. -import java.nio.ByteBuffer;
  24512. +
  24513. import org.tensorflow.lite.DataType;
  24514. import org.tensorflow.lite.support.image.ColorSpaceType;
  24515. import org.tensorflow.lite.support.image.TensorImage;
  24516. import org.tensorflow.lite.task.core.BaseTaskApi;
  24517. import org.tensorflow.lite.task.core.vision.ImageProcessingOptions;
  24518. +import java.nio.ByteBuffer;
  24519. +
  24520. /** Base class for Task Vision APIs. */
  24521. public abstract class BaseVisionTaskApi extends BaseTaskApi {
  24522. -
  24523. - /** Syntax sugar to run vision tasks with FrameBuffer and image processing options. */
  24524. - public interface InferenceProvider<T> {
  24525. - T run(long frameBufferHandle, int width, int height, ImageProcessingOptions options);
  24526. - }
  24527. -
  24528. - protected BaseVisionTaskApi(long nativeHandle) {
  24529. - super(nativeHandle);
  24530. - }
  24531. -
  24532. - /** Runs inference with {@link TensorImage} and {@link ImageProcessingOptions}. */
  24533. - protected <T> T run(
  24534. - InferenceProvider<T> provider, TensorImage image, ImageProcessingOptions options) {
  24535. - FrameBufferData frameBufferData = createFrameBuffer(image, options.getOrientation().getValue());
  24536. - T results =
  24537. - provider.run(
  24538. - frameBufferData.getFrameBufferHandle(), image.getWidth(), image.getHeight(), options);
  24539. - deleteFrameBuffer(
  24540. - frameBufferData.getFrameBufferHandle(),
  24541. - frameBufferData.getByteArrayHandle(),
  24542. - frameBufferData.getByteArray());
  24543. - return results;
  24544. - }
  24545. -
  24546. - private static FrameBufferData createFrameBuffer(TensorImage image, int orientation) {
  24547. - ColorSpaceType colorSpaceType = image.getColorSpaceType();
  24548. - switch (colorSpaceType) {
  24549. - case RGB:
  24550. - case NV12:
  24551. - case NV21:
  24552. - case YV12:
  24553. - case YV21:
  24554. - // All these types can be converted to ByteBuffer inside TensorImage. Creating FrameBuffer
  24555. - // base on the image ByteBuffer.
  24556. - return createFrameBufferFromByteBuffer(image, orientation);
  24557. - case YUV_420_888:
  24558. - // YUV_420_888 is a specific type for android.media.Image.
  24559. - return createFrameBufferFromMediaImage(image, orientation);
  24560. - default:
  24561. - throw new IllegalArgumentException(
  24562. - "Color space type, " + colorSpaceType.name() + ", is unsupported.");
  24563. + /** Syntax sugar to run vision tasks with FrameBuffer and image processing options. */
  24564. + public interface InferenceProvider<T> {
  24565. + T run(long frameBufferHandle, int width, int height, ImageProcessingOptions options);
  24566. }
  24567. - }
  24568. -
  24569. - /**
  24570. - * Creates FrameBuffer from the {@link android.media.Image} stored in the given {@link
  24571. - * TensorImage}.
  24572. - */
  24573. - private static FrameBufferData createFrameBufferFromMediaImage(
  24574. - TensorImage image, int orientation) {
  24575. - Image mediaImage = image.getMediaImage();
  24576. -
  24577. - checkArgument(
  24578. - mediaImage.getFormat() == ImageFormat.YUV_420_888,
  24579. - "Only supports loading YUV_420_888 Image.");
  24580. -
  24581. - Plane[] planes = mediaImage.getPlanes();
  24582. - checkArgument(
  24583. - planes.length == 3,
  24584. - String.format("The input image should have 3 planes, but got %d plane(s).", planes.length));
  24585. -
  24586. - // Verify and rewind planes.
  24587. - for (Plane plane : planes) {
  24588. - ByteBuffer buffer = plane.getBuffer();
  24589. - checkNotNull(buffer, "The image buffer is corrupted and the plane is null.");
  24590. - // From the public documentation, plane.getBuffer() should always return a direct ByteBuffer.
  24591. - // See https://developer.android.com/reference/android/media/Image.Plane#getBuffer()
  24592. - checkArgument(
  24593. - buffer.isDirect(),
  24594. - "The image plane buffer is not a direct ByteBuffer, and is not supported.");
  24595. - buffer.rewind();
  24596. +
  24597. + protected BaseVisionTaskApi(long nativeHandle) {
  24598. + super(nativeHandle);
  24599. }
  24600. - return FrameBufferData.create(
  24601. - createFrameBufferFromPlanes(
  24602. - planes[0].getBuffer(),
  24603. - planes[1].getBuffer(),
  24604. - planes[2].getBuffer(),
  24605. - mediaImage.getWidth(),
  24606. - mediaImage.getHeight(),
  24607. - planes[0].getRowStride(),
  24608. - // row_stride and pixel_stride should be identical for U/V planes.
  24609. - planes[1].getRowStride(),
  24610. - planes[1].getPixelStride(),
  24611. - orientation),
  24612. - // FrameBuffer created with direct ByteBuffer does not require memory freeing.
  24613. - /*byteArrayHandle=*/ 0,
  24614. - /*byteArray=*/ new byte[0]);
  24615. - }
  24616. -
  24617. - /** Creates FrameBuffer from the {@link ByteBuffer} stored in the given {@link TensorImage}. */
  24618. - private static FrameBufferData createFrameBufferFromByteBuffer(
  24619. - TensorImage image, int orientation) {
  24620. - // base_vision_api_jni.cc expects an uint8 image. Convert image of other types into uint8.
  24621. - TensorImage imageUint8 =
  24622. - image.getDataType() == DataType.UINT8
  24623. - ? image
  24624. - : TensorImage.createFrom(image, DataType.UINT8);
  24625. -
  24626. - ByteBuffer byteBuffer = imageUint8.getBuffer();
  24627. - byteBuffer.rewind();
  24628. - ColorSpaceType colorSpaceType = image.getColorSpaceType();
  24629. - if (byteBuffer.isDirect()) {
  24630. - return FrameBufferData.create(
  24631. - createFrameBufferFromByteBuffer(
  24632. - byteBuffer,
  24633. - imageUint8.getWidth(),
  24634. - imageUint8.getHeight(),
  24635. - orientation,
  24636. - colorSpaceType.getValue()),
  24637. - // FrameBuffer created with direct ByteBuffer does not require memory freeing.
  24638. - /*byteArrayHandle=*/ 0,
  24639. - /*byteArray=*/ new byte[0]);
  24640. - } else {
  24641. - // If the byte array is copied in jni (during GetByteArrayElements), need to free
  24642. - // the copied array once inference is done.
  24643. - long[] byteArrayHandle = new long[1];
  24644. - byte[] byteArray = getBytesFromByteBuffer(byteBuffer);
  24645. - return FrameBufferData.create(
  24646. - createFrameBufferFromBytes(
  24647. - byteArray,
  24648. - imageUint8.getWidth(),
  24649. - imageUint8.getHeight(),
  24650. - orientation,
  24651. - colorSpaceType.getValue(),
  24652. - byteArrayHandle),
  24653. - byteArrayHandle[0],
  24654. - byteArray);
  24655. + /** Runs inference with {@link TensorImage} and {@link ImageProcessingOptions}. */
  24656. + protected <T> T run(
  24657. + InferenceProvider<T> provider, TensorImage image, ImageProcessingOptions options) {
  24658. + FrameBufferData frameBufferData =
  24659. + createFrameBuffer(image, options.getOrientation().getValue());
  24660. + T results = provider.run(frameBufferData.getFrameBufferHandle(), image.getWidth(),
  24661. + image.getHeight(), options);
  24662. + deleteFrameBuffer(frameBufferData.getFrameBufferHandle(),
  24663. + frameBufferData.getByteArrayHandle(), frameBufferData.getByteArray());
  24664. + return results;
  24665. }
  24666. - }
  24667. - /** Holds the FrameBuffer and the underlying data pointers in C++. */
  24668. - @AutoValue
  24669. - abstract static class FrameBufferData {
  24670. + private static FrameBufferData createFrameBuffer(TensorImage image, int orientation) {
  24671. + ColorSpaceType colorSpaceType = image.getColorSpaceType();
  24672. + switch (colorSpaceType) {
  24673. + case RGB:
  24674. + case NV12:
  24675. + case NV21:
  24676. + case YV12:
  24677. + case YV21:
  24678. + // All these types can be converted to ByteBuffer inside TensorImage. Creating
  24679. + // FrameBuffer base on the image ByteBuffer.
  24680. + return createFrameBufferFromByteBuffer(image, orientation);
  24681. + case YUV_420_888:
  24682. + // YUV_420_888 is a specific type for android.media.Image.
  24683. + return createFrameBufferFromMediaImage(image, orientation);
  24684. + default:
  24685. + throw new IllegalArgumentException(
  24686. + "Color space type, " + colorSpaceType.name() + ", is unsupported.");
  24687. + }
  24688. + }
  24689. /**
  24690. - * Initializes a {@link FrameBufferData} object.
  24691. - *
  24692. - * @param frameBufferHandle the native handle to the FrameBuffer object.
  24693. - * @param byteArrayHandle the native handle to the data array that backs up the FrameBuffer
  24694. - * object. If the FrameBuffer is created on a byte array, this byte array need to be freed
  24695. - * after inference is done. If the FrameBuffer is created on a direct ByteBuffer, no byte
  24696. - * array needs to be freed, and byteArrayHandle will be 0.
  24697. - * @param byteArray the byte array that is used to create the c++ byte array object, which is
  24698. - * needed when releasing byteArrayHandle. If the FrameBuffer is created on a direct
  24699. - * ByteBuffer (no byte array needs to be freed), pass in an empty array for {@code
  24700. - * byteArray}.
  24701. + * Creates FrameBuffer from the {@link android.media.Image} stored in the given {@link
  24702. + * TensorImage}.
  24703. */
  24704. - public static FrameBufferData create(
  24705. - long frameBufferHandle, long byteArrayHandle, byte[] byteArray) {
  24706. - return new AutoValue_BaseVisionTaskApi_FrameBufferData(
  24707. - frameBufferHandle, byteArrayHandle, byteArray);
  24708. + private static FrameBufferData createFrameBufferFromMediaImage(
  24709. + TensorImage image, int orientation) {
  24710. + Image mediaImage = image.getMediaImage();
  24711. +
  24712. + checkArgument(mediaImage.getFormat() == ImageFormat.YUV_420_888,
  24713. + "Only supports loading YUV_420_888 Image.");
  24714. +
  24715. + Plane[] planes = mediaImage.getPlanes();
  24716. + checkArgument(planes.length == 3,
  24717. + String.format("The input image should have 3 planes, but got %d plane(s).",
  24718. + planes.length));
  24719. +
  24720. + // Verify and rewind planes.
  24721. + for (Plane plane : planes) {
  24722. + ByteBuffer buffer = plane.getBuffer();
  24723. + checkNotNull(buffer, "The image buffer is corrupted and the plane is null.");
  24724. + // From the public documentation, plane.getBuffer() should always return a direct
  24725. + // ByteBuffer. See
  24726. + // https://developer.android.com/reference/android/media/Image.Plane#getBuffer()
  24727. + checkArgument(buffer.isDirect(),
  24728. + "The image plane buffer is not a direct ByteBuffer, and is not supported.");
  24729. + buffer.rewind();
  24730. + }
  24731. +
  24732. + return FrameBufferData.create(
  24733. + createFrameBufferFromPlanes(planes[0].getBuffer(), planes[1].getBuffer(),
  24734. + planes[2].getBuffer(), mediaImage.getWidth(), mediaImage.getHeight(),
  24735. + planes[0].getRowStride(),
  24736. + // row_stride and pixel_stride should be identical for U/V planes.
  24737. + planes[1].getRowStride(), planes[1].getPixelStride(), orientation),
  24738. + // FrameBuffer created with direct ByteBuffer does not require memory freeing.
  24739. + /*byteArrayHandle=*/0,
  24740. + /*byteArray=*/new byte[0]);
  24741. + }
  24742. +
  24743. + /** Creates FrameBuffer from the {@link ByteBuffer} stored in the given {@link TensorImage}. */
  24744. + private static FrameBufferData createFrameBufferFromByteBuffer(
  24745. + TensorImage image, int orientation) {
  24746. + // base_vision_api_jni.cc expects an uint8 image. Convert image of other types into uint8.
  24747. + TensorImage imageUint8 = image.getDataType() == DataType.UINT8
  24748. + ? image
  24749. + : TensorImage.createFrom(image, DataType.UINT8);
  24750. +
  24751. + ByteBuffer byteBuffer = imageUint8.getBuffer();
  24752. + byteBuffer.rewind();
  24753. + ColorSpaceType colorSpaceType = image.getColorSpaceType();
  24754. + if (byteBuffer.isDirect()) {
  24755. + return FrameBufferData.create(
  24756. + createFrameBufferFromByteBuffer(byteBuffer, imageUint8.getWidth(),
  24757. + imageUint8.getHeight(), orientation, colorSpaceType.getValue()),
  24758. + // FrameBuffer created with direct ByteBuffer does not require memory freeing.
  24759. + /*byteArrayHandle=*/0,
  24760. + /*byteArray=*/new byte[0]);
  24761. + } else {
  24762. + // If the byte array is copied in jni (during GetByteArrayElements), need to free
  24763. + // the copied array once inference is done.
  24764. + long[] byteArrayHandle = new long[1];
  24765. + byte[] byteArray = getBytesFromByteBuffer(byteBuffer);
  24766. + return FrameBufferData.create(
  24767. + createFrameBufferFromBytes(byteArray, imageUint8.getWidth(),
  24768. + imageUint8.getHeight(), orientation, colorSpaceType.getValue(),
  24769. + byteArrayHandle),
  24770. + byteArrayHandle[0], byteArray);
  24771. + }
  24772. + }
  24773. +
  24774. + /** Holds the FrameBuffer and the underlying data pointers in C++. */
  24775. + @AutoValue
  24776. + abstract static class FrameBufferData {
  24777. + /**
  24778. + * Initializes a {@link FrameBufferData} object.
  24779. + *
  24780. + * @param frameBufferHandle the native handle to the FrameBuffer object.
  24781. + * @param byteArrayHandle the native handle to the data array that backs up the FrameBuffer
  24782. + * object. If the FrameBuffer is created on a byte array, this byte array need to be
  24783. + * freed after inference is done. If the FrameBuffer is created on a direct ByteBuffer, no
  24784. + * byte array needs to be freed, and byteArrayHandle will be 0.
  24785. + * @param byteArray the byte array that is used to create the c++ byte array object, which
  24786. + * is
  24787. + * needed when releasing byteArrayHandle. If the FrameBuffer is created on a direct
  24788. + * ByteBuffer (no byte array needs to be freed), pass in an empty array for {@code
  24789. + * byteArray}.
  24790. + */
  24791. + public static FrameBufferData create(
  24792. + long frameBufferHandle, long byteArrayHandle, byte[] byteArray) {
  24793. + return new AutoValue_BaseVisionTaskApi_FrameBufferData(
  24794. + frameBufferHandle, byteArrayHandle, byteArray);
  24795. + }
  24796. +
  24797. + abstract long getFrameBufferHandle();
  24798. +
  24799. + abstract long getByteArrayHandle();
  24800. +
  24801. + // Package private method for transferring data.
  24802. + @SuppressWarnings("mutable")
  24803. + abstract byte[] getByteArray();
  24804. }
  24805. - abstract long getFrameBufferHandle();
  24806. -
  24807. - abstract long getByteArrayHandle();
  24808. -
  24809. - // Package private method for transferring data.
  24810. - @SuppressWarnings("mutable")
  24811. - abstract byte[] getByteArray();
  24812. - }
  24813. -
  24814. - private static native long createFrameBufferFromByteBuffer(
  24815. - ByteBuffer image, int width, int height, int orientation, int colorSpaceType);
  24816. -
  24817. - private static native long createFrameBufferFromBytes(
  24818. - byte[] image,
  24819. - int width,
  24820. - int height,
  24821. - int orientation,
  24822. - int colorSpaceType,
  24823. - long[] byteArrayHandle);
  24824. -
  24825. - private static native long createFrameBufferFromPlanes(
  24826. - ByteBuffer yBuffer,
  24827. - ByteBuffer uBuffer,
  24828. - ByteBuffer vBuffer,
  24829. - int width,
  24830. - int height,
  24831. - int yRowStride,
  24832. - int uvRowStride,
  24833. - int uvPixelStride,
  24834. - int orientation);
  24835. -
  24836. - private static native void deleteFrameBuffer(
  24837. - long frameBufferHandle, long byteArrayHandle, byte[] byteArray);
  24838. -
  24839. - private static byte[] getBytesFromByteBuffer(ByteBuffer byteBuffer) {
  24840. - // If the ByteBuffer has a back up array, use it directly without copy.
  24841. - if (byteBuffer.hasArray() && byteBuffer.arrayOffset() == 0) {
  24842. - return byteBuffer.array();
  24843. + private static native long createFrameBufferFromByteBuffer(
  24844. + ByteBuffer image, int width, int height, int orientation, int colorSpaceType);
  24845. +
  24846. + private static native long createFrameBufferFromBytes(byte[] image, int width, int height,
  24847. + int orientation, int colorSpaceType, long[] byteArrayHandle);
  24848. +
  24849. + private static native long createFrameBufferFromPlanes(ByteBuffer yBuffer, ByteBuffer uBuffer,
  24850. + ByteBuffer vBuffer, int width, int height, int yRowStride, int uvRowStride,
  24851. + int uvPixelStride, int orientation);
  24852. +
  24853. + private static native void deleteFrameBuffer(
  24854. + long frameBufferHandle, long byteArrayHandle, byte[] byteArray);
  24855. +
  24856. + private static byte[] getBytesFromByteBuffer(ByteBuffer byteBuffer) {
  24857. + // If the ByteBuffer has a back up array, use it directly without copy.
  24858. + if (byteBuffer.hasArray() && byteBuffer.arrayOffset() == 0) {
  24859. + return byteBuffer.array();
  24860. + }
  24861. + // Copy out the data otherwise.
  24862. + byteBuffer.rewind();
  24863. + byte[] bytes = new byte[byteBuffer.limit()];
  24864. + byteBuffer.get(bytes, 0, bytes.length);
  24865. + return bytes;
  24866. }
  24867. - // Copy out the data otherwise.
  24868. - byteBuffer.rewind();
  24869. - byte[] bytes = new byte[byteBuffer.limit()];
  24870. - byteBuffer.get(bytes, 0, bytes.length);
  24871. - return bytes;
  24872. - }
  24873. }
  24874. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/Detection.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/Detection.java
  24875. index 859e41fc038be..096af521c6b00 100644
  24876. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/Detection.java
  24877. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/Detection.java
  24878. @@ -16,27 +16,29 @@ limitations under the License.
  24879. package org.tensorflow.lite.task.vision.detector;
  24880. import android.graphics.RectF;
  24881. +
  24882. import com.google.auto.value.AutoValue;
  24883. +
  24884. +import org.tensorflow.lite.support.label.Category;
  24885. +import org.tensorflow.lite.task.core.annotations.UsedByReflection;
  24886. +
  24887. import java.util.ArrayList;
  24888. import java.util.Collections;
  24889. import java.util.List;
  24890. -import org.tensorflow.lite.support.label.Category;
  24891. -import org.tensorflow.lite.task.core.annotations.UsedByReflection;
  24892. /** Represents one detected object in the results of a {@link ObjectDetector}. */
  24893. @AutoValue
  24894. @UsedByReflection("object_detection_jni.cc")
  24895. public abstract class Detection {
  24896. + @UsedByReflection("object_detection_jni.cc")
  24897. + public static Detection create(RectF boundingBox, List<Category> categories) {
  24898. + return new AutoValue_Detection(new RectF(boundingBox),
  24899. + Collections.unmodifiableList(new ArrayList<Category>(categories)));
  24900. + }
  24901. - @UsedByReflection("object_detection_jni.cc")
  24902. - public static Detection create(RectF boundingBox, List<Category> categories) {
  24903. - return new AutoValue_Detection(
  24904. - new RectF(boundingBox), Collections.unmodifiableList(new ArrayList<Category>(categories)));
  24905. - }
  24906. -
  24907. - public abstract RectF getBoundingBox();
  24908. + public abstract RectF getBoundingBox();
  24909. - // Same reason for not using ImmutableList as stated in
  24910. - // {@link ObjectDetector#ObjectDetectorOptions#labelAllowList}.
  24911. - public abstract List<Category> getCategories();
  24912. + // Same reason for not using ImmutableList as stated in
  24913. + // {@link ObjectDetector#ObjectDetectorOptions#labelAllowList}.
  24914. + public abstract List<Category> getCategories();
  24915. }
  24916. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/ObjectDetector.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/ObjectDetector.java
  24917. index 4aff7bfab8ca5..d1fb421fc0bbf 100644
  24918. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/ObjectDetector.java
  24919. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/ObjectDetector.java
  24920. @@ -17,14 +17,9 @@ package org.tensorflow.lite.task.vision.detector;
  24921. import android.content.Context;
  24922. import android.os.ParcelFileDescriptor;
  24923. +
  24924. import com.google.android.odml.image.MlImage;
  24925. -import java.io.File;
  24926. -import java.io.IOException;
  24927. -import java.nio.ByteBuffer;
  24928. -import java.nio.MappedByteBuffer;
  24929. -import java.util.ArrayList;
  24930. -import java.util.Collections;
  24931. -import java.util.List;
  24932. +
  24933. import org.tensorflow.lite.support.image.MlImageAdapter;
  24934. import org.tensorflow.lite.support.image.TensorImage;
  24935. import org.tensorflow.lite.task.core.BaseOptions;
  24936. @@ -35,6 +30,14 @@ import org.tensorflow.lite.task.core.annotations.UsedByReflection;
  24937. import org.tensorflow.lite.task.core.vision.ImageProcessingOptions;
  24938. import org.tensorflow.lite.task.vision.core.BaseVisionTaskApi;
  24939. +import java.io.File;
  24940. +import java.io.IOException;
  24941. +import java.nio.ByteBuffer;
  24942. +import java.nio.MappedByteBuffer;
  24943. +import java.util.ArrayList;
  24944. +import java.util.Collections;
  24945. +import java.util.List;
  24946. +
  24947. /**
  24948. * Performs object detection on images.
  24949. *
  24950. @@ -86,469 +89,447 @@ import org.tensorflow.lite.task.vision.core.BaseVisionTaskApi;
  24951. * Hub.</a>.
  24952. */
  24953. public final class ObjectDetector extends BaseVisionTaskApi {
  24954. + private static final String OBJECT_DETECTOR_NATIVE_LIB = "task_vision_jni";
  24955. + private static final int OPTIONAL_FD_LENGTH = -1;
  24956. + private static final int OPTIONAL_FD_OFFSET = -1;
  24957. +
  24958. + /**
  24959. + * Creates an {@link ObjectDetector} instance from the default {@link ObjectDetectorOptions}.
  24960. + *
  24961. + * @param modelPath path to the detection model with metadata in the assets
  24962. + * @throws IOException if an I/O error occurs when loading the tflite model
  24963. + * @throws IllegalArgumentException if an argument is invalid
  24964. + * @throws IllegalStateException if there is an internal error
  24965. + * @throws RuntimeException if there is an otherwise unspecified error
  24966. + */
  24967. + public static ObjectDetector createFromFile(Context context, String modelPath)
  24968. + throws IOException {
  24969. + return createFromFileAndOptions(
  24970. + context, modelPath, ObjectDetectorOptions.builder().build());
  24971. + }
  24972. - private static final String OBJECT_DETECTOR_NATIVE_LIB = "task_vision_jni";
  24973. - private static final int OPTIONAL_FD_LENGTH = -1;
  24974. - private static final int OPTIONAL_FD_OFFSET = -1;
  24975. -
  24976. - /**
  24977. - * Creates an {@link ObjectDetector} instance from the default {@link ObjectDetectorOptions}.
  24978. - *
  24979. - * @param modelPath path to the detection model with metadata in the assets
  24980. - * @throws IOException if an I/O error occurs when loading the tflite model
  24981. - * @throws IllegalArgumentException if an argument is invalid
  24982. - * @throws IllegalStateException if there is an internal error
  24983. - * @throws RuntimeException if there is an otherwise unspecified error
  24984. - */
  24985. - public static ObjectDetector createFromFile(Context context, String modelPath)
  24986. - throws IOException {
  24987. - return createFromFileAndOptions(context, modelPath, ObjectDetectorOptions.builder().build());
  24988. - }
  24989. -
  24990. - /**
  24991. - * Creates an {@link ObjectDetector} instance from the default {@link ObjectDetectorOptions}.
  24992. - *
  24993. - * @param modelFile the detection model {@link File} instance
  24994. - * @throws IOException if an I/O error occurs when loading the tflite model
  24995. - * @throws IllegalArgumentException if an argument is invalid
  24996. - * @throws IllegalStateException if there is an internal error
  24997. - * @throws RuntimeException if there is an otherwise unspecified error
  24998. - */
  24999. - public static ObjectDetector createFromFile(File modelFile) throws IOException {
  25000. - return createFromFileAndOptions(modelFile, ObjectDetectorOptions.builder().build());
  25001. - }
  25002. -
  25003. - /**
  25004. - * Creates an {@link ObjectDetector} instance with a model buffer and the default {@link
  25005. - * ObjectDetectorOptions}.
  25006. - *
  25007. - * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the detection
  25008. - * model
  25009. - * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
  25010. - * {@link MappedByteBuffer} * @throws IllegalStateException if there is an internal error
  25011. - * @throws RuntimeException if there is an otherwise unspecified error
  25012. - */
  25013. - public static ObjectDetector createFromBuffer(final ByteBuffer modelBuffer) {
  25014. - return createFromBufferAndOptions(modelBuffer, ObjectDetectorOptions.builder().build());
  25015. - }
  25016. -
  25017. - /**
  25018. - * Creates an {@link ObjectDetector} instance from {@link ObjectDetectorOptions}.
  25019. - *
  25020. - * @param modelPath path to the detection model with metadata in the assets
  25021. - * @throws IOException if an I/O error occurs when loading the tflite model
  25022. - * @throws IllegalArgumentException if an argument is invalid
  25023. - * @throws IllegalStateException if there is an internal error
  25024. - * @throws RuntimeException if there is an otherwise unspecified error
  25025. - */
  25026. - public static ObjectDetector createFromFileAndOptions(
  25027. - Context context, String modelPath, ObjectDetectorOptions options) throws IOException {
  25028. - return new ObjectDetector(
  25029. - TaskJniUtils.createHandleFromFdAndOptions(
  25030. - context,
  25031. - new FdAndOptionsHandleProvider<ObjectDetectorOptions>() {
  25032. - @Override
  25033. - public long createHandle(
  25034. - int fileDescriptor,
  25035. - long fileDescriptorLength,
  25036. - long fileDescriptorOffset,
  25037. - ObjectDetectorOptions options) {
  25038. - return initJniWithModelFdAndOptions(
  25039. - fileDescriptor,
  25040. - fileDescriptorLength,
  25041. - fileDescriptorOffset,
  25042. - options,
  25043. - TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads(
  25044. - options.getBaseOptions(), options.getNumThreads()));
  25045. - }
  25046. - },
  25047. - OBJECT_DETECTOR_NATIVE_LIB,
  25048. - modelPath,
  25049. - options));
  25050. - }
  25051. -
  25052. - /**
  25053. - * Creates an {@link ObjectDetector} instance from {@link ObjectDetectorOptions}.
  25054. - *
  25055. - * @param modelFile the detection model {@link File} instance
  25056. - * @throws IOException if an I/O error occurs when loading the tflite model
  25057. - * @throws IllegalArgumentException if an argument is invalid
  25058. - * @throws IllegalStateException if there is an internal error
  25059. - * @throws RuntimeException if there is an otherwise unspecified error
  25060. - */
  25061. - public static ObjectDetector createFromFileAndOptions(
  25062. - File modelFile, final ObjectDetectorOptions options) throws IOException {
  25063. - try (ParcelFileDescriptor descriptor =
  25064. - ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
  25065. - return new ObjectDetector(
  25066. - TaskJniUtils.createHandleFromLibrary(
  25067. - new TaskJniUtils.EmptyHandleProvider() {
  25068. - @Override
  25069. - public long createHandle() {
  25070. - return initJniWithModelFdAndOptions(
  25071. - descriptor.getFd(),
  25072. - /*fileDescriptorLength=*/ OPTIONAL_FD_LENGTH,
  25073. - /*fileDescriptorOffset=*/ OPTIONAL_FD_OFFSET,
  25074. - options,
  25075. - TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads(
  25076. - options.getBaseOptions(), options.getNumThreads()));
  25077. - }
  25078. - },
  25079. - OBJECT_DETECTOR_NATIVE_LIB));
  25080. + /**
  25081. + * Creates an {@link ObjectDetector} instance from the default {@link ObjectDetectorOptions}.
  25082. + *
  25083. + * @param modelFile the detection model {@link File} instance
  25084. + * @throws IOException if an I/O error occurs when loading the tflite model
  25085. + * @throws IllegalArgumentException if an argument is invalid
  25086. + * @throws IllegalStateException if there is an internal error
  25087. + * @throws RuntimeException if there is an otherwise unspecified error
  25088. + */
  25089. + public static ObjectDetector createFromFile(File modelFile) throws IOException {
  25090. + return createFromFileAndOptions(modelFile, ObjectDetectorOptions.builder().build());
  25091. }
  25092. - }
  25093. -
  25094. - /**
  25095. - * Creates an {@link ObjectDetector} instance with a model buffer and {@link
  25096. - * ObjectDetectorOptions}.
  25097. - *
  25098. - * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the detection
  25099. - * model
  25100. - * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
  25101. - * {@link MappedByteBuffer}
  25102. - * @throws IllegalStateException if there is an internal error
  25103. - * @throws RuntimeException if there is an otherwise unspecified error
  25104. - */
  25105. - public static ObjectDetector createFromBufferAndOptions(
  25106. - final ByteBuffer modelBuffer, final ObjectDetectorOptions options) {
  25107. - if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) {
  25108. - throw new IllegalArgumentException(
  25109. - "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer.");
  25110. +
  25111. + /**
  25112. + * Creates an {@link ObjectDetector} instance with a model buffer and the default {@link
  25113. + * ObjectDetectorOptions}.
  25114. + *
  25115. + * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the detection
  25116. + * model
  25117. + * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
  25118. + * {@link MappedByteBuffer} * @throws IllegalStateException if there is an internal error
  25119. + * @throws RuntimeException if there is an otherwise unspecified error
  25120. + */
  25121. + public static ObjectDetector createFromBuffer(final ByteBuffer modelBuffer) {
  25122. + return createFromBufferAndOptions(modelBuffer, ObjectDetectorOptions.builder().build());
  25123. }
  25124. - return new ObjectDetector(
  25125. - TaskJniUtils.createHandleFromLibrary(
  25126. - new EmptyHandleProvider() {
  25127. - @Override
  25128. - public long createHandle() {
  25129. - return initJniWithByteBuffer(
  25130. - modelBuffer,
  25131. - options,
  25132. - TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads(
  25133. - options.getBaseOptions(), options.getNumThreads()));
  25134. - }
  25135. - },
  25136. - OBJECT_DETECTOR_NATIVE_LIB));
  25137. - }
  25138. -
  25139. - /**
  25140. - * Constructor to initialize the JNI with a pointer from C++.
  25141. - *
  25142. - * @param nativeHandle a pointer referencing memory allocated in C++
  25143. - */
  25144. - private ObjectDetector(long nativeHandle) {
  25145. - super(nativeHandle);
  25146. - }
  25147. -
  25148. - /** Options for setting up an ObjectDetector. */
  25149. - @UsedByReflection("object_detector_jni.cc")
  25150. - public static class ObjectDetectorOptions {
  25151. - // Not using AutoValue for this class because scoreThreshold cannot have default value
  25152. - // (otherwise, the default value would override the one in the model metadata) and `Optional` is
  25153. - // not an option here, because
  25154. - // 1. java.util.Optional require Java 8 while we need to support Java 7.
  25155. - // 2. The Guava library (com.google.common.base.Optional) is avoided in this project. See the
  25156. - // comments for labelAllowList.
  25157. - private final BaseOptions baseOptions;
  25158. - private final String displayNamesLocale;
  25159. - private final int maxResults;
  25160. - private final float scoreThreshold;
  25161. - private final boolean isScoreThresholdSet;
  25162. - // As an open source project, we've been trying avoiding depending on common java libraries,
  25163. - // such as Guava, because it may introduce conflicts with clients who also happen to use those
  25164. - // libraries. Therefore, instead of using ImmutableList here, we convert the List into
  25165. - // unmodifiableList in setLabelAllowList() and setLabelDenyList() to make it less
  25166. - // vulnerable.
  25167. - private final List<String> labelAllowList;
  25168. - private final List<String> labelDenyList;
  25169. - private final int numThreads;
  25170. -
  25171. - public static Builder builder() {
  25172. - return new Builder();
  25173. +
  25174. + /**
  25175. + * Creates an {@link ObjectDetector} instance from {@link ObjectDetectorOptions}.
  25176. + *
  25177. + * @param modelPath path to the detection model with metadata in the assets
  25178. + * @throws IOException if an I/O error occurs when loading the tflite model
  25179. + * @throws IllegalArgumentException if an argument is invalid
  25180. + * @throws IllegalStateException if there is an internal error
  25181. + * @throws RuntimeException if there is an otherwise unspecified error
  25182. + */
  25183. + public static ObjectDetector createFromFileAndOptions(
  25184. + Context context, String modelPath, ObjectDetectorOptions options) throws IOException {
  25185. + return new ObjectDetector(TaskJniUtils.createHandleFromFdAndOptions(
  25186. + context, new FdAndOptionsHandleProvider<ObjectDetectorOptions>() {
  25187. + @Override
  25188. + public long createHandle(int fileDescriptor, long fileDescriptorLength,
  25189. + long fileDescriptorOffset, ObjectDetectorOptions options) {
  25190. + return initJniWithModelFdAndOptions(fileDescriptor, fileDescriptorLength,
  25191. + fileDescriptorOffset, options,
  25192. + TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads(
  25193. + options.getBaseOptions(), options.getNumThreads()));
  25194. + }
  25195. + }, OBJECT_DETECTOR_NATIVE_LIB, modelPath, options));
  25196. }
  25197. - /** A builder that helps to configure an instance of ObjectDetectorOptions. */
  25198. - public static class Builder {
  25199. - private BaseOptions baseOptions = BaseOptions.builder().build();
  25200. - private String displayNamesLocale = "en";
  25201. - private int maxResults = -1;
  25202. - private float scoreThreshold;
  25203. - private boolean isScoreThresholdSet = false;
  25204. - private List<String> labelAllowList = new ArrayList<>();
  25205. - private List<String> labelDenyList = new ArrayList<>();
  25206. - private int numThreads = -1;
  25207. -
  25208. - private Builder() {}
  25209. -
  25210. - /** Sets the general options to configure Task APIs, such as accelerators. */
  25211. - public Builder setBaseOptions(BaseOptions baseOptions) {
  25212. - this.baseOptions = baseOptions;
  25213. - return this;
  25214. - }
  25215. -
  25216. - /**
  25217. - * Sets the locale to use for display names specified through the TFLite Model Metadata, if
  25218. - * any.
  25219. - *
  25220. - * <p>Defaults to English({@code "en"}). See the <a
  25221. - * href="https://github.com/tensorflow/tflite-support/blob/3ce83f0cfe2c68fecf83e019f2acc354aaba471f/tensorflow_lite_support/metadata/metadata_schema.fbs#L147">TFLite
  25222. - * Metadata schema file.</a> for the accepted pattern of locale.
  25223. - */
  25224. - public Builder setDisplayNamesLocale(String displayNamesLocale) {
  25225. - this.displayNamesLocale = displayNamesLocale;
  25226. - return this;
  25227. - }
  25228. -
  25229. - /**
  25230. - * Sets the maximum number of top-scored detection results to return.
  25231. - *
  25232. - * <p>If < 0, all available results will be returned. If 0, an invalid argument error is
  25233. - * returned. Note that models may intrinsically be limited to returning a maximum number of
  25234. - * results N: if the provided value here is above N, only N results will be returned. Defaults
  25235. - * to -1.
  25236. - *
  25237. - * @throws IllegalArgumentException if maxResults is 0.
  25238. - */
  25239. - public Builder setMaxResults(int maxResults) {
  25240. - if (maxResults == 0) {
  25241. - throw new IllegalArgumentException("maxResults cannot be 0.");
  25242. + /**
  25243. + * Creates an {@link ObjectDetector} instance from {@link ObjectDetectorOptions}.
  25244. + *
  25245. + * @param modelFile the detection model {@link File} instance
  25246. + * @throws IOException if an I/O error occurs when loading the tflite model
  25247. + * @throws IllegalArgumentException if an argument is invalid
  25248. + * @throws IllegalStateException if there is an internal error
  25249. + * @throws RuntimeException if there is an otherwise unspecified error
  25250. + */
  25251. + public static ObjectDetector createFromFileAndOptions(
  25252. + File modelFile, final ObjectDetectorOptions options) throws IOException {
  25253. + try (ParcelFileDescriptor descriptor =
  25254. + ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
  25255. + return new ObjectDetector(
  25256. + TaskJniUtils.createHandleFromLibrary(new TaskJniUtils.EmptyHandleProvider() {
  25257. + @Override
  25258. + public long createHandle() {
  25259. + return initJniWithModelFdAndOptions(descriptor.getFd(),
  25260. + /*fileDescriptorLength=*/OPTIONAL_FD_LENGTH,
  25261. + /*fileDescriptorOffset=*/OPTIONAL_FD_OFFSET, options,
  25262. + TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads(
  25263. + options.getBaseOptions(), options.getNumThreads()));
  25264. + }
  25265. + }, OBJECT_DETECTOR_NATIVE_LIB));
  25266. }
  25267. - this.maxResults = maxResults;
  25268. - return this;
  25269. - }
  25270. -
  25271. - /**
  25272. - * Sets the score threshold that overrides the one provided in the model metadata (if any).
  25273. - * Results below this value are rejected.
  25274. - */
  25275. - public Builder setScoreThreshold(float scoreThreshold) {
  25276. - this.scoreThreshold = scoreThreshold;
  25277. - this.isScoreThresholdSet = true;
  25278. - return this;
  25279. - }
  25280. -
  25281. - /**
  25282. - * Sets the optional allow list of labels.
  25283. - *
  25284. - * <p>If non-empty, detection results whose label is not in this set will be filtered out.
  25285. - * Duplicate or unknown labels are ignored. Mutually exclusive with {@code labelDenyList}. It
  25286. - * will cause {@link IllegalStateException} when calling {@link #createFromFileAndOptions}, if
  25287. - * both {@code labelDenyList} and {@code labelAllowList} are set.
  25288. - */
  25289. - public Builder setLabelAllowList(List<String> labelAllowList) {
  25290. - this.labelAllowList = Collections.unmodifiableList(new ArrayList<>(labelAllowList));
  25291. - return this;
  25292. - }
  25293. -
  25294. - /**
  25295. - * Sets the optional deny list of labels.
  25296. - *
  25297. - * <p>If non-empty, detection results whose label is in this set will be filtered out.
  25298. - * Duplicate or unknown labels are ignored. Mutually exclusive with {@code labelAllowList}. It
  25299. - * will cause {@link IllegalStateException} when calling {@link #createFromFileAndOptions}, if
  25300. - * both {@code labelDenyList} and {@code labelAllowList} are set.
  25301. - */
  25302. - public Builder setLabelDenyList(List<String> labelDenyList) {
  25303. - this.labelDenyList = Collections.unmodifiableList(new ArrayList<>(labelDenyList));
  25304. - return this;
  25305. - }
  25306. -
  25307. - /**
  25308. - * Sets the number of threads to be used for TFLite ops that support multi-threading when
  25309. - * running inference with CPU. Defaults to -1.
  25310. - *
  25311. - * <p>numThreads should be greater than 0 or equal to -1. Setting numThreads to -1 has the
  25312. - * effect to let TFLite runtime set the value.
  25313. - *
  25314. - * @deprecated use {@link BaseOptions} to configure number of threads instead. This method
  25315. - * will override the number of threads configured from {@link BaseOptions}.
  25316. - */
  25317. - @Deprecated
  25318. - public Builder setNumThreads(int numThreads) {
  25319. - this.numThreads = numThreads;
  25320. - return this;
  25321. - }
  25322. -
  25323. - public ObjectDetectorOptions build() {
  25324. - return new ObjectDetectorOptions(this);
  25325. - }
  25326. }
  25327. - @UsedByReflection("object_detector_jni.cc")
  25328. - public String getDisplayNamesLocale() {
  25329. - return displayNamesLocale;
  25330. + /**
  25331. + * Creates an {@link ObjectDetector} instance with a model buffer and {@link
  25332. + * ObjectDetectorOptions}.
  25333. + *
  25334. + * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the detection
  25335. + * model
  25336. + * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
  25337. + * {@link MappedByteBuffer}
  25338. + * @throws IllegalStateException if there is an internal error
  25339. + * @throws RuntimeException if there is an otherwise unspecified error
  25340. + */
  25341. + public static ObjectDetector createFromBufferAndOptions(
  25342. + final ByteBuffer modelBuffer, final ObjectDetectorOptions options) {
  25343. + if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) {
  25344. + throw new IllegalArgumentException(
  25345. + "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer.");
  25346. + }
  25347. + return new ObjectDetector(TaskJniUtils.createHandleFromLibrary(new EmptyHandleProvider() {
  25348. + @Override
  25349. + public long createHandle() {
  25350. + return initJniWithByteBuffer(modelBuffer, options,
  25351. + TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads(
  25352. + options.getBaseOptions(), options.getNumThreads()));
  25353. + }
  25354. + }, OBJECT_DETECTOR_NATIVE_LIB));
  25355. }
  25356. - @UsedByReflection("object_detector_jni.cc")
  25357. - public int getMaxResults() {
  25358. - return maxResults;
  25359. + /**
  25360. + * Constructor to initialize the JNI with a pointer from C++.
  25361. + *
  25362. + * @param nativeHandle a pointer referencing memory allocated in C++
  25363. + */
  25364. + private ObjectDetector(long nativeHandle) {
  25365. + super(nativeHandle);
  25366. }
  25367. + /** Options for setting up an ObjectDetector. */
  25368. @UsedByReflection("object_detector_jni.cc")
  25369. - public float getScoreThreshold() {
  25370. - return scoreThreshold;
  25371. + public static class ObjectDetectorOptions {
  25372. + // Not using AutoValue for this class because scoreThreshold cannot have default value
  25373. + // (otherwise, the default value would override the one in the model metadata) and
  25374. + // `Optional` is not an option here, because
  25375. + // 1. java.util.Optional require Java 8 while we need to support Java 7.
  25376. + // 2. The Guava library (com.google.common.base.Optional) is avoided in this project. See
  25377. + // the comments for labelAllowList.
  25378. + private final BaseOptions baseOptions;
  25379. + private final String displayNamesLocale;
  25380. + private final int maxResults;
  25381. + private final float scoreThreshold;
  25382. + private final boolean isScoreThresholdSet;
  25383. + // As an open source project, we've been trying avoiding depending on common java libraries,
  25384. + // such as Guava, because it may introduce conflicts with clients who also happen to use
  25385. + // those libraries. Therefore, instead of using ImmutableList here, we convert the List into
  25386. + // unmodifiableList in setLabelAllowList() and setLabelDenyList() to make it less
  25387. + // vulnerable.
  25388. + private final List<String> labelAllowList;
  25389. + private final List<String> labelDenyList;
  25390. + private final int numThreads;
  25391. +
  25392. + public static Builder builder() {
  25393. + return new Builder();
  25394. + }
  25395. +
  25396. + /** A builder that helps to configure an instance of ObjectDetectorOptions. */
  25397. + public static class Builder {
  25398. + private BaseOptions baseOptions = BaseOptions.builder().build();
  25399. + private String displayNamesLocale = "en";
  25400. + private int maxResults = -1;
  25401. + private float scoreThreshold;
  25402. + private boolean isScoreThresholdSet = false;
  25403. + private List<String> labelAllowList = new ArrayList<>();
  25404. + private List<String> labelDenyList = new ArrayList<>();
  25405. + private int numThreads = -1;
  25406. +
  25407. + private Builder() {}
  25408. +
  25409. + /** Sets the general options to configure Task APIs, such as accelerators. */
  25410. + public Builder setBaseOptions(BaseOptions baseOptions) {
  25411. + this.baseOptions = baseOptions;
  25412. + return this;
  25413. + }
  25414. +
  25415. + /**
  25416. + * Sets the locale to use for display names specified through the TFLite Model Metadata,
  25417. + * if any.
  25418. + *
  25419. + * <p>Defaults to English({@code "en"}). See the <a
  25420. + * href="https://github.com/tensorflow/tflite-support/blob/3ce83f0cfe2c68fecf83e019f2acc354aaba471f/tensorflow_lite_support/metadata/metadata_schema.fbs#L147">TFLite
  25421. + * Metadata schema file.</a> for the accepted pattern of locale.
  25422. + */
  25423. + public Builder setDisplayNamesLocale(String displayNamesLocale) {
  25424. + this.displayNamesLocale = displayNamesLocale;
  25425. + return this;
  25426. + }
  25427. +
  25428. + /**
  25429. + * Sets the maximum number of top-scored detection results to return.
  25430. + *
  25431. + * <p>If < 0, all available results will be returned. If 0, an invalid argument error is
  25432. + * returned. Note that models may intrinsically be limited to returning a maximum number
  25433. + * of results N: if the provided value here is above N, only N results will be returned.
  25434. + * Defaults to -1.
  25435. + *
  25436. + * @throws IllegalArgumentException if maxResults is 0.
  25437. + */
  25438. + public Builder setMaxResults(int maxResults) {
  25439. + if (maxResults == 0) {
  25440. + throw new IllegalArgumentException("maxResults cannot be 0.");
  25441. + }
  25442. + this.maxResults = maxResults;
  25443. + return this;
  25444. + }
  25445. +
  25446. + /**
  25447. + * Sets the score threshold that overrides the one provided in the model metadata (if
  25448. + * any). Results below this value are rejected.
  25449. + */
  25450. + public Builder setScoreThreshold(float scoreThreshold) {
  25451. + this.scoreThreshold = scoreThreshold;
  25452. + this.isScoreThresholdSet = true;
  25453. + return this;
  25454. + }
  25455. +
  25456. + /**
  25457. + * Sets the optional allow list of labels.
  25458. + *
  25459. + * <p>If non-empty, detection results whose label is not in this set will be filtered
  25460. + * out. Duplicate or unknown labels are ignored. Mutually exclusive with {@code
  25461. + * labelDenyList}. It will cause {@link IllegalStateException} when calling {@link
  25462. + * #createFromFileAndOptions}, if both {@code labelDenyList} and {@code labelAllowList}
  25463. + * are set.
  25464. + */
  25465. + public Builder setLabelAllowList(List<String> labelAllowList) {
  25466. + this.labelAllowList = Collections.unmodifiableList(new ArrayList<>(labelAllowList));
  25467. + return this;
  25468. + }
  25469. +
  25470. + /**
  25471. + * Sets the optional deny list of labels.
  25472. + *
  25473. + * <p>If non-empty, detection results whose label is in this set will be filtered out.
  25474. + * Duplicate or unknown labels are ignored. Mutually exclusive with {@code
  25475. + * labelAllowList}. It will cause {@link IllegalStateException} when calling {@link
  25476. + * #createFromFileAndOptions}, if both {@code labelDenyList} and {@code labelAllowList}
  25477. + * are set.
  25478. + */
  25479. + public Builder setLabelDenyList(List<String> labelDenyList) {
  25480. + this.labelDenyList = Collections.unmodifiableList(new ArrayList<>(labelDenyList));
  25481. + return this;
  25482. + }
  25483. +
  25484. + /**
  25485. + * Sets the number of threads to be used for TFLite ops that support multi-threading
  25486. + * when running inference with CPU. Defaults to -1.
  25487. + *
  25488. + * <p>numThreads should be greater than 0 or equal to -1. Setting numThreads to -1 has
  25489. + * the effect to let TFLite runtime set the value.
  25490. + *
  25491. + * @deprecated use {@link BaseOptions} to configure number of threads instead. This
  25492. + * method
  25493. + * will override the number of threads configured from {@link BaseOptions}.
  25494. + */
  25495. + @Deprecated
  25496. + public Builder setNumThreads(int numThreads) {
  25497. + this.numThreads = numThreads;
  25498. + return this;
  25499. + }
  25500. +
  25501. + public ObjectDetectorOptions build() {
  25502. + return new ObjectDetectorOptions(this);
  25503. + }
  25504. + }
  25505. +
  25506. + @UsedByReflection("object_detector_jni.cc")
  25507. + public String getDisplayNamesLocale() {
  25508. + return displayNamesLocale;
  25509. + }
  25510. +
  25511. + @UsedByReflection("object_detector_jni.cc")
  25512. + public int getMaxResults() {
  25513. + return maxResults;
  25514. + }
  25515. +
  25516. + @UsedByReflection("object_detector_jni.cc")
  25517. + public float getScoreThreshold() {
  25518. + return scoreThreshold;
  25519. + }
  25520. +
  25521. + @UsedByReflection("object_detector_jni.cc")
  25522. + public boolean getIsScoreThresholdSet() {
  25523. + return isScoreThresholdSet;
  25524. + }
  25525. +
  25526. + @UsedByReflection("object_detector_jni.cc")
  25527. + public List<String> getLabelAllowList() {
  25528. + return new ArrayList<>(labelAllowList);
  25529. + }
  25530. +
  25531. + @UsedByReflection("object_detector_jni.cc")
  25532. + public List<String> getLabelDenyList() {
  25533. + return new ArrayList<>(labelDenyList);
  25534. + }
  25535. +
  25536. + @UsedByReflection("object_detector_jni.cc")
  25537. + public int getNumThreads() {
  25538. + return numThreads;
  25539. + }
  25540. +
  25541. + public BaseOptions getBaseOptions() {
  25542. + return baseOptions;
  25543. + }
  25544. +
  25545. + private ObjectDetectorOptions(Builder builder) {
  25546. + displayNamesLocale = builder.displayNamesLocale;
  25547. + maxResults = builder.maxResults;
  25548. + scoreThreshold = builder.scoreThreshold;
  25549. + isScoreThresholdSet = builder.isScoreThresholdSet;
  25550. + labelAllowList = builder.labelAllowList;
  25551. + labelDenyList = builder.labelDenyList;
  25552. + numThreads = builder.numThreads;
  25553. + baseOptions = builder.baseOptions;
  25554. + }
  25555. }
  25556. - @UsedByReflection("object_detector_jni.cc")
  25557. - public boolean getIsScoreThresholdSet() {
  25558. - return isScoreThresholdSet;
  25559. + /**
  25560. + * Performs actual detection on the provided image.
  25561. + *
  25562. + * <p>{@link ObjectDetector} supports the following {@link TensorImage} color space types:
  25563. + *
  25564. + * <ul>
  25565. + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#RGB}
  25566. + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV12}
  25567. + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV21}
  25568. + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV12}
  25569. + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV21}
  25570. + * </ul>
  25571. + *
  25572. + * @param image a UINT8 {@link TensorImage} object that represents an RGB or YUV image
  25573. + * @throws IllegalStateException if there is an internal error
  25574. + * @throws RuntimeException if there is an otherwise unspecified error
  25575. + * @throws IllegalArgumentException if the color space type of image is unsupported
  25576. + */
  25577. + public List<Detection> detect(TensorImage image) {
  25578. + return detect(image, ImageProcessingOptions.builder().build());
  25579. }
  25580. - @UsedByReflection("object_detector_jni.cc")
  25581. - public List<String> getLabelAllowList() {
  25582. - return new ArrayList<>(labelAllowList);
  25583. + /**
  25584. + * Performs actual detection on the provided image.
  25585. + *
  25586. + * <p>{@link ObjectDetector} supports the following {@link TensorImage} color space types:
  25587. + *
  25588. + * <ul>
  25589. + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#RGB}
  25590. + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV12}
  25591. + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV21}
  25592. + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV12}
  25593. + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV21}
  25594. + * </ul>
  25595. + *
  25596. + * <p>{@link ObjectDetector} supports the following options:
  25597. + *
  25598. + * <ul>
  25599. + * <li>image rotation (through {@link ImageProcessingOptions.Builder#setOrientation}). It
  25600. + * defaults to {@link ImageProcessingOptions.Orientation#TOP_LEFT}.
  25601. + * </ul>
  25602. + *
  25603. + * @param image a UINT8 {@link TensorImage} object that represents an RGB or YUV image
  25604. + * @param options the options to configure how to preprocess the image
  25605. + * @throws IllegalStateException if there is an internal error
  25606. + * @throws RuntimeException if there is an otherwise unspecified error
  25607. + * @throws IllegalArgumentException if the color space type of image is unsupported
  25608. + */
  25609. + public List<Detection> detect(TensorImage image, ImageProcessingOptions options) {
  25610. + return run(new InferenceProvider<List<Detection>>() {
  25611. + @Override
  25612. + public List<Detection> run(
  25613. + long frameBufferHandle, int width, int height, ImageProcessingOptions options) {
  25614. + return detect(frameBufferHandle, options);
  25615. + }
  25616. + }, image, options);
  25617. }
  25618. - @UsedByReflection("object_detector_jni.cc")
  25619. - public List<String> getLabelDenyList() {
  25620. - return new ArrayList<>(labelDenyList);
  25621. + /**
  25622. + * Performs actual detection on the provided {@code MlImage}.
  25623. + *
  25624. + * @param image an {@code MlImage} object that represents an image
  25625. + * @throws IllegalStateException if there is an internal error
  25626. + * @throws RuntimeException if there is an otherwise unspecified error
  25627. + * @throws IllegalArgumentException if the storage type or format of the image is unsupported
  25628. + */
  25629. + public List<Detection> detect(MlImage image) {
  25630. + return detect(image, ImageProcessingOptions.builder().build());
  25631. }
  25632. - @UsedByReflection("object_detector_jni.cc")
  25633. - public int getNumThreads() {
  25634. - return numThreads;
  25635. + /**
  25636. + * Performs actual detection on the provided {@code MlImage} with {@link
  25637. + * ImageProcessingOptions}.
  25638. + *
  25639. + * <p>{@link ObjectDetector} supports the following options:
  25640. + *
  25641. + * <ul>
  25642. + * <li>image rotation (through {@link ImageProcessingOptions.Builder#setOrientation}). It
  25643. + * defaults to {@link ImageProcessingOptions.Orientation#TOP_LEFT}. {@link
  25644. + * MlImage#getRotation()} is not effective.
  25645. + * </ul>
  25646. + *
  25647. + * @param image an {@code MlImage} object that represents an image
  25648. + * @param options the options to configure how to preprocess the image
  25649. + * @throws IllegalStateException if there is an internal error
  25650. + * @throws RuntimeException if there is an otherwise unspecified error
  25651. + * @throws IllegalArgumentException if the storage type or format of the image is unsupported
  25652. + */
  25653. + public List<Detection> detect(MlImage image, ImageProcessingOptions options) {
  25654. + image.getInternal().acquire();
  25655. + TensorImage tensorImage = MlImageAdapter.createTensorImageFrom(image);
  25656. + List<Detection> result = detect(tensorImage, options);
  25657. + image.close();
  25658. + return result;
  25659. }
  25660. - public BaseOptions getBaseOptions() {
  25661. - return baseOptions;
  25662. + private List<Detection> detect(long frameBufferHandle, ImageProcessingOptions options) {
  25663. + checkNotClosed();
  25664. +
  25665. + return detectNative(getNativeHandle(), frameBufferHandle);
  25666. }
  25667. - private ObjectDetectorOptions(Builder builder) {
  25668. - displayNamesLocale = builder.displayNamesLocale;
  25669. - maxResults = builder.maxResults;
  25670. - scoreThreshold = builder.scoreThreshold;
  25671. - isScoreThresholdSet = builder.isScoreThresholdSet;
  25672. - labelAllowList = builder.labelAllowList;
  25673. - labelDenyList = builder.labelDenyList;
  25674. - numThreads = builder.numThreads;
  25675. - baseOptions = builder.baseOptions;
  25676. + private static native long initJniWithModelFdAndOptions(int fileDescriptor,
  25677. + long fileDescriptorLength, long fileDescriptorOffset, ObjectDetectorOptions options,
  25678. + long baseOptionsHandle);
  25679. +
  25680. + private static native long initJniWithByteBuffer(
  25681. + ByteBuffer modelBuffer, ObjectDetectorOptions options, long baseOptionsHandle);
  25682. +
  25683. + private static native List<Detection> detectNative(long nativeHandle, long frameBufferHandle);
  25684. +
  25685. + @Override
  25686. + protected void deinit(long nativeHandle) {
  25687. + deinitJni(nativeHandle);
  25688. }
  25689. - }
  25690. -
  25691. - /**
  25692. - * Performs actual detection on the provided image.
  25693. - *
  25694. - * <p>{@link ObjectDetector} supports the following {@link TensorImage} color space types:
  25695. - *
  25696. - * <ul>
  25697. - * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#RGB}
  25698. - * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV12}
  25699. - * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV21}
  25700. - * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV12}
  25701. - * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV21}
  25702. - * </ul>
  25703. - *
  25704. - * @param image a UINT8 {@link TensorImage} object that represents an RGB or YUV image
  25705. - * @throws IllegalStateException if there is an internal error
  25706. - * @throws RuntimeException if there is an otherwise unspecified error
  25707. - * @throws IllegalArgumentException if the color space type of image is unsupported
  25708. - */
  25709. - public List<Detection> detect(TensorImage image) {
  25710. - return detect(image, ImageProcessingOptions.builder().build());
  25711. - }
  25712. -
  25713. - /**
  25714. - * Performs actual detection on the provided image.
  25715. - *
  25716. - * <p>{@link ObjectDetector} supports the following {@link TensorImage} color space types:
  25717. - *
  25718. - * <ul>
  25719. - * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#RGB}
  25720. - * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV12}
  25721. - * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV21}
  25722. - * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV12}
  25723. - * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV21}
  25724. - * </ul>
  25725. - *
  25726. - * <p>{@link ObjectDetector} supports the following options:
  25727. - *
  25728. - * <ul>
  25729. - * <li>image rotation (through {@link ImageProcessingOptions.Builder#setOrientation}). It
  25730. - * defaults to {@link ImageProcessingOptions.Orientation#TOP_LEFT}.
  25731. - * </ul>
  25732. - *
  25733. - * @param image a UINT8 {@link TensorImage} object that represents an RGB or YUV image
  25734. - * @param options the options to configure how to preprocess the image
  25735. - * @throws IllegalStateException if there is an internal error
  25736. - * @throws RuntimeException if there is an otherwise unspecified error
  25737. - * @throws IllegalArgumentException if the color space type of image is unsupported
  25738. - */
  25739. - public List<Detection> detect(TensorImage image, ImageProcessingOptions options) {
  25740. - return run(
  25741. - new InferenceProvider<List<Detection>>() {
  25742. - @Override
  25743. - public List<Detection> run(
  25744. - long frameBufferHandle, int width, int height, ImageProcessingOptions options) {
  25745. - return detect(frameBufferHandle, options);
  25746. - }
  25747. - },
  25748. - image,
  25749. - options);
  25750. - }
  25751. -
  25752. - /**
  25753. - * Performs actual detection on the provided {@code MlImage}.
  25754. - *
  25755. - * @param image an {@code MlImage} object that represents an image
  25756. - * @throws IllegalStateException if there is an internal error
  25757. - * @throws RuntimeException if there is an otherwise unspecified error
  25758. - * @throws IllegalArgumentException if the storage type or format of the image is unsupported
  25759. - */
  25760. - public List<Detection> detect(MlImage image) {
  25761. - return detect(image, ImageProcessingOptions.builder().build());
  25762. - }
  25763. -
  25764. - /**
  25765. - * Performs actual detection on the provided {@code MlImage} with {@link ImageProcessingOptions}.
  25766. - *
  25767. - * <p>{@link ObjectDetector} supports the following options:
  25768. - *
  25769. - * <ul>
  25770. - * <li>image rotation (through {@link ImageProcessingOptions.Builder#setOrientation}). It
  25771. - * defaults to {@link ImageProcessingOptions.Orientation#TOP_LEFT}. {@link
  25772. - * MlImage#getRotation()} is not effective.
  25773. - * </ul>
  25774. - *
  25775. - * @param image an {@code MlImage} object that represents an image
  25776. - * @param options the options to configure how to preprocess the image
  25777. - * @throws IllegalStateException if there is an internal error
  25778. - * @throws RuntimeException if there is an otherwise unspecified error
  25779. - * @throws IllegalArgumentException if the storage type or format of the image is unsupported
  25780. - */
  25781. - public List<Detection> detect(MlImage image, ImageProcessingOptions options) {
  25782. - image.getInternal().acquire();
  25783. - TensorImage tensorImage = MlImageAdapter.createTensorImageFrom(image);
  25784. - List<Detection> result = detect(tensorImage, options);
  25785. - image.close();
  25786. - return result;
  25787. - }
  25788. -
  25789. - private List<Detection> detect(long frameBufferHandle, ImageProcessingOptions options) {
  25790. - checkNotClosed();
  25791. -
  25792. - return detectNative(getNativeHandle(), frameBufferHandle);
  25793. - }
  25794. -
  25795. - private static native long initJniWithModelFdAndOptions(
  25796. - int fileDescriptor,
  25797. - long fileDescriptorLength,
  25798. - long fileDescriptorOffset,
  25799. - ObjectDetectorOptions options,
  25800. - long baseOptionsHandle);
  25801. -
  25802. - private static native long initJniWithByteBuffer(
  25803. - ByteBuffer modelBuffer, ObjectDetectorOptions options, long baseOptionsHandle);
  25804. -
  25805. - private static native List<Detection> detectNative(long nativeHandle, long frameBufferHandle);
  25806. -
  25807. - @Override
  25808. - protected void deinit(long nativeHandle) {
  25809. - deinitJni(nativeHandle);
  25810. - }
  25811. -
  25812. - /**
  25813. - * Native implementation to release memory pointed by the pointer.
  25814. - *
  25815. - * @param nativeHandle pointer to memory allocated
  25816. - */
  25817. - private native void deinitJni(long nativeHandle);
  25818. +
  25819. + /**
  25820. + * Native implementation to release memory pointed by the pointer.
  25821. + *
  25822. + * @param nativeHandle pointer to memory allocated
  25823. + */
  25824. + private native void deinitJni(long nativeHandle);
  25825. }
  25826. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/searcher/ImageSearcher.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/searcher/ImageSearcher.java
  25827. index 7a02ad8a037a2..d3d1e6a4f4878 100644
  25828. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/searcher/ImageSearcher.java
  25829. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/searcher/ImageSearcher.java
  25830. @@ -19,13 +19,10 @@ import android.content.Context;
  25831. import android.content.res.AssetFileDescriptor;
  25832. import android.graphics.Rect;
  25833. import android.os.ParcelFileDescriptor;
  25834. +
  25835. import com.google.android.odml.image.MlImage;
  25836. import com.google.auto.value.AutoValue;
  25837. -import java.io.File;
  25838. -import java.io.IOException;
  25839. -import java.nio.ByteBuffer;
  25840. -import java.nio.MappedByteBuffer;
  25841. -import java.util.List;
  25842. +
  25843. import org.tensorflow.lite.support.image.MlImageAdapter;
  25844. import org.tensorflow.lite.support.image.TensorImage;
  25845. import org.tensorflow.lite.task.core.BaseOptions;
  25846. @@ -37,6 +34,12 @@ import org.tensorflow.lite.task.processor.SearcherOptions;
  25847. import org.tensorflow.lite.task.vision.core.BaseVisionTaskApi;
  25848. import org.tensorflow.lite.task.vision.core.BaseVisionTaskApi.InferenceProvider;
  25849. +import java.io.File;
  25850. +import java.io.IOException;
  25851. +import java.nio.ByteBuffer;
  25852. +import java.nio.MappedByteBuffer;
  25853. +import java.util.List;
  25854. +
  25855. /**
  25856. * Performs similarity search on images.
  25857. *
  25858. @@ -66,330 +69,292 @@ import org.tensorflow.lite.task.vision.core.BaseVisionTaskApi.InferenceProvider;
  25859. * the single file format (index file packed in the model) is supported.
  25860. */
  25861. public final class ImageSearcher extends BaseVisionTaskApi {
  25862. + private static final String IMAGE_SEARCHER_NATIVE_LIB = "task_vision_jni";
  25863. + private static final int OPTIONAL_FD_LENGTH = -1;
  25864. + private static final int OPTIONAL_FD_OFFSET = -1;
  25865. - private static final String IMAGE_SEARCHER_NATIVE_LIB = "task_vision_jni";
  25866. - private static final int OPTIONAL_FD_LENGTH = -1;
  25867. - private static final int OPTIONAL_FD_OFFSET = -1;
  25868. -
  25869. - /**
  25870. - * Creates an {@link ImageSearcher} instance from {@link ImageSearcherOptions}.
  25871. - *
  25872. - * @param modelPath path of the search model with metadata in the assets
  25873. - * @throws IOException if an I/O error occurs when loading the tflite model or the index file
  25874. - * @throws IllegalArgumentException if an argument is invalid
  25875. - * @throws IllegalStateException if there is an internal error
  25876. - * @throws RuntimeException if there is an otherwise unspecified error
  25877. - */
  25878. - public static ImageSearcher createFromFileAndOptions(
  25879. - Context context, String modelPath, final ImageSearcherOptions options) throws IOException {
  25880. - try (AssetFileDescriptor assetFileDescriptor = context.getAssets().openFd(modelPath)) {
  25881. - return createFromModelFdAndOptions(
  25882. - /*modelDescriptor=*/ assetFileDescriptor.getParcelFileDescriptor().getFd(),
  25883. - /*modelDescriptorLength=*/ assetFileDescriptor.getLength(),
  25884. - /*modelDescriptorOffset=*/ assetFileDescriptor.getStartOffset(),
  25885. - options);
  25886. + /**
  25887. + * Creates an {@link ImageSearcher} instance from {@link ImageSearcherOptions}.
  25888. + *
  25889. + * @param modelPath path of the search model with metadata in the assets
  25890. + * @throws IOException if an I/O error occurs when loading the tflite model or the index file
  25891. + * @throws IllegalArgumentException if an argument is invalid
  25892. + * @throws IllegalStateException if there is an internal error
  25893. + * @throws RuntimeException if there is an otherwise unspecified error
  25894. + */
  25895. + public static ImageSearcher createFromFileAndOptions(Context context, String modelPath,
  25896. + final ImageSearcherOptions options) throws IOException {
  25897. + try (AssetFileDescriptor assetFileDescriptor = context.getAssets().openFd(modelPath)) {
  25898. + return createFromModelFdAndOptions(
  25899. + /*modelDescriptor=*/assetFileDescriptor.getParcelFileDescriptor().getFd(),
  25900. + /*modelDescriptorLength=*/assetFileDescriptor.getLength(),
  25901. + /*modelDescriptorOffset=*/assetFileDescriptor.getStartOffset(), options);
  25902. + }
  25903. }
  25904. - }
  25905. -
  25906. - /**
  25907. - * Creates an {@link ImageSearcher} instance.
  25908. - *
  25909. - * @param modelFile the search model {@link File} instance
  25910. - * @throws IOException if an I/O error occurs when loading the tflite model or the index file
  25911. - * @throws IllegalArgumentException if an argument is invalid
  25912. - * @throws IllegalStateException if there is an internal error
  25913. - * @throws RuntimeException if there is an otherwise unspecified error
  25914. - */
  25915. - public static ImageSearcher createFromFileAndOptions(
  25916. - File modelFile, final ImageSearcherOptions options) throws IOException {
  25917. - try (ParcelFileDescriptor descriptor =
  25918. - ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
  25919. - return createFromModelFdAndOptions(
  25920. - /*modelDescriptor=*/ descriptor.getFd(),
  25921. - /*modelDescriptorLength=*/ OPTIONAL_FD_LENGTH,
  25922. - /*modelDescriptorOffset=*/ OPTIONAL_FD_OFFSET,
  25923. - options);
  25924. +
  25925. + /**
  25926. + * Creates an {@link ImageSearcher} instance.
  25927. + *
  25928. + * @param modelFile the search model {@link File} instance
  25929. + * @throws IOException if an I/O error occurs when loading the tflite model or the index file
  25930. + * @throws IllegalArgumentException if an argument is invalid
  25931. + * @throws IllegalStateException if there is an internal error
  25932. + * @throws RuntimeException if there is an otherwise unspecified error
  25933. + */
  25934. + public static ImageSearcher createFromFileAndOptions(
  25935. + File modelFile, final ImageSearcherOptions options) throws IOException {
  25936. + try (ParcelFileDescriptor descriptor =
  25937. + ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
  25938. + return createFromModelFdAndOptions(
  25939. + /*modelDescriptor=*/descriptor.getFd(),
  25940. + /*modelDescriptorLength=*/OPTIONAL_FD_LENGTH,
  25941. + /*modelDescriptorOffset=*/OPTIONAL_FD_OFFSET, options);
  25942. + }
  25943. }
  25944. - }
  25945. -
  25946. - /**
  25947. - * Creates an {@link ImageSearcher} instance with a model buffer and {@link ImageSearcherOptions}.
  25948. - *
  25949. - * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the search
  25950. - * model
  25951. - * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
  25952. - * {@link MappedByteBuffer}
  25953. - * @throws IOException if an I/O error occurs when loading the index file
  25954. - * @throws IllegalStateException if there is an internal error
  25955. - * @throws RuntimeException if there is an otherwise unspecified error
  25956. - */
  25957. - public static ImageSearcher createFromBufferAndOptions(
  25958. - final ByteBuffer modelBuffer, final ImageSearcherOptions options) throws IOException {
  25959. - if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) {
  25960. - throw new IllegalArgumentException(
  25961. - "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer.");
  25962. +
  25963. + /**
  25964. + * Creates an {@link ImageSearcher} instance with a model buffer and {@link
  25965. + * ImageSearcherOptions}.
  25966. + *
  25967. + * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the search
  25968. + * model
  25969. + * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
  25970. + * {@link MappedByteBuffer}
  25971. + * @throws IOException if an I/O error occurs when loading the index file
  25972. + * @throws IllegalStateException if there is an internal error
  25973. + * @throws RuntimeException if there is an otherwise unspecified error
  25974. + */
  25975. + public static ImageSearcher createFromBufferAndOptions(
  25976. + final ByteBuffer modelBuffer, final ImageSearcherOptions options) throws IOException {
  25977. + if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) {
  25978. + throw new IllegalArgumentException(
  25979. + "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer.");
  25980. + }
  25981. + if (options.getSearcherOptions().getIndexFile() != null) {
  25982. + try (ParcelFileDescriptor indexDescriptor =
  25983. + ParcelFileDescriptor.open(options.getSearcherOptions().getIndexFile(),
  25984. + ParcelFileDescriptor.MODE_READ_ONLY)) {
  25985. + return createFromBufferAndOptionsImpl(
  25986. + modelBuffer, options, indexDescriptor.getFd());
  25987. + }
  25988. + } else {
  25989. + return createFromBufferAndOptionsImpl(modelBuffer, options, /*indexFd=*/0);
  25990. + }
  25991. }
  25992. - if (options.getSearcherOptions().getIndexFile() != null) {
  25993. - try (ParcelFileDescriptor indexDescriptor =
  25994. - ParcelFileDescriptor.open(
  25995. - options.getSearcherOptions().getIndexFile(), ParcelFileDescriptor.MODE_READ_ONLY)) {
  25996. - return createFromBufferAndOptionsImpl(modelBuffer, options, indexDescriptor.getFd());
  25997. - }
  25998. - } else {
  25999. - return createFromBufferAndOptionsImpl(modelBuffer, options, /*indexFd=*/ 0);
  26000. +
  26001. + public static ImageSearcher createFromBufferAndOptionsImpl(
  26002. + final ByteBuffer modelBuffer, final ImageSearcherOptions options, final int indexFd) {
  26003. + return new ImageSearcher(TaskJniUtils.createHandleFromLibrary(new EmptyHandleProvider() {
  26004. + @Override
  26005. + public long createHandle() {
  26006. + return initJniWithByteBuffer(modelBuffer,
  26007. + TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions()),
  26008. + options.getSearcherOptions().getL2Normalize(),
  26009. + options.getSearcherOptions().getQuantize(), indexFd,
  26010. + options.getSearcherOptions().getMaxResults());
  26011. + }
  26012. + }, IMAGE_SEARCHER_NATIVE_LIB));
  26013. }
  26014. - }
  26015. -
  26016. - public static ImageSearcher createFromBufferAndOptionsImpl(
  26017. - final ByteBuffer modelBuffer, final ImageSearcherOptions options, final int indexFd) {
  26018. - return new ImageSearcher(
  26019. - TaskJniUtils.createHandleFromLibrary(
  26020. - new EmptyHandleProvider() {
  26021. - @Override
  26022. - public long createHandle() {
  26023. - return initJniWithByteBuffer(
  26024. - modelBuffer,
  26025. - TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions()),
  26026. - options.getSearcherOptions().getL2Normalize(),
  26027. - options.getSearcherOptions().getQuantize(),
  26028. - indexFd,
  26029. - options.getSearcherOptions().getMaxResults());
  26030. - }
  26031. - },
  26032. - IMAGE_SEARCHER_NATIVE_LIB));
  26033. - }
  26034. -
  26035. - /**
  26036. - * Constructor to initialize the JNI with a pointer from C++.
  26037. - *
  26038. - * @param nativeHandle a pointer referencing memory allocated in C++
  26039. - */
  26040. - ImageSearcher(long nativeHandle) {
  26041. - super(nativeHandle);
  26042. - }
  26043. -
  26044. - /** Options for setting up an ImageSearcher. */
  26045. - @AutoValue
  26046. - public abstract static class ImageSearcherOptions {
  26047. -
  26048. - abstract BaseOptions getBaseOptions();
  26049. -
  26050. - abstract SearcherOptions getSearcherOptions();
  26051. -
  26052. - public static Builder builder() {
  26053. - return new AutoValue_ImageSearcher_ImageSearcherOptions.Builder()
  26054. - .setBaseOptions(BaseOptions.builder().build())
  26055. - .setSearcherOptions(SearcherOptions.builder().build());
  26056. +
  26057. + /**
  26058. + * Constructor to initialize the JNI with a pointer from C++.
  26059. + *
  26060. + * @param nativeHandle a pointer referencing memory allocated in C++
  26061. + */
  26062. + ImageSearcher(long nativeHandle) {
  26063. + super(nativeHandle);
  26064. }
  26065. - /** Builder for {@link ImageSearcherOptions}. */
  26066. - @AutoValue.Builder
  26067. - public abstract static class Builder {
  26068. - /** Sets the general options to configure Task APIs, such as accelerators. */
  26069. - public abstract Builder setBaseOptions(BaseOptions baseOptions);
  26070. + /** Options for setting up an ImageSearcher. */
  26071. + @AutoValue
  26072. + public abstract static class ImageSearcherOptions {
  26073. + abstract BaseOptions getBaseOptions();
  26074. +
  26075. + abstract SearcherOptions getSearcherOptions();
  26076. +
  26077. + public static Builder builder() {
  26078. + return new AutoValue_ImageSearcher_ImageSearcherOptions.Builder()
  26079. + .setBaseOptions(BaseOptions.builder().build())
  26080. + .setSearcherOptions(SearcherOptions.builder().build());
  26081. + }
  26082. +
  26083. + /** Builder for {@link ImageSearcherOptions}. */
  26084. + @AutoValue.Builder
  26085. + public abstract static class Builder {
  26086. + /** Sets the general options to configure Task APIs, such as accelerators. */
  26087. + public abstract Builder setBaseOptions(BaseOptions baseOptions);
  26088. - /** Sets the options to configure Searcher API. */
  26089. - public abstract Builder setSearcherOptions(SearcherOptions searcherOptions);
  26090. + /** Sets the options to configure Searcher API. */
  26091. + public abstract Builder setSearcherOptions(SearcherOptions searcherOptions);
  26092. - public abstract ImageSearcherOptions build();
  26093. + public abstract ImageSearcherOptions build();
  26094. + }
  26095. }
  26096. - }
  26097. -
  26098. - /**
  26099. - * Performs embedding extraction on the provided {@link TensorImage}, followed by nearest-neighbor
  26100. - * search in the index.
  26101. - *
  26102. - * <p>{@link ImageSearcher} supports the following {@link TensorImage} color space types:
  26103. - *
  26104. - * <ul>
  26105. - * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#RGB}
  26106. - * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV12}
  26107. - * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV21}
  26108. - * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV12}
  26109. - * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV21}
  26110. - * </ul>
  26111. - *
  26112. - * @param image a UINT8 {@link TensorImage} object that represents an RGB or YUV image
  26113. - * @throws IllegalArgumentException if the color space type of image is unsupported
  26114. - */
  26115. - public List<NearestNeighbor> search(TensorImage image) {
  26116. - return search(image, ImageProcessingOptions.builder().build());
  26117. - }
  26118. -
  26119. - /**
  26120. - * Performs embedding extraction on the provided {@link TensorImage} with {@link
  26121. - * ImageProcessingOptions}, followed by nearest-neighbor search in the index.
  26122. - *
  26123. - * <p>{@link ImageSearcher} supports the following options:
  26124. - *
  26125. - * <ul>
  26126. - * <li>Region of interest (ROI) (through {@link ImageProcessingOptions.Builder#setRoi}). It
  26127. - * defaults to the entire image.
  26128. - * <li>image rotation (through {@link ImageProcessingOptions.Builder#setOrientation}). It
  26129. - * defaults to {@link ImageProcessingOptions.Orientation#TOP_LEFT}.
  26130. - * </ul>
  26131. - *
  26132. - * <p>{@link ImageSearcher} supports the following {@link TensorImage} color space types:
  26133. - *
  26134. - * <ul>
  26135. - * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#RGB}
  26136. - * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV12}
  26137. - * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV21}
  26138. - * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV12}
  26139. - * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV21}
  26140. - * </ul>
  26141. - *
  26142. - * @param image a UINT8 {@link TensorImage} object that represents an RGB or YUV image
  26143. - * @throws IllegalArgumentException if the color space type of image is unsupported
  26144. - */
  26145. - public List<NearestNeighbor> search(TensorImage image, ImageProcessingOptions options) {
  26146. - return run(
  26147. - new InferenceProvider<List<NearestNeighbor>>() {
  26148. - @Override
  26149. - public List<NearestNeighbor> run(
  26150. - long frameBufferHandle, int width, int height, ImageProcessingOptions options) {
  26151. - return search(frameBufferHandle, width, height, options);
  26152. - }
  26153. - },
  26154. - image,
  26155. - options);
  26156. - }
  26157. -
  26158. - /**
  26159. - * Performs embedding extraction on the provided {@code MlImage}, followed by nearest-neighbor
  26160. - * search in the index.
  26161. - *
  26162. - * @param image an {@code MlImage} object that represents an image
  26163. - * @throws IllegalArgumentException if the storage type or format of the image is unsupported
  26164. - */
  26165. - public List<NearestNeighbor> search(MlImage image) {
  26166. - return search(image, ImageProcessingOptions.builder().build());
  26167. - }
  26168. -
  26169. - /**
  26170. - * Performs embedding extraction on the provided {@code MlImage} with {@link
  26171. - * ImageProcessingOptions}, followed by nearest-neighbor search in the index.
  26172. - *
  26173. - * <p>{@link ImageSearcher} supports the following options:
  26174. - *
  26175. - * <ul>
  26176. - * <li>Region of interest (ROI) (through {@link ImageProcessingOptions.Builder#setRoi}). It
  26177. - * defaults to the entire image.
  26178. - * <li>image rotation (through {@link ImageProcessingOptions.Builder#setOrientation}). It
  26179. - * defaults to {@link ImageProcessingOptions.Orientation#TOP_LEFT}. {@link
  26180. - * MlImage#getRotation()} is not effective.
  26181. - * </ul>
  26182. - *
  26183. - * @param image a {@code MlImage} object that represents an image
  26184. - * @param options configures options including ROI and rotation
  26185. - * @throws IllegalArgumentException if the storage type or format of the image is unsupported
  26186. - */
  26187. - public List<NearestNeighbor> search(MlImage image, ImageProcessingOptions options) {
  26188. - image.getInternal().acquire();
  26189. - TensorImage tensorImage = MlImageAdapter.createTensorImageFrom(image);
  26190. - List<NearestNeighbor> result = search(tensorImage, options);
  26191. - image.close();
  26192. - return result;
  26193. - }
  26194. -
  26195. - private List<NearestNeighbor> search(
  26196. - long frameBufferHandle, int width, int height, ImageProcessingOptions options) {
  26197. - checkNotClosed();
  26198. - Rect roi = options.getRoi().isEmpty() ? new Rect(0, 0, width, height) : options.getRoi();
  26199. - return searchNative(
  26200. - getNativeHandle(),
  26201. - frameBufferHandle,
  26202. - new int[] {roi.left, roi.top, roi.width(), roi.height()});
  26203. - }
  26204. -
  26205. - private static ImageSearcher createFromModelFdAndOptions(
  26206. - final int modelDescriptor,
  26207. - final long modelDescriptorLength,
  26208. - final long modelDescriptorOffset,
  26209. - final ImageSearcherOptions options)
  26210. - throws IOException {
  26211. - if (options.getSearcherOptions().getIndexFile() != null) {
  26212. - // indexDescriptor must be alive before ImageSearcher is initialized completely in the native
  26213. - // layer.
  26214. - try (ParcelFileDescriptor indexDescriptor =
  26215. - ParcelFileDescriptor.open(
  26216. - options.getSearcherOptions().getIndexFile(), ParcelFileDescriptor.MODE_READ_ONLY)) {
  26217. - return createFromModelFdAndOptionsImpl(
  26218. - modelDescriptor,
  26219. - modelDescriptorLength,
  26220. - modelDescriptorOffset,
  26221. - options,
  26222. - indexDescriptor.getFd());
  26223. - }
  26224. - } else {
  26225. - // Index file is not configured. We'll check if the model contains one in the native layer.
  26226. - return createFromModelFdAndOptionsImpl(
  26227. - modelDescriptor, modelDescriptorLength, modelDescriptorOffset, options, /*indexFd=*/ 0);
  26228. +
  26229. + /**
  26230. + * Performs embedding extraction on the provided {@link TensorImage}, followed by
  26231. + * nearest-neighbor search in the index.
  26232. + *
  26233. + * <p>{@link ImageSearcher} supports the following {@link TensorImage} color space types:
  26234. + *
  26235. + * <ul>
  26236. + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#RGB}
  26237. + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV12}
  26238. + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV21}
  26239. + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV12}
  26240. + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV21}
  26241. + * </ul>
  26242. + *
  26243. + * @param image a UINT8 {@link TensorImage} object that represents an RGB or YUV image
  26244. + * @throws IllegalArgumentException if the color space type of image is unsupported
  26245. + */
  26246. + public List<NearestNeighbor> search(TensorImage image) {
  26247. + return search(image, ImageProcessingOptions.builder().build());
  26248. + }
  26249. +
  26250. + /**
  26251. + * Performs embedding extraction on the provided {@link TensorImage} with {@link
  26252. + * ImageProcessingOptions}, followed by nearest-neighbor search in the index.
  26253. + *
  26254. + * <p>{@link ImageSearcher} supports the following options:
  26255. + *
  26256. + * <ul>
  26257. + * <li>Region of interest (ROI) (through {@link ImageProcessingOptions.Builder#setRoi}). It
  26258. + * defaults to the entire image.
  26259. + * <li>image rotation (through {@link ImageProcessingOptions.Builder#setOrientation}). It
  26260. + * defaults to {@link ImageProcessingOptions.Orientation#TOP_LEFT}.
  26261. + * </ul>
  26262. + *
  26263. + * <p>{@link ImageSearcher} supports the following {@link TensorImage} color space types:
  26264. + *
  26265. + * <ul>
  26266. + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#RGB}
  26267. + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV12}
  26268. + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV21}
  26269. + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV12}
  26270. + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV21}
  26271. + * </ul>
  26272. + *
  26273. + * @param image a UINT8 {@link TensorImage} object that represents an RGB or YUV image
  26274. + * @throws IllegalArgumentException if the color space type of image is unsupported
  26275. + */
  26276. + public List<NearestNeighbor> search(TensorImage image, ImageProcessingOptions options) {
  26277. + return run(new InferenceProvider<List<NearestNeighbor>>() {
  26278. + @Override
  26279. + public List<NearestNeighbor> run(
  26280. + long frameBufferHandle, int width, int height, ImageProcessingOptions options) {
  26281. + return search(frameBufferHandle, width, height, options);
  26282. + }
  26283. + }, image, options);
  26284. + }
  26285. +
  26286. + /**
  26287. + * Performs embedding extraction on the provided {@code MlImage}, followed by nearest-neighbor
  26288. + * search in the index.
  26289. + *
  26290. + * @param image an {@code MlImage} object that represents an image
  26291. + * @throws IllegalArgumentException if the storage type or format of the image is unsupported
  26292. + */
  26293. + public List<NearestNeighbor> search(MlImage image) {
  26294. + return search(image, ImageProcessingOptions.builder().build());
  26295. }
  26296. - }
  26297. -
  26298. - private static ImageSearcher createFromModelFdAndOptionsImpl(
  26299. - final int modelDescriptor,
  26300. - final long modelDescriptorLength,
  26301. - final long modelDescriptorOffset,
  26302. - final ImageSearcherOptions options,
  26303. - final int indexFd) {
  26304. - long nativeHandle =
  26305. - TaskJniUtils.createHandleFromLibrary(
  26306. - new EmptyHandleProvider() {
  26307. - @Override
  26308. - public long createHandle() {
  26309. - return initJniWithModelFdAndOptions(
  26310. - modelDescriptor,
  26311. - modelDescriptorLength,
  26312. - modelDescriptorOffset,
  26313. - TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions()),
  26314. - options.getSearcherOptions().getL2Normalize(),
  26315. - options.getSearcherOptions().getQuantize(),
  26316. - indexFd,
  26317. - options.getSearcherOptions().getMaxResults());
  26318. - }
  26319. - },
  26320. - IMAGE_SEARCHER_NATIVE_LIB);
  26321. - return new ImageSearcher(nativeHandle);
  26322. - }
  26323. -
  26324. - private static native long initJniWithModelFdAndOptions(
  26325. - int modelDescriptor,
  26326. - long modelDescriptorLength,
  26327. - long modelDescriptorOffset,
  26328. - long baseOptionsHandle,
  26329. - boolean l2Normalize,
  26330. - boolean quantize,
  26331. - int indexDescriptor,
  26332. - int maxResults);
  26333. -
  26334. - private static native long initJniWithByteBuffer(
  26335. - ByteBuffer modelBuffer,
  26336. - long baseOptionsHandle,
  26337. - boolean l2Normalize,
  26338. - boolean quantize,
  26339. - int indexFileDescriptor,
  26340. - int maxResults);
  26341. -
  26342. - /**
  26343. - * The native method to search an image based on the ROI specified.
  26344. - *
  26345. - * @param roi the ROI of the input image, an array representing the bounding box as {left, top,
  26346. - * width, height}
  26347. - */
  26348. - private static native List<NearestNeighbor> searchNative(
  26349. - long nativeHandle, long frameBufferHandle, int[] roi);
  26350. -
  26351. - @Override
  26352. - protected void deinit(long nativeHandle) {
  26353. - deinitJni(nativeHandle);
  26354. - }
  26355. -
  26356. - /**
  26357. - * Native implementation to release memory pointed by the pointer.
  26358. - *
  26359. - * @param nativeHandle pointer to memory allocated
  26360. - */
  26361. - private native void deinitJni(long nativeHandle);
  26362. +
  26363. + /**
  26364. + * Performs embedding extraction on the provided {@code MlImage} with {@link
  26365. + * ImageProcessingOptions}, followed by nearest-neighbor search in the index.
  26366. + *
  26367. + * <p>{@link ImageSearcher} supports the following options:
  26368. + *
  26369. + * <ul>
  26370. + * <li>Region of interest (ROI) (through {@link ImageProcessingOptions.Builder#setRoi}). It
  26371. + * defaults to the entire image.
  26372. + * <li>image rotation (through {@link ImageProcessingOptions.Builder#setOrientation}). It
  26373. + * defaults to {@link ImageProcessingOptions.Orientation#TOP_LEFT}. {@link
  26374. + * MlImage#getRotation()} is not effective.
  26375. + * </ul>
  26376. + *
  26377. + * @param image a {@code MlImage} object that represents an image
  26378. + * @param options configures options including ROI and rotation
  26379. + * @throws IllegalArgumentException if the storage type or format of the image is unsupported
  26380. + */
  26381. + public List<NearestNeighbor> search(MlImage image, ImageProcessingOptions options) {
  26382. + image.getInternal().acquire();
  26383. + TensorImage tensorImage = MlImageAdapter.createTensorImageFrom(image);
  26384. + List<NearestNeighbor> result = search(tensorImage, options);
  26385. + image.close();
  26386. + return result;
  26387. + }
  26388. +
  26389. + private List<NearestNeighbor> search(
  26390. + long frameBufferHandle, int width, int height, ImageProcessingOptions options) {
  26391. + checkNotClosed();
  26392. + Rect roi = options.getRoi().isEmpty() ? new Rect(0, 0, width, height) : options.getRoi();
  26393. + return searchNative(getNativeHandle(), frameBufferHandle,
  26394. + new int[] {roi.left, roi.top, roi.width(), roi.height()});
  26395. + }
  26396. +
  26397. + private static ImageSearcher createFromModelFdAndOptions(final int modelDescriptor,
  26398. + final long modelDescriptorLength, final long modelDescriptorOffset,
  26399. + final ImageSearcherOptions options) throws IOException {
  26400. + if (options.getSearcherOptions().getIndexFile() != null) {
  26401. + // indexDescriptor must be alive before ImageSearcher is initialized completely in the
  26402. + // native layer.
  26403. + try (ParcelFileDescriptor indexDescriptor =
  26404. + ParcelFileDescriptor.open(options.getSearcherOptions().getIndexFile(),
  26405. + ParcelFileDescriptor.MODE_READ_ONLY)) {
  26406. + return createFromModelFdAndOptionsImpl(modelDescriptor, modelDescriptorLength,
  26407. + modelDescriptorOffset, options, indexDescriptor.getFd());
  26408. + }
  26409. + } else {
  26410. + // Index file is not configured. We'll check if the model contains one in the native
  26411. + // layer.
  26412. + return createFromModelFdAndOptionsImpl(modelDescriptor, modelDescriptorLength,
  26413. + modelDescriptorOffset, options, /*indexFd=*/0);
  26414. + }
  26415. + }
  26416. +
  26417. + private static ImageSearcher createFromModelFdAndOptionsImpl(final int modelDescriptor,
  26418. + final long modelDescriptorLength, final long modelDescriptorOffset,
  26419. + final ImageSearcherOptions options, final int indexFd) {
  26420. + long nativeHandle = TaskJniUtils.createHandleFromLibrary(new EmptyHandleProvider() {
  26421. + @Override
  26422. + public long createHandle() {
  26423. + return initJniWithModelFdAndOptions(modelDescriptor, modelDescriptorLength,
  26424. + modelDescriptorOffset,
  26425. + TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions()),
  26426. + options.getSearcherOptions().getL2Normalize(),
  26427. + options.getSearcherOptions().getQuantize(), indexFd,
  26428. + options.getSearcherOptions().getMaxResults());
  26429. + }
  26430. + }, IMAGE_SEARCHER_NATIVE_LIB);
  26431. + return new ImageSearcher(nativeHandle);
  26432. + }
  26433. +
  26434. + private static native long initJniWithModelFdAndOptions(int modelDescriptor,
  26435. + long modelDescriptorLength, long modelDescriptorOffset, long baseOptionsHandle,
  26436. + boolean l2Normalize, boolean quantize, int indexDescriptor, int maxResults);
  26437. +
  26438. + private static native long initJniWithByteBuffer(ByteBuffer modelBuffer, long baseOptionsHandle,
  26439. + boolean l2Normalize, boolean quantize, int indexFileDescriptor, int maxResults);
  26440. +
  26441. + /**
  26442. + * The native method to search an image based on the ROI specified.
  26443. + *
  26444. + * @param roi the ROI of the input image, an array representing the bounding box as {left, top,
  26445. + * width, height}
  26446. + */
  26447. + private static native List<NearestNeighbor> searchNative(
  26448. + long nativeHandle, long frameBufferHandle, int[] roi);
  26449. +
  26450. + @Override
  26451. + protected void deinit(long nativeHandle) {
  26452. + deinitJni(nativeHandle);
  26453. + }
  26454. +
  26455. + /**
  26456. + * Native implementation to release memory pointed by the pointer.
  26457. + *
  26458. + * @param nativeHandle pointer to memory allocated
  26459. + */
  26460. + private native void deinitJni(long nativeHandle);
  26461. }
  26462. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/ColoredLabel.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/ColoredLabel.java
  26463. index a92e70ebc09b4..7a7a5b323f43b 100644
  26464. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/ColoredLabel.java
  26465. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/ColoredLabel.java
  26466. @@ -17,72 +17,74 @@ package org.tensorflow.lite.task.vision.segmenter;
  26467. import android.graphics.Color;
  26468. import android.os.Build;
  26469. +
  26470. import androidx.annotation.RequiresApi;
  26471. +
  26472. import com.google.auto.value.AutoValue;
  26473. +
  26474. import org.tensorflow.lite.task.core.annotations.UsedByReflection;
  26475. /** Represents a label associated with a color for display purposes. */
  26476. @AutoValue
  26477. @UsedByReflection("image_segmentation_jni.cc")
  26478. public abstract class ColoredLabel {
  26479. + /**
  26480. + * Creates a {@link ColoredLabel} object with an ARGB color int.
  26481. + *
  26482. + * @param label the label string, as provided in the label map packed in the TFLite Model
  26483. + * Metadata.
  26484. + * @param displayName the display name of label, as configured through {@link
  26485. + * ImageSegmenter.ImageSegmenterOptions.Builder#setDisplayNamesLocale}
  26486. + * @param argb the color components for the label in ARGB. See <a
  26487. + * href="https://developer.android.com/reference/android/graphics/Color#color-ints">Android
  26488. + * Color ints.</a> for more details.
  26489. + */
  26490. + @UsedByReflection("image_segmentation_jni.cc")
  26491. + public static ColoredLabel create(String label, String displayName, int argb) {
  26492. + return new AutoValue_ColoredLabel(label, displayName, argb);
  26493. + }
  26494. - /**
  26495. - * Creates a {@link ColoredLabel} object with an ARGB color int.
  26496. - *
  26497. - * @param label the label string, as provided in the label map packed in the TFLite Model
  26498. - * Metadata.
  26499. - * @param displayName the display name of label, as configured through {@link
  26500. - * ImageSegmenter.ImageSegmenterOptions.Builder#setDisplayNamesLocale}
  26501. - * @param argb the color components for the label in ARGB. See <a
  26502. - * href="https://developer.android.com/reference/android/graphics/Color#color-ints">Android
  26503. - * Color ints.</a> for more details.
  26504. - */
  26505. - @UsedByReflection("image_segmentation_jni.cc")
  26506. - public static ColoredLabel create(String label, String displayName, int argb) {
  26507. - return new AutoValue_ColoredLabel(label, displayName, argb);
  26508. - }
  26509. -
  26510. - /**
  26511. - * Creates a {@link ColoredLabel} object with a {@link android.graphics.Color} instance.
  26512. - *
  26513. - * @param label the label string, as provided in the label map packed in the TFLite Model
  26514. - * Metadata.
  26515. - * @param displayName the display name of label, as configured through {@link
  26516. - * ImageSegmenter.ImageSegmenterOptions.Builder#setDisplayNamesLocale}
  26517. - * @param color the color components for the label. The Color instatnce is supported on Android
  26518. - * API level 26 and above. For API level lower than 26, use {@link #create(String, String,
  26519. - * int)}. See <a
  26520. - * href="https://developer.android.com/reference/android/graphics/Color#color-instances">Android
  26521. - * Color instances.</a> for more details.
  26522. - */
  26523. - @RequiresApi(Build.VERSION_CODES.O)
  26524. - public static ColoredLabel create(String label, String displayName, Color color) {
  26525. - return new AutoValue_ColoredLabel(label, displayName, color.toArgb());
  26526. - }
  26527. + /**
  26528. + * Creates a {@link ColoredLabel} object with a {@link android.graphics.Color} instance.
  26529. + *
  26530. + * @param label the label string, as provided in the label map packed in the TFLite Model
  26531. + * Metadata.
  26532. + * @param displayName the display name of label, as configured through {@link
  26533. + * ImageSegmenter.ImageSegmenterOptions.Builder#setDisplayNamesLocale}
  26534. + * @param color the color components for the label. The Color instatnce is supported on Android
  26535. + * API level 26 and above. For API level lower than 26, use {@link #create(String, String,
  26536. + * int)}. See <a
  26537. + * href="https://developer.android.com/reference/android/graphics/Color#color-instances">Android
  26538. + * Color instances.</a> for more details.
  26539. + */
  26540. + @RequiresApi(Build.VERSION_CODES.O)
  26541. + public static ColoredLabel create(String label, String displayName, Color color) {
  26542. + return new AutoValue_ColoredLabel(label, displayName, color.toArgb());
  26543. + }
  26544. - public abstract String getlabel();
  26545. + public abstract String getlabel();
  26546. - public abstract String getDisplayName();
  26547. + public abstract String getDisplayName();
  26548. - /**
  26549. - * Gets the ARGB int that represents the color.
  26550. - *
  26551. - * <p>See <a
  26552. - * href="https://developer.android.com/reference/android/graphics/Color#color-ints">Android Color
  26553. - * ints.</a> for more details.
  26554. - */
  26555. - public abstract int getArgb();
  26556. + /**
  26557. + * Gets the ARGB int that represents the color.
  26558. + *
  26559. + * <p>See <a
  26560. + * href="https://developer.android.com/reference/android/graphics/Color#color-ints">Android
  26561. + * Color ints.</a> for more details.
  26562. + */
  26563. + public abstract int getArgb();
  26564. - /**
  26565. - * Gets the {@link android.graphics.Color} instance of the underlying color.
  26566. - *
  26567. - * <p>The Color instatnce is supported on Android API level 26 and above. For API level lower than
  26568. - * 26, use {@link #getArgb()}. See <a
  26569. - * href="https://developer.android.com/reference/android/graphics/Color#color-instances">Android
  26570. - * Color instances.</a> for more details.
  26571. - */
  26572. - @RequiresApi(Build.VERSION_CODES.O)
  26573. - public Color getColor() {
  26574. - return Color.valueOf(getArgb());
  26575. - }
  26576. + /**
  26577. + * Gets the {@link android.graphics.Color} instance of the underlying color.
  26578. + *
  26579. + * <p>The Color instatnce is supported on Android API level 26 and above. For API level lower
  26580. + * than 26, use {@link #getArgb()}. See <a
  26581. + * href="https://developer.android.com/reference/android/graphics/Color#color-instances">Android
  26582. + * Color instances.</a> for more details.
  26583. + */
  26584. + @RequiresApi(Build.VERSION_CODES.O)
  26585. + public Color getColor() {
  26586. + return Color.valueOf(getArgb());
  26587. + }
  26588. }
  26589. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/ImageSegmenter.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/ImageSegmenter.java
  26590. index 0caa7a33e1729..4c3b36304a0e3 100644
  26591. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/ImageSegmenter.java
  26592. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/ImageSegmenter.java
  26593. @@ -18,16 +18,10 @@ package org.tensorflow.lite.task.vision.segmenter;
  26594. import android.content.Context;
  26595. import android.content.res.AssetFileDescriptor;
  26596. import android.os.ParcelFileDescriptor;
  26597. +
  26598. import com.google.android.odml.image.MlImage;
  26599. import com.google.auto.value.AutoValue;
  26600. -import java.io.File;
  26601. -import java.io.IOException;
  26602. -import java.nio.ByteBuffer;
  26603. -import java.nio.ByteOrder;
  26604. -import java.nio.MappedByteBuffer;
  26605. -import java.util.ArrayList;
  26606. -import java.util.Arrays;
  26607. -import java.util.List;
  26608. +
  26609. import org.tensorflow.lite.support.image.MlImageAdapter;
  26610. import org.tensorflow.lite.support.image.TensorImage;
  26611. import org.tensorflow.lite.task.core.BaseOptions;
  26612. @@ -37,6 +31,15 @@ import org.tensorflow.lite.task.core.vision.ImageProcessingOptions;
  26613. import org.tensorflow.lite.task.vision.core.BaseVisionTaskApi;
  26614. import org.tensorflow.lite.task.vision.core.BaseVisionTaskApi.InferenceProvider;
  26615. +import java.io.File;
  26616. +import java.io.IOException;
  26617. +import java.nio.ByteBuffer;
  26618. +import java.nio.ByteOrder;
  26619. +import java.nio.MappedByteBuffer;
  26620. +import java.util.ArrayList;
  26621. +import java.util.Arrays;
  26622. +import java.util.List;
  26623. +
  26624. /**
  26625. * Performs segmentation on images.
  26626. *
  26627. @@ -75,394 +78,365 @@ import org.tensorflow.lite.task.vision.core.BaseVisionTaskApi.InferenceProvider;
  26628. * href="https://tfhub.dev/tensorflow/lite-model/deeplabv3/1/metadata/1">TensorFlow Hub.</a>.
  26629. */
  26630. public final class ImageSegmenter extends BaseVisionTaskApi {
  26631. + private static final String IMAGE_SEGMENTER_NATIVE_LIB = "task_vision_jni";
  26632. + private static final int OPTIONAL_FD_LENGTH = -1;
  26633. + private static final int OPTIONAL_FD_OFFSET = -1;
  26634. +
  26635. + private final OutputType outputType;
  26636. +
  26637. + /**
  26638. + * Creates an {@link ImageSegmenter} instance from the default {@link ImageSegmenterOptions}.
  26639. + *
  26640. + * @param modelPath path of the segmentation model with metadata in the assets
  26641. + * @throws IOException if an I/O error occurs when loading the tflite model
  26642. + * @throws IllegalArgumentException if an argument is invalid
  26643. + * @throws IllegalStateException if there is an internal error
  26644. + * @throws RuntimeException if there is an otherwise unspecified error
  26645. + */
  26646. + public static ImageSegmenter createFromFile(Context context, String modelPath)
  26647. + throws IOException {
  26648. + return createFromFileAndOptions(
  26649. + context, modelPath, ImageSegmenterOptions.builder().build());
  26650. + }
  26651. - private static final String IMAGE_SEGMENTER_NATIVE_LIB = "task_vision_jni";
  26652. - private static final int OPTIONAL_FD_LENGTH = -1;
  26653. - private static final int OPTIONAL_FD_OFFSET = -1;
  26654. -
  26655. - private final OutputType outputType;
  26656. -
  26657. - /**
  26658. - * Creates an {@link ImageSegmenter} instance from the default {@link ImageSegmenterOptions}.
  26659. - *
  26660. - * @param modelPath path of the segmentation model with metadata in the assets
  26661. - * @throws IOException if an I/O error occurs when loading the tflite model
  26662. - * @throws IllegalArgumentException if an argument is invalid
  26663. - * @throws IllegalStateException if there is an internal error
  26664. - * @throws RuntimeException if there is an otherwise unspecified error
  26665. - */
  26666. - public static ImageSegmenter createFromFile(Context context, String modelPath)
  26667. - throws IOException {
  26668. - return createFromFileAndOptions(context, modelPath, ImageSegmenterOptions.builder().build());
  26669. - }
  26670. -
  26671. - /**
  26672. - * Creates an {@link ImageSegmenter} instance from the default {@link ImageSegmenterOptions}.
  26673. - *
  26674. - * @param modelFile the segmentation model {@link File} instance
  26675. - * @throws IOException if an I/O error occurs when loading the tflite model
  26676. - * @throws IllegalArgumentException if an argument is invalid
  26677. - * @throws IllegalStateException if there is an internal error
  26678. - * @throws RuntimeException if there is an otherwise unspecified error
  26679. - */
  26680. - public static ImageSegmenter createFromFile(File modelFile) throws IOException {
  26681. - return createFromFileAndOptions(modelFile, ImageSegmenterOptions.builder().build());
  26682. - }
  26683. -
  26684. - /**
  26685. - * Creates an {@link ImageSegmenter} instance with a model buffer and the default {@link
  26686. - * ImageSegmenterOptions}.
  26687. - *
  26688. - * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the
  26689. - * segmentation model
  26690. - * @throws IllegalStateException if there is an internal error
  26691. - * @throws RuntimeException if there is an otherwise unspecified error
  26692. - * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
  26693. - * {@link MappedByteBuffer}
  26694. - */
  26695. - public static ImageSegmenter createFromBuffer(final ByteBuffer modelBuffer) {
  26696. - return createFromBufferAndOptions(modelBuffer, ImageSegmenterOptions.builder().build());
  26697. - }
  26698. -
  26699. - /**
  26700. - * Creates an {@link ImageSegmenter} instance from {@link ImageSegmenterOptions}.
  26701. - *
  26702. - * @param modelPath path of the segmentation model with metadata in the assets
  26703. - * @throws IOException if an I/O error occurs when loading the tflite model
  26704. - * @throws IllegalArgumentException if an argument is invalid
  26705. - * @throws IllegalStateException if there is an internal error
  26706. - * @throws RuntimeException if there is an otherwise unspecified error
  26707. - */
  26708. - public static ImageSegmenter createFromFileAndOptions(
  26709. - Context context, String modelPath, final ImageSegmenterOptions options) throws IOException {
  26710. - try (AssetFileDescriptor assetFileDescriptor = context.getAssets().openFd(modelPath)) {
  26711. - return createFromModelFdAndOptions(
  26712. - /*fileDescriptor=*/ assetFileDescriptor.getParcelFileDescriptor().getFd(),
  26713. - /*fileDescriptorLength=*/ assetFileDescriptor.getLength(),
  26714. - /*fileDescriptorOffset=*/ assetFileDescriptor.getStartOffset(),
  26715. - options);
  26716. + /**
  26717. + * Creates an {@link ImageSegmenter} instance from the default {@link ImageSegmenterOptions}.
  26718. + *
  26719. + * @param modelFile the segmentation model {@link File} instance
  26720. + * @throws IOException if an I/O error occurs when loading the tflite model
  26721. + * @throws IllegalArgumentException if an argument is invalid
  26722. + * @throws IllegalStateException if there is an internal error
  26723. + * @throws RuntimeException if there is an otherwise unspecified error
  26724. + */
  26725. + public static ImageSegmenter createFromFile(File modelFile) throws IOException {
  26726. + return createFromFileAndOptions(modelFile, ImageSegmenterOptions.builder().build());
  26727. }
  26728. - }
  26729. -
  26730. - /**
  26731. - * Creates an {@link ImageSegmenter} instance from {@link ImageSegmenterOptions}.
  26732. - *
  26733. - * @param modelFile the segmentation model {@link File} instance
  26734. - * @throws IOException if an I/O error occurs when loading the tflite model
  26735. - * @throws IllegalArgumentException if an argument is invalid
  26736. - * @throws IllegalStateException if there is an internal error
  26737. - * @throws RuntimeException if there is an otherwise unspecified error
  26738. - */
  26739. - public static ImageSegmenter createFromFileAndOptions(
  26740. - File modelFile, final ImageSegmenterOptions options) throws IOException {
  26741. - try (ParcelFileDescriptor descriptor =
  26742. - ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
  26743. - return createFromModelFdAndOptions(
  26744. - /*fileDescriptor=*/ descriptor.getFd(),
  26745. - /*fileDescriptorLength=*/ OPTIONAL_FD_LENGTH,
  26746. - /*fileDescriptorOffset=*/ OPTIONAL_FD_OFFSET,
  26747. - options);
  26748. +
  26749. + /**
  26750. + * Creates an {@link ImageSegmenter} instance with a model buffer and the default {@link
  26751. + * ImageSegmenterOptions}.
  26752. + *
  26753. + * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the
  26754. + * segmentation model
  26755. + * @throws IllegalStateException if there is an internal error
  26756. + * @throws RuntimeException if there is an otherwise unspecified error
  26757. + * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
  26758. + * {@link MappedByteBuffer}
  26759. + */
  26760. + public static ImageSegmenter createFromBuffer(final ByteBuffer modelBuffer) {
  26761. + return createFromBufferAndOptions(modelBuffer, ImageSegmenterOptions.builder().build());
  26762. + }
  26763. +
  26764. + /**
  26765. + * Creates an {@link ImageSegmenter} instance from {@link ImageSegmenterOptions}.
  26766. + *
  26767. + * @param modelPath path of the segmentation model with metadata in the assets
  26768. + * @throws IOException if an I/O error occurs when loading the tflite model
  26769. + * @throws IllegalArgumentException if an argument is invalid
  26770. + * @throws IllegalStateException if there is an internal error
  26771. + * @throws RuntimeException if there is an otherwise unspecified error
  26772. + */
  26773. + public static ImageSegmenter createFromFileAndOptions(Context context, String modelPath,
  26774. + final ImageSegmenterOptions options) throws IOException {
  26775. + try (AssetFileDescriptor assetFileDescriptor = context.getAssets().openFd(modelPath)) {
  26776. + return createFromModelFdAndOptions(
  26777. + /*fileDescriptor=*/assetFileDescriptor.getParcelFileDescriptor().getFd(),
  26778. + /*fileDescriptorLength=*/assetFileDescriptor.getLength(),
  26779. + /*fileDescriptorOffset=*/assetFileDescriptor.getStartOffset(), options);
  26780. + }
  26781. + }
  26782. +
  26783. + /**
  26784. + * Creates an {@link ImageSegmenter} instance from {@link ImageSegmenterOptions}.
  26785. + *
  26786. + * @param modelFile the segmentation model {@link File} instance
  26787. + * @throws IOException if an I/O error occurs when loading the tflite model
  26788. + * @throws IllegalArgumentException if an argument is invalid
  26789. + * @throws IllegalStateException if there is an internal error
  26790. + * @throws RuntimeException if there is an otherwise unspecified error
  26791. + */
  26792. + public static ImageSegmenter createFromFileAndOptions(
  26793. + File modelFile, final ImageSegmenterOptions options) throws IOException {
  26794. + try (ParcelFileDescriptor descriptor =
  26795. + ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
  26796. + return createFromModelFdAndOptions(
  26797. + /*fileDescriptor=*/descriptor.getFd(),
  26798. + /*fileDescriptorLength=*/OPTIONAL_FD_LENGTH,
  26799. + /*fileDescriptorOffset=*/OPTIONAL_FD_OFFSET, options);
  26800. + }
  26801. + }
  26802. +
  26803. + /**
  26804. + * Creates an {@link ImageSegmenter} instance with a model buffer and {@link
  26805. + * ImageSegmenterOptions}.
  26806. + *
  26807. + * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the
  26808. + * segmentation model
  26809. + * @throws IllegalStateException if there is an internal error
  26810. + * @throws RuntimeException if there is an otherwise unspecified error
  26811. + * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
  26812. + * {@link MappedByteBuffer}
  26813. + */
  26814. + public static ImageSegmenter createFromBufferAndOptions(
  26815. + final ByteBuffer modelBuffer, final ImageSegmenterOptions options) {
  26816. + if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) {
  26817. + throw new IllegalArgumentException(
  26818. + "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer.");
  26819. + }
  26820. + return new ImageSegmenter(TaskJniUtils.createHandleFromLibrary(new EmptyHandleProvider() {
  26821. + @Override
  26822. + public long createHandle() {
  26823. + return initJniWithByteBuffer(modelBuffer, options.getDisplayNamesLocale(),
  26824. + options.getOutputType().getValue(),
  26825. + TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads(
  26826. + options.getBaseOptions(), options.getNumThreads()));
  26827. + }
  26828. + }, IMAGE_SEGMENTER_NATIVE_LIB), options.getOutputType());
  26829. + }
  26830. +
  26831. + /**
  26832. + * Constructor to initialize the JNI with a pointer from C++.
  26833. + *
  26834. + * @param nativeHandle a pointer referencing memory allocated in C++
  26835. + */
  26836. + private ImageSegmenter(long nativeHandle, OutputType outputType) {
  26837. + super(nativeHandle);
  26838. + this.outputType = outputType;
  26839. + }
  26840. +
  26841. + /** Options for setting up an {@link ImageSegmenter}. */
  26842. + @AutoValue
  26843. + public abstract static class ImageSegmenterOptions {
  26844. + private static final String DEFAULT_DISPLAY_NAME_LOCALE = "en";
  26845. + private static final OutputType DEFAULT_OUTPUT_TYPE = OutputType.CATEGORY_MASK;
  26846. + private static final int NUM_THREADS = -1;
  26847. +
  26848. + public abstract BaseOptions getBaseOptions();
  26849. +
  26850. + public abstract String getDisplayNamesLocale();
  26851. +
  26852. + public abstract OutputType getOutputType();
  26853. +
  26854. + public abstract int getNumThreads();
  26855. +
  26856. + public static Builder builder() {
  26857. + return new AutoValue_ImageSegmenter_ImageSegmenterOptions.Builder()
  26858. + .setDisplayNamesLocale(DEFAULT_DISPLAY_NAME_LOCALE)
  26859. + .setOutputType(DEFAULT_OUTPUT_TYPE)
  26860. + .setNumThreads(NUM_THREADS)
  26861. + .setBaseOptions(BaseOptions.builder().build());
  26862. + }
  26863. +
  26864. + /** Builder for {@link ImageSegmenterOptions}. */
  26865. + @AutoValue.Builder
  26866. + public abstract static class Builder {
  26867. + /** Sets the general options to configure Task APIs, such as accelerators. */
  26868. + public abstract Builder setBaseOptions(BaseOptions baseOptions);
  26869. +
  26870. + /**
  26871. + * Sets the locale to use for display names specified through the TFLite Model Metadata,
  26872. + * if any.
  26873. + *
  26874. + * <p>Defaults to English({@code "en"}). See the <a
  26875. + * href="https://github.com/tensorflow/tflite-support/blob/3ce83f0cfe2c68fecf83e019f2acc354aaba471f/tensorflow_lite_support/metadata/metadata_schema.fbs#L147">TFLite
  26876. + * Metadata schema file.</a> for the accepted pattern of locale.
  26877. + */
  26878. + public abstract Builder setDisplayNamesLocale(String displayNamesLocale);
  26879. +
  26880. + public abstract Builder setOutputType(OutputType outputType);
  26881. +
  26882. + /**
  26883. + * Sets the number of threads to be used for TFLite ops that support multi-threading
  26884. + * when running inference with CPU. Defaults to -1.
  26885. + *
  26886. + * <p>numThreads should be greater than 0 or equal to -1. Setting numThreads to -1 has
  26887. + * the effect to let TFLite runtime set the value.
  26888. + *
  26889. + * @deprecated use {@link BaseOptions} to configure number of threads instead. This
  26890. + * method
  26891. + * will override the number of threads configured from {@link BaseOptions}.
  26892. + */
  26893. + @Deprecated
  26894. + public abstract Builder setNumThreads(int numThreads);
  26895. +
  26896. + public abstract ImageSegmenterOptions build();
  26897. + }
  26898. + }
  26899. +
  26900. + /**
  26901. + * Performs actual segmentation on the provided image.
  26902. + *
  26903. + * <p>{@link ImageSegmenter} supports the following {@link TensorImage} color space types:
  26904. + *
  26905. + * <ul>
  26906. + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#RGB}
  26907. + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV12}
  26908. + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV21}
  26909. + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV12}
  26910. + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV21}
  26911. + * </ul>
  26912. + *
  26913. + * @param image a UINT8 {@link TensorImage} object that represents an RGB or YUV image
  26914. + * @return results of performing image segmentation. Note that at the time, a single {@link
  26915. + * Segmentation} element is expected to be returned. The result is stored in a {@link List}
  26916. + * for later extension to e.g. instance segmentation models, which may return one
  26917. + * segmentation per object.
  26918. + * @throws IllegalStateException if there is an internal error
  26919. + * @throws RuntimeException if there is an otherwise unspecified error
  26920. + * @throws IllegalArgumentException if the color space type of image is unsupported
  26921. + */
  26922. + public List<Segmentation> segment(TensorImage image) {
  26923. + return segment(image, ImageProcessingOptions.builder().build());
  26924. + }
  26925. +
  26926. + /**
  26927. + * Performs actual segmentation on the provided image with {@link ImageProcessingOptions}.
  26928. + *
  26929. + * <p>{@link ImageSegmenter} supports the following {@link TensorImage} color space types:
  26930. + *
  26931. + * <ul>
  26932. + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#RGB}
  26933. + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV12}
  26934. + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV21}
  26935. + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV12}
  26936. + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV21}
  26937. + * </ul>
  26938. + *
  26939. + * <p>{@link ImageSegmenter} supports the following options:
  26940. + *
  26941. + * <ul>
  26942. + * <li>image rotation (through {@link ImageProcessingOptions.Builder#setOrientation}). It
  26943. + * defaults to {@link ImageProcessingOptions.Orientation#TOP_LEFT}
  26944. + * </ul>
  26945. + *
  26946. + * @param image a UINT8 {@link TensorImage} object that represents an RGB or YUV image
  26947. + * @param options the options configure how to preprocess the image
  26948. + * @return results of performing image segmentation. Note that at the time, a single {@link
  26949. + * Segmentation} element is expected to be returned. The result is stored in a {@link List}
  26950. + * for later extension to e.g. instance segmentation models, which may return one
  26951. + * segmentation per object.
  26952. + * @throws IllegalStateException if there is an internal error
  26953. + * @throws RuntimeException if there is an otherwise unspecified error
  26954. + * @throws IllegalArgumentException if the color space type of image is unsupported
  26955. + */
  26956. + public List<Segmentation> segment(TensorImage image, ImageProcessingOptions options) {
  26957. + return run(new InferenceProvider<List<Segmentation>>() {
  26958. + @Override
  26959. + public List<Segmentation> run(
  26960. + long frameBufferHandle, int width, int height, ImageProcessingOptions options) {
  26961. + return segment(frameBufferHandle, options);
  26962. + }
  26963. + }, image, options);
  26964. }
  26965. - }
  26966. -
  26967. - /**
  26968. - * Creates an {@link ImageSegmenter} instance with a model buffer and {@link
  26969. - * ImageSegmenterOptions}.
  26970. - *
  26971. - * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the
  26972. - * segmentation model
  26973. - * @throws IllegalStateException if there is an internal error
  26974. - * @throws RuntimeException if there is an otherwise unspecified error
  26975. - * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
  26976. - * {@link MappedByteBuffer}
  26977. - */
  26978. - public static ImageSegmenter createFromBufferAndOptions(
  26979. - final ByteBuffer modelBuffer, final ImageSegmenterOptions options) {
  26980. - if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) {
  26981. - throw new IllegalArgumentException(
  26982. - "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer.");
  26983. +
  26984. + /**
  26985. + * Performs actual segmentation on the provided {@code MlImage}.
  26986. + *
  26987. + * @param image an {@code MlImage} to segment.
  26988. + * @return results of performing image segmentation. Note that at the time, a single {@link
  26989. + * Segmentation} element is expected to be returned. The result is stored in a {@link List}
  26990. + * for later extension to e.g. instance segmentation models, which may return one
  26991. + * segmentation per object.
  26992. + * @throws IllegalStateException if there is an internal error
  26993. + * @throws RuntimeException if there is an otherwise unspecified error
  26994. + * @throws IllegalArgumentException if the storage type or format of the image is unsupported
  26995. + */
  26996. + public List<Segmentation> segment(MlImage image) {
  26997. + return segment(image, ImageProcessingOptions.builder().build());
  26998. }
  26999. - return new ImageSegmenter(
  27000. - TaskJniUtils.createHandleFromLibrary(
  27001. - new EmptyHandleProvider() {
  27002. - @Override
  27003. - public long createHandle() {
  27004. - return initJniWithByteBuffer(
  27005. - modelBuffer,
  27006. - options.getDisplayNamesLocale(),
  27007. - options.getOutputType().getValue(),
  27008. - TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads(
  27009. - options.getBaseOptions(), options.getNumThreads()));
  27010. - }
  27011. - },
  27012. - IMAGE_SEGMENTER_NATIVE_LIB),
  27013. - options.getOutputType());
  27014. - }
  27015. -
  27016. - /**
  27017. - * Constructor to initialize the JNI with a pointer from C++.
  27018. - *
  27019. - * @param nativeHandle a pointer referencing memory allocated in C++
  27020. - */
  27021. - private ImageSegmenter(long nativeHandle, OutputType outputType) {
  27022. - super(nativeHandle);
  27023. - this.outputType = outputType;
  27024. - }
  27025. -
  27026. - /** Options for setting up an {@link ImageSegmenter}. */
  27027. - @AutoValue
  27028. - public abstract static class ImageSegmenterOptions {
  27029. - private static final String DEFAULT_DISPLAY_NAME_LOCALE = "en";
  27030. - private static final OutputType DEFAULT_OUTPUT_TYPE = OutputType.CATEGORY_MASK;
  27031. - private static final int NUM_THREADS = -1;
  27032. -
  27033. - public abstract BaseOptions getBaseOptions();
  27034. -
  27035. - public abstract String getDisplayNamesLocale();
  27036. -
  27037. - public abstract OutputType getOutputType();
  27038. -
  27039. - public abstract int getNumThreads();
  27040. -
  27041. - public static Builder builder() {
  27042. - return new AutoValue_ImageSegmenter_ImageSegmenterOptions.Builder()
  27043. - .setDisplayNamesLocale(DEFAULT_DISPLAY_NAME_LOCALE)
  27044. - .setOutputType(DEFAULT_OUTPUT_TYPE)
  27045. - .setNumThreads(NUM_THREADS)
  27046. - .setBaseOptions(BaseOptions.builder().build());
  27047. +
  27048. + /**
  27049. + * Performs actual segmentation on the provided {@code MlImage} with {@link
  27050. + * ImageProcessingOptions}.
  27051. + *
  27052. + * <p>{@link ImageSegmenter} supports the following options:
  27053. + *
  27054. + * <ul>
  27055. + * <li>image rotation (through {@link ImageProcessingOptions.Builder#setOrientation}). It
  27056. + * defaults to {@link ImageProcessingOptions.Orientation#TOP_LEFT}. {@link
  27057. + * MlImage#getRotation()} is not effective.
  27058. + * </ul>
  27059. + *
  27060. + * @param image an {@code MlImage} to segment.
  27061. + * @param options the options configure how to preprocess the image.
  27062. + * @return results of performing image segmentation. Note that at the time, a single {@link
  27063. + * Segmentation} element is expected to be returned. The result is stored in a {@link List}
  27064. + * for later extension to e.g. instance segmentation models, which may return one
  27065. + * segmentation per object.
  27066. + * @throws IllegalStateException if there is an internal error
  27067. + * @throws RuntimeException if there is an otherwise unspecified error
  27068. + * @throws IllegalArgumentException if the color space type of image is unsupported
  27069. + */
  27070. + public List<Segmentation> segment(MlImage image, ImageProcessingOptions options) {
  27071. + image.getInternal().acquire();
  27072. + TensorImage tensorImage = MlImageAdapter.createTensorImageFrom(image);
  27073. + List<Segmentation> result = segment(tensorImage, options);
  27074. + image.close();
  27075. + return result;
  27076. }
  27077. - /** Builder for {@link ImageSegmenterOptions}. */
  27078. - @AutoValue.Builder
  27079. - public abstract static class Builder {
  27080. -
  27081. - /** Sets the general options to configure Task APIs, such as accelerators. */
  27082. - public abstract Builder setBaseOptions(BaseOptions baseOptions);
  27083. -
  27084. - /**
  27085. - * Sets the locale to use for display names specified through the TFLite Model Metadata, if
  27086. - * any.
  27087. - *
  27088. - * <p>Defaults to English({@code "en"}). See the <a
  27089. - * href="https://github.com/tensorflow/tflite-support/blob/3ce83f0cfe2c68fecf83e019f2acc354aaba471f/tensorflow_lite_support/metadata/metadata_schema.fbs#L147">TFLite
  27090. - * Metadata schema file.</a> for the accepted pattern of locale.
  27091. - */
  27092. - public abstract Builder setDisplayNamesLocale(String displayNamesLocale);
  27093. -
  27094. - public abstract Builder setOutputType(OutputType outputType);
  27095. -
  27096. - /**
  27097. - * Sets the number of threads to be used for TFLite ops that support multi-threading when
  27098. - * running inference with CPU. Defaults to -1.
  27099. - *
  27100. - * <p>numThreads should be greater than 0 or equal to -1. Setting numThreads to -1 has the
  27101. - * effect to let TFLite runtime set the value.
  27102. - *
  27103. - * @deprecated use {@link BaseOptions} to configure number of threads instead. This method
  27104. - * will override the number of threads configured from {@link BaseOptions}.
  27105. - */
  27106. - @Deprecated
  27107. - public abstract Builder setNumThreads(int numThreads);
  27108. -
  27109. - public abstract ImageSegmenterOptions build();
  27110. + public List<Segmentation> segment(long frameBufferHandle, ImageProcessingOptions options) {
  27111. + checkNotClosed();
  27112. +
  27113. + List<byte[]> maskByteArrays = new ArrayList<>();
  27114. + List<ColoredLabel> coloredLabels = new ArrayList<>();
  27115. + int[] maskShape = new int[2];
  27116. + segmentNative(
  27117. + getNativeHandle(), frameBufferHandle, maskByteArrays, maskShape, coloredLabels);
  27118. +
  27119. + List<ByteBuffer> maskByteBuffers = new ArrayList<>();
  27120. + for (byte[] bytes : maskByteArrays) {
  27121. + ByteBuffer byteBuffer = ByteBuffer.wrap(bytes);
  27122. + // Change the byte order to little_endian, since the buffers were generated in jni.
  27123. + byteBuffer.order(ByteOrder.LITTLE_ENDIAN);
  27124. + maskByteBuffers.add(byteBuffer);
  27125. + }
  27126. +
  27127. + return Arrays.asList(Segmentation.create(outputType,
  27128. + outputType.createMasksFromBuffer(maskByteBuffers, maskShape), coloredLabels));
  27129. }
  27130. - }
  27131. -
  27132. - /**
  27133. - * Performs actual segmentation on the provided image.
  27134. - *
  27135. - * <p>{@link ImageSegmenter} supports the following {@link TensorImage} color space types:
  27136. - *
  27137. - * <ul>
  27138. - * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#RGB}
  27139. - * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV12}
  27140. - * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV21}
  27141. - * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV12}
  27142. - * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV21}
  27143. - * </ul>
  27144. - *
  27145. - * @param image a UINT8 {@link TensorImage} object that represents an RGB or YUV image
  27146. - * @return results of performing image segmentation. Note that at the time, a single {@link
  27147. - * Segmentation} element is expected to be returned. The result is stored in a {@link List}
  27148. - * for later extension to e.g. instance segmentation models, which may return one segmentation
  27149. - * per object.
  27150. - * @throws IllegalStateException if there is an internal error
  27151. - * @throws RuntimeException if there is an otherwise unspecified error
  27152. - * @throws IllegalArgumentException if the color space type of image is unsupported
  27153. - */
  27154. - public List<Segmentation> segment(TensorImage image) {
  27155. - return segment(image, ImageProcessingOptions.builder().build());
  27156. - }
  27157. -
  27158. - /**
  27159. - * Performs actual segmentation on the provided image with {@link ImageProcessingOptions}.
  27160. - *
  27161. - * <p>{@link ImageSegmenter} supports the following {@link TensorImage} color space types:
  27162. - *
  27163. - * <ul>
  27164. - * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#RGB}
  27165. - * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV12}
  27166. - * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV21}
  27167. - * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV12}
  27168. - * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV21}
  27169. - * </ul>
  27170. - *
  27171. - * <p>{@link ImageSegmenter} supports the following options:
  27172. - *
  27173. - * <ul>
  27174. - * <li>image rotation (through {@link ImageProcessingOptions.Builder#setOrientation}). It
  27175. - * defaults to {@link ImageProcessingOptions.Orientation#TOP_LEFT}
  27176. - * </ul>
  27177. - *
  27178. - * @param image a UINT8 {@link TensorImage} object that represents an RGB or YUV image
  27179. - * @param options the options configure how to preprocess the image
  27180. - * @return results of performing image segmentation. Note that at the time, a single {@link
  27181. - * Segmentation} element is expected to be returned. The result is stored in a {@link List}
  27182. - * for later extension to e.g. instance segmentation models, which may return one segmentation
  27183. - * per object.
  27184. - * @throws IllegalStateException if there is an internal error
  27185. - * @throws RuntimeException if there is an otherwise unspecified error
  27186. - * @throws IllegalArgumentException if the color space type of image is unsupported
  27187. - */
  27188. - public List<Segmentation> segment(TensorImage image, ImageProcessingOptions options) {
  27189. - return run(
  27190. - new InferenceProvider<List<Segmentation>>() {
  27191. - @Override
  27192. - public List<Segmentation> run(
  27193. - long frameBufferHandle, int width, int height, ImageProcessingOptions options) {
  27194. - return segment(frameBufferHandle, options);
  27195. - }
  27196. - },
  27197. - image,
  27198. - options);
  27199. - }
  27200. -
  27201. - /**
  27202. - * Performs actual segmentation on the provided {@code MlImage}.
  27203. - *
  27204. - * @param image an {@code MlImage} to segment.
  27205. - * @return results of performing image segmentation. Note that at the time, a single {@link
  27206. - * Segmentation} element is expected to be returned. The result is stored in a {@link List}
  27207. - * for later extension to e.g. instance segmentation models, which may return one segmentation
  27208. - * per object.
  27209. - * @throws IllegalStateException if there is an internal error
  27210. - * @throws RuntimeException if there is an otherwise unspecified error
  27211. - * @throws IllegalArgumentException if the storage type or format of the image is unsupported
  27212. - */
  27213. - public List<Segmentation> segment(MlImage image) {
  27214. - return segment(image, ImageProcessingOptions.builder().build());
  27215. - }
  27216. -
  27217. - /**
  27218. - * Performs actual segmentation on the provided {@code MlImage} with {@link
  27219. - * ImageProcessingOptions}.
  27220. - *
  27221. - * <p>{@link ImageSegmenter} supports the following options:
  27222. - *
  27223. - * <ul>
  27224. - * <li>image rotation (through {@link ImageProcessingOptions.Builder#setOrientation}). It
  27225. - * defaults to {@link ImageProcessingOptions.Orientation#TOP_LEFT}. {@link
  27226. - * MlImage#getRotation()} is not effective.
  27227. - * </ul>
  27228. - *
  27229. - * @param image an {@code MlImage} to segment.
  27230. - * @param options the options configure how to preprocess the image.
  27231. - * @return results of performing image segmentation. Note that at the time, a single {@link
  27232. - * Segmentation} element is expected to be returned. The result is stored in a {@link List}
  27233. - * for later extension to e.g. instance segmentation models, which may return one segmentation
  27234. - * per object.
  27235. - * @throws IllegalStateException if there is an internal error
  27236. - * @throws RuntimeException if there is an otherwise unspecified error
  27237. - * @throws IllegalArgumentException if the color space type of image is unsupported
  27238. - */
  27239. - public List<Segmentation> segment(MlImage image, ImageProcessingOptions options) {
  27240. - image.getInternal().acquire();
  27241. - TensorImage tensorImage = MlImageAdapter.createTensorImageFrom(image);
  27242. - List<Segmentation> result = segment(tensorImage, options);
  27243. - image.close();
  27244. - return result;
  27245. - }
  27246. -
  27247. - public List<Segmentation> segment(long frameBufferHandle, ImageProcessingOptions options) {
  27248. - checkNotClosed();
  27249. -
  27250. - List<byte[]> maskByteArrays = new ArrayList<>();
  27251. - List<ColoredLabel> coloredLabels = new ArrayList<>();
  27252. - int[] maskShape = new int[2];
  27253. - segmentNative(getNativeHandle(), frameBufferHandle, maskByteArrays, maskShape, coloredLabels);
  27254. -
  27255. - List<ByteBuffer> maskByteBuffers = new ArrayList<>();
  27256. - for (byte[] bytes : maskByteArrays) {
  27257. - ByteBuffer byteBuffer = ByteBuffer.wrap(bytes);
  27258. - // Change the byte order to little_endian, since the buffers were generated in jni.
  27259. - byteBuffer.order(ByteOrder.LITTLE_ENDIAN);
  27260. - maskByteBuffers.add(byteBuffer);
  27261. +
  27262. + private static ImageSegmenter createFromModelFdAndOptions(final int fileDescriptor,
  27263. + final long fileDescriptorLength, final long fileDescriptorOffset,
  27264. + final ImageSegmenterOptions options) {
  27265. + long nativeHandle = TaskJniUtils.createHandleFromLibrary(new EmptyHandleProvider() {
  27266. + @Override
  27267. + public long createHandle() {
  27268. + return initJniWithModelFdAndOptions(fileDescriptor, fileDescriptorLength,
  27269. + fileDescriptorOffset, options.getDisplayNamesLocale(),
  27270. + options.getOutputType().getValue(),
  27271. + TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads(
  27272. + options.getBaseOptions(), options.getNumThreads()));
  27273. + }
  27274. + }, IMAGE_SEGMENTER_NATIVE_LIB);
  27275. + return new ImageSegmenter(nativeHandle, options.getOutputType());
  27276. + }
  27277. +
  27278. + private static native long initJniWithModelFdAndOptions(int fileDescriptor,
  27279. + long fileDescriptorLength, long fileDescriptorOffset, String displayNamesLocale,
  27280. + int outputType, long baseOptionsHandle);
  27281. +
  27282. + private static native long initJniWithByteBuffer(ByteBuffer modelBuffer,
  27283. + String displayNamesLocale, int outputType, long baseOptionsHandle);
  27284. +
  27285. + /**
  27286. + * The native method to segment the image.
  27287. + *
  27288. + * <p>{@code maskBuffers}, {@code maskShape}, {@code coloredLabels} will be updated in the
  27289. + * native layer.
  27290. + */
  27291. + private static native void segmentNative(long nativeHandle, long frameBufferHandle,
  27292. + List<byte[]> maskByteArrays, int[] maskShape, List<ColoredLabel> coloredLabels);
  27293. +
  27294. + @Override
  27295. + protected void deinit(long nativeHandle) {
  27296. + deinitJni(nativeHandle);
  27297. }
  27298. - return Arrays.asList(
  27299. - Segmentation.create(
  27300. - outputType,
  27301. - outputType.createMasksFromBuffer(maskByteBuffers, maskShape),
  27302. - coloredLabels));
  27303. - }
  27304. -
  27305. - private static ImageSegmenter createFromModelFdAndOptions(
  27306. - final int fileDescriptor,
  27307. - final long fileDescriptorLength,
  27308. - final long fileDescriptorOffset,
  27309. - final ImageSegmenterOptions options) {
  27310. - long nativeHandle =
  27311. - TaskJniUtils.createHandleFromLibrary(
  27312. - new EmptyHandleProvider() {
  27313. - @Override
  27314. - public long createHandle() {
  27315. - return initJniWithModelFdAndOptions(
  27316. - fileDescriptor,
  27317. - fileDescriptorLength,
  27318. - fileDescriptorOffset,
  27319. - options.getDisplayNamesLocale(),
  27320. - options.getOutputType().getValue(),
  27321. - TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads(
  27322. - options.getBaseOptions(), options.getNumThreads()));
  27323. - }
  27324. - },
  27325. - IMAGE_SEGMENTER_NATIVE_LIB);
  27326. - return new ImageSegmenter(nativeHandle, options.getOutputType());
  27327. - }
  27328. -
  27329. - private static native long initJniWithModelFdAndOptions(
  27330. - int fileDescriptor,
  27331. - long fileDescriptorLength,
  27332. - long fileDescriptorOffset,
  27333. - String displayNamesLocale,
  27334. - int outputType,
  27335. - long baseOptionsHandle);
  27336. -
  27337. - private static native long initJniWithByteBuffer(
  27338. - ByteBuffer modelBuffer, String displayNamesLocale, int outputType, long baseOptionsHandle);
  27339. -
  27340. - /**
  27341. - * The native method to segment the image.
  27342. - *
  27343. - * <p>{@code maskBuffers}, {@code maskShape}, {@code coloredLabels} will be updated in the native
  27344. - * layer.
  27345. - */
  27346. - private static native void segmentNative(
  27347. - long nativeHandle,
  27348. - long frameBufferHandle,
  27349. - List<byte[]> maskByteArrays,
  27350. - int[] maskShape,
  27351. - List<ColoredLabel> coloredLabels);
  27352. -
  27353. - @Override
  27354. - protected void deinit(long nativeHandle) {
  27355. - deinitJni(nativeHandle);
  27356. - }
  27357. -
  27358. - /**
  27359. - * Native implementation to release memory pointed by the pointer.
  27360. - *
  27361. - * @param nativeHandle pointer to memory allocated
  27362. - */
  27363. - private native void deinitJni(long nativeHandle);
  27364. + /**
  27365. + * Native implementation to release memory pointed by the pointer.
  27366. + *
  27367. + * @param nativeHandle pointer to memory allocated
  27368. + */
  27369. + private native void deinitJni(long nativeHandle);
  27370. }
  27371. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/OutputType.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/OutputType.java
  27372. index 26ace1eaa1783..8c69cf5d152a0 100644
  27373. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/OutputType.java
  27374. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/OutputType.java
  27375. @@ -20,126 +20,128 @@ import static org.tensorflow.lite.DataType.UINT8;
  27376. import static org.tensorflow.lite.support.common.internal.SupportPreconditions.checkArgument;
  27377. import static org.tensorflow.lite.support.image.ColorSpaceType.GRAYSCALE;
  27378. +import org.tensorflow.lite.support.image.TensorImage;
  27379. +import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
  27380. +
  27381. import java.nio.ByteBuffer;
  27382. import java.util.ArrayList;
  27383. import java.util.List;
  27384. -import org.tensorflow.lite.support.image.TensorImage;
  27385. -import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
  27386. /**
  27387. * Output mask type. This allows specifying the type of post-processing to perform on the raw model
  27388. * results.
  27389. */
  27390. public enum OutputType {
  27391. -
  27392. - /**
  27393. - * Gives a single output mask where each pixel represents the class which the pixel in the
  27394. - * original image was predicted to belong to.
  27395. - */
  27396. - CATEGORY_MASK(0) {
  27397. /**
  27398. - * {@inheritDoc}
  27399. - *
  27400. - * @throws IllegalArgumentException if more than one {@link TensorImage} are provided, or if the
  27401. - * color space of the {@link TensorImage} is not {@link ColorSpaceType#GRAYSCALE}
  27402. + * Gives a single output mask where each pixel represents the class which the pixel in the
  27403. + * original image was predicted to belong to.
  27404. */
  27405. - @Override
  27406. - void assertMasksMatchColoredLabels(List<TensorImage> masks, List<ColoredLabel> coloredLabels) {
  27407. - checkArgument(
  27408. - masks.size() == 1,
  27409. - "CATRGORY_MASK only allows one TensorImage in the list, providing " + masks.size());
  27410. -
  27411. - TensorImage mask = masks.get(0);
  27412. - checkArgument(
  27413. - mask.getColorSpaceType() == GRAYSCALE,
  27414. - "CATRGORY_MASK only supports masks of ColorSpaceType, GRAYSCALE, providing "
  27415. - + mask.getColorSpaceType());
  27416. - }
  27417. + CATEGORY_MASK(0) {
  27418. + /**
  27419. + * {@inheritDoc}
  27420. + *
  27421. + * @throws IllegalArgumentException if more than one {@link TensorImage} are provided, or if
  27422. + * the
  27423. + * color space of the {@link TensorImage} is not {@link ColorSpaceType#GRAYSCALE}
  27424. + */
  27425. + @Override
  27426. + void assertMasksMatchColoredLabels(
  27427. + List<TensorImage> masks, List<ColoredLabel> coloredLabels) {
  27428. + checkArgument(masks.size() == 1,
  27429. + "CATRGORY_MASK only allows one TensorImage in the list, providing "
  27430. + + masks.size());
  27431. +
  27432. + TensorImage mask = masks.get(0);
  27433. + checkArgument(mask.getColorSpaceType() == GRAYSCALE,
  27434. + "CATRGORY_MASK only supports masks of ColorSpaceType, GRAYSCALE, providing "
  27435. + + mask.getColorSpaceType());
  27436. + }
  27437. +
  27438. + /**
  27439. + * {@inheritDoc}
  27440. + *
  27441. + * @throws IllegalArgumentException if more than one {@link ByteBuffer} are provided in the
  27442. + * list
  27443. + */
  27444. + @Override
  27445. + List<TensorImage> createMasksFromBuffer(List<ByteBuffer> buffers, int[] maskShape) {
  27446. + checkArgument(buffers.size() == 1,
  27447. + "CATRGORY_MASK only allows one mask in the buffer list, providing "
  27448. + + buffers.size());
  27449. +
  27450. + List<TensorImage> masks = new ArrayList<>();
  27451. + TensorBuffer tensorBuffer = TensorBuffer.createDynamic(UINT8);
  27452. + tensorBuffer.loadBuffer(buffers.get(0), maskShape);
  27453. + TensorImage tensorImage = new TensorImage(UINT8);
  27454. + tensorImage.load(tensorBuffer, GRAYSCALE);
  27455. + masks.add(tensorImage);
  27456. +
  27457. + return masks;
  27458. + }
  27459. + },
  27460. /**
  27461. - * {@inheritDoc}
  27462. - *
  27463. - * @throws IllegalArgumentException if more than one {@link ByteBuffer} are provided in the list
  27464. + * Gives a list of output masks where, for each mask, each pixel represents the prediction
  27465. + * confidence, usually in the [0, 1] range.
  27466. */
  27467. - @Override
  27468. - List<TensorImage> createMasksFromBuffer(List<ByteBuffer> buffers, int[] maskShape) {
  27469. - checkArgument(
  27470. - buffers.size() == 1,
  27471. - "CATRGORY_MASK only allows one mask in the buffer list, providing " + buffers.size());
  27472. -
  27473. - List<TensorImage> masks = new ArrayList<>();
  27474. - TensorBuffer tensorBuffer = TensorBuffer.createDynamic(UINT8);
  27475. - tensorBuffer.loadBuffer(buffers.get(0), maskShape);
  27476. - TensorImage tensorImage = new TensorImage(UINT8);
  27477. - tensorImage.load(tensorBuffer, GRAYSCALE);
  27478. - masks.add(tensorImage);
  27479. -
  27480. - return masks;
  27481. + CONFIDENCE_MASK(1) {
  27482. + /**
  27483. + * {@inheritDoc}
  27484. + *
  27485. + * @throws IllegalArgumentException if more the size of the masks list does not match the
  27486. + * size
  27487. + * of the coloredlabels list, or if the color space type of the any mask is not {@link
  27488. + * ColorSpaceType#GRAYSCALE}
  27489. + */
  27490. + @Override
  27491. + void assertMasksMatchColoredLabels(
  27492. + List<TensorImage> masks, List<ColoredLabel> coloredLabels) {
  27493. + checkArgument(masks.size() == coloredLabels.size(),
  27494. + String.format(
  27495. + "When using CONFIDENCE_MASK, the number of masks (%d) should match the number of"
  27496. + + " coloredLabels (%d).",
  27497. + masks.size(), coloredLabels.size()));
  27498. +
  27499. + for (TensorImage mask : masks) {
  27500. + checkArgument(mask.getColorSpaceType() == GRAYSCALE,
  27501. + "CONFIDENCE_MASK only supports masks of ColorSpaceType, GRAYSCALE, providing "
  27502. + + mask.getColorSpaceType());
  27503. + }
  27504. + }
  27505. +
  27506. + @Override
  27507. + List<TensorImage> createMasksFromBuffer(List<ByteBuffer> buffers, int[] maskShape) {
  27508. + List<TensorImage> masks = new ArrayList<>();
  27509. + for (ByteBuffer buffer : buffers) {
  27510. + TensorBuffer tensorBuffer = TensorBuffer.createDynamic(FLOAT32);
  27511. + tensorBuffer.loadBuffer(buffer, maskShape);
  27512. + TensorImage tensorImage = new TensorImage(FLOAT32);
  27513. + tensorImage.load(tensorBuffer, GRAYSCALE);
  27514. + masks.add(tensorImage);
  27515. + }
  27516. + return masks;
  27517. + }
  27518. + };
  27519. +
  27520. + public int getValue() {
  27521. + return value;
  27522. }
  27523. - },
  27524. - /**
  27525. - * Gives a list of output masks where, for each mask, each pixel represents the prediction
  27526. - * confidence, usually in the [0, 1] range.
  27527. - */
  27528. - CONFIDENCE_MASK(1) {
  27529. /**
  27530. - * {@inheritDoc}
  27531. + * Verifies that the given list of masks matches the list of colored labels.
  27532. *
  27533. - * @throws IllegalArgumentException if more the size of the masks list does not match the size
  27534. - * of the coloredlabels list, or if the color space type of the any mask is not {@link
  27535. - * ColorSpaceType#GRAYSCALE}
  27536. + * @throws IllegalArgumentException if {@code masks} and {@code coloredLabels} do not match the
  27537. + * output type
  27538. */
  27539. - @Override
  27540. - void assertMasksMatchColoredLabels(List<TensorImage> masks, List<ColoredLabel> coloredLabels) {
  27541. - checkArgument(
  27542. - masks.size() == coloredLabels.size(),
  27543. - String.format(
  27544. - "When using CONFIDENCE_MASK, the number of masks (%d) should match the number of"
  27545. - + " coloredLabels (%d).",
  27546. - masks.size(), coloredLabels.size()));
  27547. -
  27548. - for (TensorImage mask : masks) {
  27549. - checkArgument(
  27550. - mask.getColorSpaceType() == GRAYSCALE,
  27551. - "CONFIDENCE_MASK only supports masks of ColorSpaceType, GRAYSCALE, providing "
  27552. - + mask.getColorSpaceType());
  27553. - }
  27554. - }
  27555. -
  27556. - @Override
  27557. - List<TensorImage> createMasksFromBuffer(List<ByteBuffer> buffers, int[] maskShape) {
  27558. - List<TensorImage> masks = new ArrayList<>();
  27559. - for (ByteBuffer buffer : buffers) {
  27560. - TensorBuffer tensorBuffer = TensorBuffer.createDynamic(FLOAT32);
  27561. - tensorBuffer.loadBuffer(buffer, maskShape);
  27562. - TensorImage tensorImage = new TensorImage(FLOAT32);
  27563. - tensorImage.load(tensorBuffer, GRAYSCALE);
  27564. - masks.add(tensorImage);
  27565. - }
  27566. - return masks;
  27567. - }
  27568. - };
  27569. + abstract void assertMasksMatchColoredLabels(
  27570. + List<TensorImage> masks, List<ColoredLabel> coloredLabels);
  27571. - public int getValue() {
  27572. - return value;
  27573. - }
  27574. + /** Creates the masks in {@link TensorImage} based on the data in {@link ByteBuffer}. */
  27575. + abstract List<TensorImage> createMasksFromBuffer(List<ByteBuffer> buffers, int[] maskShape);
  27576. - /**
  27577. - * Verifies that the given list of masks matches the list of colored labels.
  27578. - *
  27579. - * @throws IllegalArgumentException if {@code masks} and {@code coloredLabels} do not match the
  27580. - * output type
  27581. - */
  27582. - abstract void assertMasksMatchColoredLabels(
  27583. - List<TensorImage> masks, List<ColoredLabel> coloredLabels);
  27584. + private final int value;
  27585. - /** Creates the masks in {@link TensorImage} based on the data in {@link ByteBuffer}. */
  27586. - abstract List<TensorImage> createMasksFromBuffer(List<ByteBuffer> buffers, int[] maskShape);
  27587. -
  27588. - private final int value;
  27589. -
  27590. - private OutputType(int value) {
  27591. - this.value = value;
  27592. - }
  27593. + private OutputType(int value) {
  27594. + this.value = value;
  27595. + }
  27596. }
  27597. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/Segmentation.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/Segmentation.java
  27598. index 018482c7e82db..f5062bc8745f0 100644
  27599. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/Segmentation.java
  27600. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/Segmentation.java
  27601. @@ -16,67 +16,69 @@ limitations under the License.
  27602. package org.tensorflow.lite.task.vision.segmenter;
  27603. import com.google.auto.value.AutoValue;
  27604. +
  27605. +import org.tensorflow.lite.support.image.TensorImage;
  27606. +
  27607. import java.util.ArrayList;
  27608. import java.util.Collections;
  27609. import java.util.List;
  27610. -import org.tensorflow.lite.support.image.TensorImage;
  27611. /** Represents the segmentation result of an {@link ImageSegmenter}. */
  27612. @AutoValue
  27613. public abstract class Segmentation {
  27614. + /**
  27615. + * Creates a {@link Segmentation} object.
  27616. + *
  27617. + * <p>{@link Segmentation} provides two types of outputs as indicated through {@link
  27618. + * OutputType}:
  27619. + *
  27620. + * <p>{@link OutputType#CATEGORY_MASK}: the result contains a single category mask, which is a
  27621. + * grayscale {@link TensorImage} with shape (height, width), in row major order. The value of
  27622. + * each pixel in this mask represents the class to which the pixel in the mask belongs. The
  27623. + * pixel values are in 1:1 corresponding with the colored labels, i.e. a pixel with value {@code
  27624. + * i} is associated with {@code coloredLabels.get(i)}.
  27625. + *
  27626. + * <p>{@link OutputType#CONFIDENCE_MASK}: the result contains a list of confidence masks, which
  27627. + * are in 1:1 correspondance with the colored labels, i.e. {@link masks.get(i)} is associated
  27628. + * with
  27629. + * {@code coloredLabels.get(i)}. Each confidence mask is a grayscale {@link TensorImage} with
  27630. + * shape (height, width), in row major order. The value of each pixel in these masks represents
  27631. + * the confidence score for this particular class.
  27632. + *
  27633. + * <p>IMPORTANT: segmentation masks are not direcly suited for display, in particular:<br>
  27634. + * \* they are relative to the unrotated input frame, i.e. *not* taking into account the {@code
  27635. + * Orientation} flag of the input FrameBuffer, <br>
  27636. + * \* their dimensions are intrinsic to the model, i.e. *not* dependent on the input FrameBuffer
  27637. + * dimensions.
  27638. + *
  27639. + * <p>Example of such post-processing, assuming: <br>
  27640. + * \* an input FrameBuffer with width=640, height=480, orientation=kLeftBottom (i.e. the image
  27641. + * will be rotated 90° clockwise during preprocessing to make it "upright"), <br>
  27642. + * \* a model outputting masks of size 224x224. <br>
  27643. + * In order to be directly displayable on top of the input image assumed to be displayed *with*
  27644. + * the {@code Orientation} flag taken into account (according to the <a
  27645. + * href="http://jpegclub.org/exif_orientation.html">EXIF specification</a>), the masks need to
  27646. + * be: re-scaled to 640 x 480, then rotated 90° clockwise.
  27647. + *
  27648. + * @throws IllegalArgumentException if {@code masks} and {@code coloredLabels} do not match the
  27649. + * {@code outputType}
  27650. + */
  27651. + static Segmentation create(
  27652. + OutputType outputType, List<TensorImage> masks, List<ColoredLabel> coloredLabels) {
  27653. + outputType.assertMasksMatchColoredLabels(masks, coloredLabels);
  27654. - /**
  27655. - * Creates a {@link Segmentation} object.
  27656. - *
  27657. - * <p>{@link Segmentation} provides two types of outputs as indicated through {@link OutputType}:
  27658. - *
  27659. - * <p>{@link OutputType#CATEGORY_MASK}: the result contains a single category mask, which is a
  27660. - * grayscale {@link TensorImage} with shape (height, width), in row major order. The value of each
  27661. - * pixel in this mask represents the class to which the pixel in the mask belongs. The pixel
  27662. - * values are in 1:1 corresponding with the colored labels, i.e. a pixel with value {@code i} is
  27663. - * associated with {@code coloredLabels.get(i)}.
  27664. - *
  27665. - * <p>{@link OutputType#CONFIDENCE_MASK}: the result contains a list of confidence masks, which
  27666. - * are in 1:1 correspondance with the colored labels, i.e. {@link masks.get(i)} is associated with
  27667. - * {@code coloredLabels.get(i)}. Each confidence mask is a grayscale {@link TensorImage} with
  27668. - * shape (height, width), in row major order. The value of each pixel in these masks represents
  27669. - * the confidence score for this particular class.
  27670. - *
  27671. - * <p>IMPORTANT: segmentation masks are not direcly suited for display, in particular:<br>
  27672. - * \* they are relative to the unrotated input frame, i.e. *not* taking into account the {@code
  27673. - * Orientation} flag of the input FrameBuffer, <br>
  27674. - * \* their dimensions are intrinsic to the model, i.e. *not* dependent on the input FrameBuffer
  27675. - * dimensions.
  27676. - *
  27677. - * <p>Example of such post-processing, assuming: <br>
  27678. - * \* an input FrameBuffer with width=640, height=480, orientation=kLeftBottom (i.e. the image
  27679. - * will be rotated 90° clockwise during preprocessing to make it "upright"), <br>
  27680. - * \* a model outputting masks of size 224x224. <br>
  27681. - * In order to be directly displayable on top of the input image assumed to be displayed *with*
  27682. - * the {@code Orientation} flag taken into account (according to the <a
  27683. - * href="http://jpegclub.org/exif_orientation.html">EXIF specification</a>), the masks need to be:
  27684. - * re-scaled to 640 x 480, then rotated 90° clockwise.
  27685. - *
  27686. - * @throws IllegalArgumentException if {@code masks} and {@code coloredLabels} do not match the
  27687. - * {@code outputType}
  27688. - */
  27689. - static Segmentation create(
  27690. - OutputType outputType, List<TensorImage> masks, List<ColoredLabel> coloredLabels) {
  27691. - outputType.assertMasksMatchColoredLabels(masks, coloredLabels);
  27692. -
  27693. - return new AutoValue_Segmentation(
  27694. - outputType,
  27695. - Collections.unmodifiableList(new ArrayList<TensorImage>(masks)),
  27696. - Collections.unmodifiableList(new ArrayList<ColoredLabel>(coloredLabels)));
  27697. - }
  27698. + return new AutoValue_Segmentation(outputType,
  27699. + Collections.unmodifiableList(new ArrayList<TensorImage>(masks)),
  27700. + Collections.unmodifiableList(new ArrayList<ColoredLabel>(coloredLabels)));
  27701. + }
  27702. - public abstract OutputType getOutputType();
  27703. + public abstract OutputType getOutputType();
  27704. - // As an open source project, we've been trying avoiding depending on common java libraries,
  27705. - // such as Guava, because it may introduce conflicts with clients who also happen to use those
  27706. - // libraries. Therefore, instead of using ImmutableList here, we convert the List into
  27707. - // unmodifiableList in create() to make it less vulnerable.
  27708. - public abstract List<TensorImage> getMasks();
  27709. + // As an open source project, we've been trying avoiding depending on common java libraries,
  27710. + // such as Guava, because it may introduce conflicts with clients who also happen to use those
  27711. + // libraries. Therefore, instead of using ImmutableList here, we convert the List into
  27712. + // unmodifiableList in create() to make it less vulnerable.
  27713. + public abstract List<TensorImage> getMasks();
  27714. - public abstract List<ColoredLabel> getColoredLabels();
  27715. + public abstract List<ColoredLabel> getColoredLabels();
  27716. }
  27717. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/audio/TensorAudioTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/audio/TensorAudioTest.java
  27718. index f53cfd7a9510a..02aa581c3559c 100644
  27719. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/audio/TensorAudioTest.java
  27720. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/audio/TensorAudioTest.java
  27721. @@ -16,6 +16,7 @@ limitations under the License.
  27722. package org.tensorflow.lite.support.audio;
  27723. import static com.google.common.truth.Truth.assertThat;
  27724. +
  27725. import static org.junit.Assert.assertThrows;
  27726. import static org.mockito.ArgumentMatchers.any;
  27727. import static org.mockito.ArgumentMatchers.anyInt;
  27728. @@ -25,6 +26,7 @@ import static org.mockito.Mockito.when;
  27729. import android.media.AudioFormat;
  27730. import android.media.AudioRecord;
  27731. +
  27732. import org.junit.Test;
  27733. import org.junit.runner.RunWith;
  27734. import org.junit.runners.Suite;
  27735. @@ -35,259 +37,258 @@ import org.tensorflow.lite.support.audio.TensorAudio.TensorAudioFormat;
  27736. /** Test for {@link TensorAudio}. */
  27737. @RunWith(Suite.class)
  27738. @SuiteClasses({
  27739. - TensorAudioTest.General.class,
  27740. + TensorAudioTest.General.class,
  27741. })
  27742. public class TensorAudioTest {
  27743. -
  27744. - /** General tests of TensorAudio. */
  27745. - @RunWith(RobolectricTestRunner.class)
  27746. - public static final class General extends TensorAudioTest {
  27747. - @Test
  27748. - public void createSucceedsWithTensorAudioFormat() throws Exception {
  27749. - TensorAudio tensor =
  27750. - TensorAudio.create(
  27751. - TensorAudioFormat.builder().setChannels(1).setSampleRate(2).build(), 100);
  27752. - assertThat(tensor.getFormat().getChannels()).isEqualTo(1);
  27753. - assertThat(tensor.getFormat().getSampleRate()).isEqualTo(2);
  27754. - assertThat(tensor.getTensorBuffer().getFlatSize()).isEqualTo(100);
  27755. - }
  27756. -
  27757. - @Test
  27758. - public void createSucceedsWithTensorAudioFormatWithMultipleChannels() throws Exception {
  27759. - TensorAudio tensor =
  27760. - TensorAudio.create(
  27761. - TensorAudioFormat.builder().setChannels(5).setSampleRate(2).build(), 100);
  27762. - assertThat(tensor.getFormat().getChannels()).isEqualTo(5);
  27763. - assertThat(tensor.getFormat().getSampleRate()).isEqualTo(2);
  27764. - assertThat(tensor.getTensorBuffer().getFlatSize()).isEqualTo(500);
  27765. - }
  27766. -
  27767. - @Test
  27768. - public void createSucceededsWithDefaultArguments() throws Exception {
  27769. - TensorAudio tensor =
  27770. - TensorAudio.create(TensorAudioFormat.builder().setSampleRate(20).build(), 1000);
  27771. - // Number of channels defaults to 1.
  27772. - assertThat(tensor.getFormat().getChannels()).isEqualTo(1);
  27773. - assertThat(tensor.getFormat().getSampleRate()).isEqualTo(20);
  27774. - assertThat(tensor.getTensorBuffer().getFlatSize()).isEqualTo(1000);
  27775. - }
  27776. -
  27777. - @Test
  27778. - public void createSucceedsWithAudioFormat() throws Exception {
  27779. - AudioFormat format =
  27780. - new AudioFormat.Builder()
  27781. - .setChannelMask(AudioFormat.CHANNEL_IN_STEREO)
  27782. - .setEncoding(AudioFormat.ENCODING_PCM_16BIT)
  27783. - .setSampleRate(16000)
  27784. - .build();
  27785. - TensorAudio tensor = TensorAudio.create(format, 100);
  27786. - // STEREO has 2 channels
  27787. - assertThat(tensor.getFormat().getChannels()).isEqualTo(2);
  27788. - assertThat(tensor.getFormat().getSampleRate()).isEqualTo(16000);
  27789. - // flatSize = channelCount * sampleCount
  27790. - assertThat(tensor.getTensorBuffer().getFlatSize()).isEqualTo(200);
  27791. - }
  27792. -
  27793. - @Test
  27794. - public void createFailedWithInvalidSampleRate() throws Exception {
  27795. - IllegalArgumentException exception =
  27796. - assertThrows(
  27797. - IllegalArgumentException.class,
  27798. - () -> TensorAudio.create(TensorAudioFormat.builder().setSampleRate(0).build(), 100));
  27799. - // Sample rate 0 is not allowed
  27800. - assertThat(exception).hasMessageThat().ignoringCase().contains("sample rate");
  27801. - }
  27802. -
  27803. - @Test
  27804. - public void createFailedWithInvalidChannels() throws Exception {
  27805. - IllegalArgumentException exception =
  27806. - assertThrows(
  27807. - IllegalArgumentException.class,
  27808. - () ->
  27809. - TensorAudio.create(
  27810. - TensorAudioFormat.builder().setSampleRate(1).setChannels(-1).build(), 100));
  27811. - // Negative channels is not allowed
  27812. - assertThat(exception).hasMessageThat().ignoringCase().contains("channels");
  27813. - }
  27814. -
  27815. - @Test
  27816. - public void loadSucceedsFromArray() throws Exception {
  27817. - TensorAudioFormat format =
  27818. - TensorAudioFormat.builder().setChannels(2).setSampleRate(2).build();
  27819. - TensorAudio tensor = TensorAudio.create(format, 2);
  27820. - assertThat(tensor.getTensorBuffer().getFloatArray()).isEqualTo(new float[4]);
  27821. -
  27822. - tensor.load(new float[] {2.f, 0});
  27823. - assertThat(tensor.getTensorBuffer().getFloatArray())
  27824. - .usingTolerance(0.001f)
  27825. - .containsExactly(new float[] {0, 0, 2.f, 0});
  27826. -
  27827. - tensor.load(new float[] {2.f, 3.f}, 0, 2);
  27828. - assertThat(tensor.getTensorBuffer().getFloatArray())
  27829. - .usingTolerance(0.001f)
  27830. - .containsExactly(new float[] {2.f, 0, 2.f, 3.f});
  27831. -
  27832. - tensor.load(new float[] {2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f}, 1, 6);
  27833. - // The sequence is longer than the ring buffer size so it's expected to keep only the last 4
  27834. - // numbers (index 3 to 6) of the load target sub-sequence (index 1 to 6).
  27835. - assertThat(tensor.getTensorBuffer().getFloatArray())
  27836. - .usingTolerance(0.001f)
  27837. - .containsExactly(new float[] {5.f, 6.f, 7.f, 8.f});
  27838. -
  27839. - tensor.load(new short[] {Short.MAX_VALUE, Short.MIN_VALUE});
  27840. - assertThat(tensor.getTensorBuffer().getFloatArray())
  27841. - .usingTolerance(0.001f)
  27842. - .containsExactly(new float[] {7.f, 8.f, 1.f, -1.f});
  27843. -
  27844. - tensor.load(new short[] {1000, 2000, 3000, 0, 1000, Short.MIN_VALUE, 4000, 5000, 6000}, 3, 6);
  27845. - // The sequence is longer than the ring buffer size so it's expected to keep only the last 4
  27846. - // numbers.
  27847. - assertThat(tensor.getTensorBuffer().getFloatArray())
  27848. - .usingTolerance(0.001f)
  27849. - .containsExactly(
  27850. - new float[] {
  27851. - -1.f, 4000.f / Short.MAX_VALUE, 5000.f / Short.MAX_VALUE, 6000.f / Short.MAX_VALUE
  27852. - });
  27853. - }
  27854. -
  27855. - @Test
  27856. - public void loadFailsWithIndexOutOfRange() throws Exception {
  27857. - TensorAudioFormat format = TensorAudioFormat.builder().setSampleRate(2).build();
  27858. - TensorAudio tensor = TensorAudio.create(format, 5);
  27859. -
  27860. - assertThrows(IllegalArgumentException.class, () -> tensor.load(new short[100], 99, 2));
  27861. -
  27862. - assertThrows(IllegalArgumentException.class, () -> tensor.load(new float[100], 99, 2));
  27863. - }
  27864. -
  27865. - @Test
  27866. - public void loadFailsWithIncompatibleInputSize() throws Exception {
  27867. - TensorAudioFormat format =
  27868. - TensorAudioFormat.builder().setChannels(3).setSampleRate(2).build();
  27869. - TensorAudio tensor = TensorAudio.create(format, 5);
  27870. -
  27871. - assertThrows(IllegalArgumentException.class, () -> tensor.load(new float[1]));
  27872. -
  27873. - assertThrows(IllegalArgumentException.class, () -> tensor.load(new short[2]));
  27874. -
  27875. - assertThrows(IllegalArgumentException.class, () -> tensor.load(new float[2], 1, 1));
  27876. -
  27877. - assertThrows(IllegalArgumentException.class, () -> tensor.load(new short[5], 2, 4));
  27878. - }
  27879. -
  27880. - @Test
  27881. - public void loadAudioRecordSucceeds() throws Exception {
  27882. - TensorAudio tensor =
  27883. - TensorAudio.create(TensorAudioFormat.builder().setSampleRate(16000).build(), 4);
  27884. - tensor.load(new float[] {1, 2, 3, 4, 5});
  27885. - assertThat(tensor.getTensorBuffer().getFloatArray())
  27886. - .isEqualTo(new float[] {2.f, 3.f, 4.f, 5.f});
  27887. -
  27888. - AudioRecord record = mock(AudioRecord.class);
  27889. - when(record.getBufferSizeInFrames()).thenReturn(5);
  27890. - when(record.getChannelCount()).thenReturn(1);
  27891. - when(record.getAudioFormat()).thenReturn(AudioFormat.ENCODING_PCM_FLOAT);
  27892. - when(record.getFormat())
  27893. - .thenReturn(
  27894. - new AudioFormat.Builder()
  27895. - .setChannelMask(AudioFormat.CHANNEL_IN_MONO)
  27896. - .setEncoding(AudioFormat.ENCODING_PCM_FLOAT)
  27897. - .setSampleRate(16000)
  27898. - .build());
  27899. - // Unused
  27900. - when(record.read(any(short[].class), anyInt(), anyInt(), eq(AudioRecord.READ_NON_BLOCKING)))
  27901. - .thenReturn(AudioRecord.ERROR_INVALID_OPERATION);
  27902. - // Used
  27903. - when(record.read(any(float[].class), anyInt(), anyInt(), eq(AudioRecord.READ_NON_BLOCKING)))
  27904. - .thenReturn(1);
  27905. - assertThat(tensor.load(record)).isEqualTo(1);
  27906. - assertThat(tensor.getTensorBuffer().getFloatArray())
  27907. - .isEqualTo(new float[] {3.f, 4.f, 5.f, 0});
  27908. -
  27909. - record = mock(AudioRecord.class);
  27910. - when(record.getBufferSizeInFrames()).thenReturn(5);
  27911. - when(record.getChannelCount()).thenReturn(1);
  27912. - when(record.getAudioFormat()).thenReturn(AudioFormat.ENCODING_PCM_16BIT);
  27913. - when(record.getFormat())
  27914. - .thenReturn(
  27915. - new AudioFormat.Builder()
  27916. - .setChannelMask(AudioFormat.CHANNEL_IN_MONO)
  27917. - .setEncoding(AudioFormat.ENCODING_PCM_16BIT)
  27918. - .setSampleRate(16000)
  27919. - .build());
  27920. - // Used
  27921. - when(record.read(any(short[].class), anyInt(), anyInt(), eq(AudioRecord.READ_NON_BLOCKING)))
  27922. - .thenReturn(2);
  27923. - // Unused
  27924. - when(record.read(any(float[].class), anyInt(), anyInt(), eq(AudioRecord.READ_NON_BLOCKING)))
  27925. - .thenReturn(AudioRecord.ERROR_INVALID_OPERATION);
  27926. - assertThat(tensor.load(record)).isEqualTo(2);
  27927. - assertThat(tensor.getTensorBuffer().getFloatArray()).isEqualTo(new float[] {5.f, 0, 0, 0});
  27928. - }
  27929. -
  27930. - @Test
  27931. - public void loadAudioRecordFailsWithErrorState() throws Exception {
  27932. - TensorAudio tensor =
  27933. - TensorAudio.create(TensorAudioFormat.builder().setSampleRate(16000).build(), 4);
  27934. - tensor.load(new float[] {1, 2, 3, 4, 5});
  27935. - assertThat(tensor.getTensorBuffer().getFloatArray())
  27936. - .isEqualTo(new float[] {2.f, 3.f, 4.f, 5.f});
  27937. -
  27938. - AudioRecord record = mock(AudioRecord.class);
  27939. - when(record.getAudioFormat()).thenReturn(AudioFormat.ENCODING_PCM_FLOAT);
  27940. - when(record.getFormat())
  27941. - .thenReturn(
  27942. - new AudioFormat.Builder()
  27943. - .setChannelMask(AudioFormat.CHANNEL_IN_MONO)
  27944. - .setEncoding(AudioFormat.ENCODING_PCM_FLOAT)
  27945. - .setSampleRate(16000)
  27946. - .build());
  27947. - // Unused
  27948. - when(record.read(any(short[].class), anyInt(), anyInt(), eq(AudioRecord.READ_NON_BLOCKING)))
  27949. - .thenReturn(AudioRecord.ERROR_INVALID_OPERATION);
  27950. - // Used
  27951. - when(record.read(any(float[].class), anyInt(), anyInt(), eq(AudioRecord.READ_NON_BLOCKING)))
  27952. - .thenReturn(AudioRecord.ERROR_DEAD_OBJECT);
  27953. - IllegalStateException exception =
  27954. - assertThrows(IllegalStateException.class, () -> tensor.load(record));
  27955. - assertThat(exception).hasMessageThat().contains("ERROR_DEAD_OBJECT");
  27956. - }
  27957. -
  27958. - @Test
  27959. - public void loadAudioRecordFailsWithUnsupportedAudioEncoding() throws Exception {
  27960. - TensorAudio tensor =
  27961. - TensorAudio.create(TensorAudioFormat.builder().setSampleRate(16000).build(), 4);
  27962. - AudioRecord record = mock(AudioRecord.class);
  27963. - when(record.getFormat())
  27964. - .thenReturn(
  27965. - new AudioFormat.Builder()
  27966. - .setChannelMask(AudioFormat.CHANNEL_IN_MONO)
  27967. - .setEncoding(AudioFormat.ENCODING_PCM_8BIT) // Not supported
  27968. - .setSampleRate(16000)
  27969. - .build());
  27970. - when(record.getAudioFormat()).thenReturn(AudioFormat.ENCODING_PCM_8BIT);
  27971. -
  27972. - IllegalArgumentException exception =
  27973. - assertThrows(IllegalArgumentException.class, () -> tensor.load(record));
  27974. - assertThat(exception).hasMessageThat().ignoringCase().contains("unsupported encoding");
  27975. - }
  27976. -
  27977. - @Test
  27978. - public void loadAudioRecordFailsWithIncompatibleAudioFormat() throws Exception {
  27979. - TensorAudio tensor =
  27980. - TensorAudio.create(TensorAudioFormat.builder().setSampleRate(16000).build(), 4);
  27981. - AudioRecord record = mock(AudioRecord.class);
  27982. - when(record.getFormat())
  27983. - .thenReturn(
  27984. - new AudioFormat.Builder()
  27985. - .setChannelMask(AudioFormat.CHANNEL_IN_MONO)
  27986. - .setEncoding(AudioFormat.ENCODING_PCM_FLOAT)
  27987. - .setSampleRate(44100) // Mismatch
  27988. - .build());
  27989. -
  27990. - IllegalArgumentException exception =
  27991. - assertThrows(IllegalArgumentException.class, () -> tensor.load(record));
  27992. - assertThat(exception).hasMessageThat().ignoringCase().contains("Incompatible audio format");
  27993. + /** General tests of TensorAudio. */
  27994. + @RunWith(RobolectricTestRunner.class)
  27995. + public static final class General extends TensorAudioTest {
  27996. + @Test
  27997. + public void createSucceedsWithTensorAudioFormat() throws Exception {
  27998. + TensorAudio tensor = TensorAudio.create(
  27999. + TensorAudioFormat.builder().setChannels(1).setSampleRate(2).build(), 100);
  28000. + assertThat(tensor.getFormat().getChannels()).isEqualTo(1);
  28001. + assertThat(tensor.getFormat().getSampleRate()).isEqualTo(2);
  28002. + assertThat(tensor.getTensorBuffer().getFlatSize()).isEqualTo(100);
  28003. + }
  28004. +
  28005. + @Test
  28006. + public void createSucceedsWithTensorAudioFormatWithMultipleChannels() throws Exception {
  28007. + TensorAudio tensor = TensorAudio.create(
  28008. + TensorAudioFormat.builder().setChannels(5).setSampleRate(2).build(), 100);
  28009. + assertThat(tensor.getFormat().getChannels()).isEqualTo(5);
  28010. + assertThat(tensor.getFormat().getSampleRate()).isEqualTo(2);
  28011. + assertThat(tensor.getTensorBuffer().getFlatSize()).isEqualTo(500);
  28012. + }
  28013. +
  28014. + @Test
  28015. + public void createSucceededsWithDefaultArguments() throws Exception {
  28016. + TensorAudio tensor =
  28017. + TensorAudio.create(TensorAudioFormat.builder().setSampleRate(20).build(), 1000);
  28018. + // Number of channels defaults to 1.
  28019. + assertThat(tensor.getFormat().getChannels()).isEqualTo(1);
  28020. + assertThat(tensor.getFormat().getSampleRate()).isEqualTo(20);
  28021. + assertThat(tensor.getTensorBuffer().getFlatSize()).isEqualTo(1000);
  28022. + }
  28023. +
  28024. + @Test
  28025. + public void createSucceedsWithAudioFormat() throws Exception {
  28026. + AudioFormat format = new AudioFormat.Builder()
  28027. + .setChannelMask(AudioFormat.CHANNEL_IN_STEREO)
  28028. + .setEncoding(AudioFormat.ENCODING_PCM_16BIT)
  28029. + .setSampleRate(16000)
  28030. + .build();
  28031. + TensorAudio tensor = TensorAudio.create(format, 100);
  28032. + // STEREO has 2 channels
  28033. + assertThat(tensor.getFormat().getChannels()).isEqualTo(2);
  28034. + assertThat(tensor.getFormat().getSampleRate()).isEqualTo(16000);
  28035. + // flatSize = channelCount * sampleCount
  28036. + assertThat(tensor.getTensorBuffer().getFlatSize()).isEqualTo(200);
  28037. + }
  28038. +
  28039. + @Test
  28040. + public void createFailedWithInvalidSampleRate() throws Exception {
  28041. + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
  28042. + ()
  28043. + -> TensorAudio.create(
  28044. + TensorAudioFormat.builder().setSampleRate(0).build(), 100));
  28045. + // Sample rate 0 is not allowed
  28046. + assertThat(exception).hasMessageThat().ignoringCase().contains("sample rate");
  28047. + }
  28048. +
  28049. + @Test
  28050. + public void createFailedWithInvalidChannels() throws Exception {
  28051. + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
  28052. + ()
  28053. + -> TensorAudio.create(TensorAudioFormat.builder()
  28054. + .setSampleRate(1)
  28055. + .setChannels(-1)
  28056. + .build(),
  28057. + 100));
  28058. + // Negative channels is not allowed
  28059. + assertThat(exception).hasMessageThat().ignoringCase().contains("channels");
  28060. + }
  28061. +
  28062. + @Test
  28063. + public void loadSucceedsFromArray() throws Exception {
  28064. + TensorAudioFormat format =
  28065. + TensorAudioFormat.builder().setChannels(2).setSampleRate(2).build();
  28066. + TensorAudio tensor = TensorAudio.create(format, 2);
  28067. + assertThat(tensor.getTensorBuffer().getFloatArray()).isEqualTo(new float[4]);
  28068. +
  28069. + tensor.load(new float[] {2.f, 0});
  28070. + assertThat(tensor.getTensorBuffer().getFloatArray())
  28071. + .usingTolerance(0.001f)
  28072. + .containsExactly(new float[] {0, 0, 2.f, 0});
  28073. +
  28074. + tensor.load(new float[] {2.f, 3.f}, 0, 2);
  28075. + assertThat(tensor.getTensorBuffer().getFloatArray())
  28076. + .usingTolerance(0.001f)
  28077. + .containsExactly(new float[] {2.f, 0, 2.f, 3.f});
  28078. +
  28079. + tensor.load(new float[] {2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f}, 1, 6);
  28080. + // The sequence is longer than the ring buffer size so it's expected to keep only the
  28081. + // last 4 numbers (index 3 to 6) of the load target sub-sequence (index 1 to 6).
  28082. + assertThat(tensor.getTensorBuffer().getFloatArray())
  28083. + .usingTolerance(0.001f)
  28084. + .containsExactly(new float[] {5.f, 6.f, 7.f, 8.f});
  28085. +
  28086. + tensor.load(new short[] {Short.MAX_VALUE, Short.MIN_VALUE});
  28087. + assertThat(tensor.getTensorBuffer().getFloatArray())
  28088. + .usingTolerance(0.001f)
  28089. + .containsExactly(new float[] {7.f, 8.f, 1.f, -1.f});
  28090. +
  28091. + tensor.load(new short[] {1000, 2000, 3000, 0, 1000, Short.MIN_VALUE, 4000, 5000, 6000},
  28092. + 3, 6);
  28093. + // The sequence is longer than the ring buffer size so it's expected to keep only the
  28094. + // last 4 numbers.
  28095. + assertThat(tensor.getTensorBuffer().getFloatArray())
  28096. + .usingTolerance(0.001f)
  28097. + .containsExactly(new float[] {-1.f, 4000.f / Short.MAX_VALUE,
  28098. + 5000.f / Short.MAX_VALUE, 6000.f / Short.MAX_VALUE});
  28099. + }
  28100. +
  28101. + @Test
  28102. + public void loadFailsWithIndexOutOfRange() throws Exception {
  28103. + TensorAudioFormat format = TensorAudioFormat.builder().setSampleRate(2).build();
  28104. + TensorAudio tensor = TensorAudio.create(format, 5);
  28105. +
  28106. + assertThrows(IllegalArgumentException.class, () -> tensor.load(new short[100], 99, 2));
  28107. +
  28108. + assertThrows(IllegalArgumentException.class, () -> tensor.load(new float[100], 99, 2));
  28109. + }
  28110. +
  28111. + @Test
  28112. + public void loadFailsWithIncompatibleInputSize() throws Exception {
  28113. + TensorAudioFormat format =
  28114. + TensorAudioFormat.builder().setChannels(3).setSampleRate(2).build();
  28115. + TensorAudio tensor = TensorAudio.create(format, 5);
  28116. +
  28117. + assertThrows(IllegalArgumentException.class, () -> tensor.load(new float[1]));
  28118. +
  28119. + assertThrows(IllegalArgumentException.class, () -> tensor.load(new short[2]));
  28120. +
  28121. + assertThrows(IllegalArgumentException.class, () -> tensor.load(new float[2], 1, 1));
  28122. +
  28123. + assertThrows(IllegalArgumentException.class, () -> tensor.load(new short[5], 2, 4));
  28124. + }
  28125. +
  28126. + @Test
  28127. + public void loadAudioRecordSucceeds() throws Exception {
  28128. + TensorAudio tensor =
  28129. + TensorAudio.create(TensorAudioFormat.builder().setSampleRate(16000).build(), 4);
  28130. + tensor.load(new float[] {1, 2, 3, 4, 5});
  28131. + assertThat(tensor.getTensorBuffer().getFloatArray())
  28132. + .isEqualTo(new float[] {2.f, 3.f, 4.f, 5.f});
  28133. +
  28134. + AudioRecord record = mock(AudioRecord.class);
  28135. + when(record.getBufferSizeInFrames()).thenReturn(5);
  28136. + when(record.getChannelCount()).thenReturn(1);
  28137. + when(record.getAudioFormat()).thenReturn(AudioFormat.ENCODING_PCM_FLOAT);
  28138. + when(record.getFormat())
  28139. + .thenReturn(new AudioFormat.Builder()
  28140. + .setChannelMask(AudioFormat.CHANNEL_IN_MONO)
  28141. + .setEncoding(AudioFormat.ENCODING_PCM_FLOAT)
  28142. + .setSampleRate(16000)
  28143. + .build());
  28144. + // Unused
  28145. + when(record.read(
  28146. + any(short[].class), anyInt(), anyInt(), eq(AudioRecord.READ_NON_BLOCKING)))
  28147. + .thenReturn(AudioRecord.ERROR_INVALID_OPERATION);
  28148. + // Used
  28149. + when(record.read(
  28150. + any(float[].class), anyInt(), anyInt(), eq(AudioRecord.READ_NON_BLOCKING)))
  28151. + .thenReturn(1);
  28152. + assertThat(tensor.load(record)).isEqualTo(1);
  28153. + assertThat(tensor.getTensorBuffer().getFloatArray())
  28154. + .isEqualTo(new float[] {3.f, 4.f, 5.f, 0});
  28155. +
  28156. + record = mock(AudioRecord.class);
  28157. + when(record.getBufferSizeInFrames()).thenReturn(5);
  28158. + when(record.getChannelCount()).thenReturn(1);
  28159. + when(record.getAudioFormat()).thenReturn(AudioFormat.ENCODING_PCM_16BIT);
  28160. + when(record.getFormat())
  28161. + .thenReturn(new AudioFormat.Builder()
  28162. + .setChannelMask(AudioFormat.CHANNEL_IN_MONO)
  28163. + .setEncoding(AudioFormat.ENCODING_PCM_16BIT)
  28164. + .setSampleRate(16000)
  28165. + .build());
  28166. + // Used
  28167. + when(record.read(
  28168. + any(short[].class), anyInt(), anyInt(), eq(AudioRecord.READ_NON_BLOCKING)))
  28169. + .thenReturn(2);
  28170. + // Unused
  28171. + when(record.read(
  28172. + any(float[].class), anyInt(), anyInt(), eq(AudioRecord.READ_NON_BLOCKING)))
  28173. + .thenReturn(AudioRecord.ERROR_INVALID_OPERATION);
  28174. + assertThat(tensor.load(record)).isEqualTo(2);
  28175. + assertThat(tensor.getTensorBuffer().getFloatArray())
  28176. + .isEqualTo(new float[] {5.f, 0, 0, 0});
  28177. + }
  28178. +
  28179. + @Test
  28180. + public void loadAudioRecordFailsWithErrorState() throws Exception {
  28181. + TensorAudio tensor =
  28182. + TensorAudio.create(TensorAudioFormat.builder().setSampleRate(16000).build(), 4);
  28183. + tensor.load(new float[] {1, 2, 3, 4, 5});
  28184. + assertThat(tensor.getTensorBuffer().getFloatArray())
  28185. + .isEqualTo(new float[] {2.f, 3.f, 4.f, 5.f});
  28186. +
  28187. + AudioRecord record = mock(AudioRecord.class);
  28188. + when(record.getAudioFormat()).thenReturn(AudioFormat.ENCODING_PCM_FLOAT);
  28189. + when(record.getFormat())
  28190. + .thenReturn(new AudioFormat.Builder()
  28191. + .setChannelMask(AudioFormat.CHANNEL_IN_MONO)
  28192. + .setEncoding(AudioFormat.ENCODING_PCM_FLOAT)
  28193. + .setSampleRate(16000)
  28194. + .build());
  28195. + // Unused
  28196. + when(record.read(
  28197. + any(short[].class), anyInt(), anyInt(), eq(AudioRecord.READ_NON_BLOCKING)))
  28198. + .thenReturn(AudioRecord.ERROR_INVALID_OPERATION);
  28199. + // Used
  28200. + when(record.read(
  28201. + any(float[].class), anyInt(), anyInt(), eq(AudioRecord.READ_NON_BLOCKING)))
  28202. + .thenReturn(AudioRecord.ERROR_DEAD_OBJECT);
  28203. + IllegalStateException exception =
  28204. + assertThrows(IllegalStateException.class, () -> tensor.load(record));
  28205. + assertThat(exception).hasMessageThat().contains("ERROR_DEAD_OBJECT");
  28206. + }
  28207. +
  28208. + @Test
  28209. + public void loadAudioRecordFailsWithUnsupportedAudioEncoding() throws Exception {
  28210. + TensorAudio tensor =
  28211. + TensorAudio.create(TensorAudioFormat.builder().setSampleRate(16000).build(), 4);
  28212. + AudioRecord record = mock(AudioRecord.class);
  28213. + when(record.getFormat())
  28214. + .thenReturn(new AudioFormat.Builder()
  28215. + .setChannelMask(AudioFormat.CHANNEL_IN_MONO)
  28216. + .setEncoding(AudioFormat.ENCODING_PCM_8BIT) // Not supported
  28217. + .setSampleRate(16000)
  28218. + .build());
  28219. + when(record.getAudioFormat()).thenReturn(AudioFormat.ENCODING_PCM_8BIT);
  28220. +
  28221. + IllegalArgumentException exception =
  28222. + assertThrows(IllegalArgumentException.class, () -> tensor.load(record));
  28223. + assertThat(exception).hasMessageThat().ignoringCase().contains("unsupported encoding");
  28224. + }
  28225. +
  28226. + @Test
  28227. + public void loadAudioRecordFailsWithIncompatibleAudioFormat() throws Exception {
  28228. + TensorAudio tensor =
  28229. + TensorAudio.create(TensorAudioFormat.builder().setSampleRate(16000).build(), 4);
  28230. + AudioRecord record = mock(AudioRecord.class);
  28231. + when(record.getFormat())
  28232. + .thenReturn(new AudioFormat.Builder()
  28233. + .setChannelMask(AudioFormat.CHANNEL_IN_MONO)
  28234. + .setEncoding(AudioFormat.ENCODING_PCM_FLOAT)
  28235. + .setSampleRate(44100) // Mismatch
  28236. + .build());
  28237. +
  28238. + IllegalArgumentException exception =
  28239. + assertThrows(IllegalArgumentException.class, () -> tensor.load(record));
  28240. + assertThat(exception).hasMessageThat().ignoringCase().contains(
  28241. + "Incompatible audio format");
  28242. + }
  28243. }
  28244. - }
  28245. }
  28246. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/FileUtilTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/FileUtilTest.java
  28247. index d97665d1ed771..1d26476733c98 100644
  28248. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/FileUtilTest.java
  28249. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/FileUtilTest.java
  28250. @@ -18,78 +18,81 @@ package org.tensorflow.lite.support.common;
  28251. import static com.google.common.truth.Truth.assertThat;
  28252. import android.content.Context;
  28253. +
  28254. import androidx.test.core.app.ApplicationProvider;
  28255. +
  28256. +import org.junit.Assert;
  28257. +import org.junit.Test;
  28258. +import org.junit.runner.RunWith;
  28259. +import org.robolectric.RobolectricTestRunner;
  28260. +
  28261. import java.io.ByteArrayInputStream;
  28262. import java.io.IOException;
  28263. import java.io.InputStream;
  28264. import java.nio.MappedByteBuffer;
  28265. import java.nio.charset.Charset;
  28266. import java.util.List;
  28267. -import org.junit.Assert;
  28268. -import org.junit.Test;
  28269. -import org.junit.runner.RunWith;
  28270. -import org.robolectric.RobolectricTestRunner;
  28271. /** Tests of {@link org.tensorflow.lite.support.common.FileUtil}. */
  28272. @RunWith(RobolectricTestRunner.class)
  28273. public final class FileUtilTest {
  28274. - private final Context context = ApplicationProvider.getApplicationContext();
  28275. - private static final String LABEL_PATH = "flower_labels.txt";
  28276. -
  28277. - @Test
  28278. - public void testLoadLabels() throws IOException {
  28279. - List<String> labels = FileUtil.loadLabels(context, LABEL_PATH);
  28280. - assertThat(labels)
  28281. - .containsExactly("daisy", "dandelion", "roses", "sunflowers", "tulips")
  28282. - .inOrder();
  28283. - }
  28284. -
  28285. - @Test
  28286. - public void testLoadLabelsFromInputStream() throws IOException {
  28287. - InputStream inputStream = context.getAssets().open(LABEL_PATH);
  28288. - assertThat(FileUtil.loadLabels(inputStream))
  28289. - .containsExactly("daisy", "dandelion", "roses", "sunflowers", "tulips")
  28290. - .inOrder();
  28291. - }
  28292. -
  28293. - @Test
  28294. - public void whitespaceLabelsShouldNotCount() throws IOException {
  28295. - String s = "a\nb\n \n\n\nc";
  28296. - InputStream stream = new ByteArrayInputStream(s.getBytes(Charset.defaultCharset()));
  28297. - assertThat(FileUtil.loadLabels(stream)).hasSize(3);
  28298. - }
  28299. -
  28300. - @Test
  28301. - public void testLoadLabelsNullContext() throws IOException {
  28302. - Context nullContext = null;
  28303. - Assert.assertThrows(
  28304. - NullPointerException.class, () -> FileUtil.loadLabels(nullContext, LABEL_PATH));
  28305. - }
  28306. -
  28307. - @Test
  28308. - public void testLoadLabelsNullFilePath() throws IOException {
  28309. - String nullFilePath = null;
  28310. - Assert.assertThrows(
  28311. - NullPointerException.class, () -> FileUtil.loadLabels(context, nullFilePath));
  28312. - }
  28313. -
  28314. - @Test
  28315. - public void testLoadMappedFile() throws IOException {
  28316. - MappedByteBuffer byteModel = FileUtil.loadMappedFile(context, LABEL_PATH);
  28317. - assertThat(byteModel).isNotNull();
  28318. - }
  28319. -
  28320. - @Test
  28321. - public void testLoadMappedFileWithNullContext() throws IOException {
  28322. - Context nullContext = null;
  28323. - Assert.assertThrows(
  28324. - NullPointerException.class, () -> FileUtil.loadMappedFile(nullContext, LABEL_PATH));
  28325. - }
  28326. -
  28327. - @Test
  28328. - public void loadMappedFileWithNullFilePath() throws IOException {
  28329. - String nullFilePath = null;
  28330. - Assert.assertThrows(
  28331. - NullPointerException.class, () -> FileUtil.loadMappedFile(context, nullFilePath));
  28332. - }
  28333. + private final Context context = ApplicationProvider.getApplicationContext();
  28334. + private static final String LABEL_PATH = "flower_labels.txt";
  28335. +
  28336. + @Test
  28337. + public void testLoadLabels() throws IOException {
  28338. + List<String> labels = FileUtil.loadLabels(context, LABEL_PATH);
  28339. + assertThat(labels)
  28340. + .containsExactly("daisy", "dandelion", "roses", "sunflowers", "tulips")
  28341. + .inOrder();
  28342. + }
  28343. +
  28344. + @Test
  28345. + public void testLoadLabelsFromInputStream() throws IOException {
  28346. + InputStream inputStream = context.getAssets().open(LABEL_PATH);
  28347. + assertThat(FileUtil.loadLabels(inputStream))
  28348. + .containsExactly("daisy", "dandelion", "roses", "sunflowers", "tulips")
  28349. + .inOrder();
  28350. + }
  28351. +
  28352. + @Test
  28353. + public void whitespaceLabelsShouldNotCount() throws IOException {
  28354. + String s = "a\nb\n \n\n\nc";
  28355. + InputStream stream = new ByteArrayInputStream(s.getBytes(Charset.defaultCharset()));
  28356. + assertThat(FileUtil.loadLabels(stream)).hasSize(3);
  28357. + }
  28358. +
  28359. + @Test
  28360. + public void testLoadLabelsNullContext() throws IOException {
  28361. + Context nullContext = null;
  28362. + Assert.assertThrows(
  28363. + NullPointerException.class, () -> FileUtil.loadLabels(nullContext, LABEL_PATH));
  28364. + }
  28365. +
  28366. + @Test
  28367. + public void testLoadLabelsNullFilePath() throws IOException {
  28368. + String nullFilePath = null;
  28369. + Assert.assertThrows(
  28370. + NullPointerException.class, () -> FileUtil.loadLabels(context, nullFilePath));
  28371. + }
  28372. +
  28373. + @Test
  28374. + public void testLoadMappedFile() throws IOException {
  28375. + MappedByteBuffer byteModel = FileUtil.loadMappedFile(context, LABEL_PATH);
  28376. + assertThat(byteModel).isNotNull();
  28377. + }
  28378. +
  28379. + @Test
  28380. + public void testLoadMappedFileWithNullContext() throws IOException {
  28381. + Context nullContext = null;
  28382. + Assert.assertThrows(
  28383. + NullPointerException.class, () -> FileUtil.loadMappedFile(nullContext, LABEL_PATH));
  28384. + }
  28385. +
  28386. + @Test
  28387. + public void loadMappedFileWithNullFilePath() throws IOException {
  28388. + String nullFilePath = null;
  28389. + Assert.assertThrows(
  28390. + NullPointerException.class, () -> FileUtil.loadMappedFile(context, nullFilePath));
  28391. + }
  28392. }
  28393. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/TensorProcessorTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/TensorProcessorTest.java
  28394. index 43a7f7cd1ce29..82f97f2534cf7 100644
  28395. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/TensorProcessorTest.java
  28396. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/TensorProcessorTest.java
  28397. @@ -27,59 +27,58 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
  28398. /** Tests for {@link TensorProcessor}. */
  28399. @RunWith(RobolectricTestRunner.class)
  28400. public final class TensorProcessorTest {
  28401. + private static final int EXAMPLE_NUM_FEATURES = 1000;
  28402. + private static final float MEAN = 127.5f;
  28403. + private static final float STDDEV = 127.5f;
  28404. - private static final int EXAMPLE_NUM_FEATURES = 1000;
  28405. - private static final float MEAN = 127.5f;
  28406. - private static final float STDDEV = 127.5f;
  28407. -
  28408. - @Test
  28409. - public void testBuild() {
  28410. - TensorProcessor processor =
  28411. - new TensorProcessor.Builder().add(new NormalizeOp(MEAN, STDDEV)).build();
  28412. - assertThat(processor).isNotNull();
  28413. - }
  28414. + @Test
  28415. + public void testBuild() {
  28416. + TensorProcessor processor =
  28417. + new TensorProcessor.Builder().add(new NormalizeOp(MEAN, STDDEV)).build();
  28418. + assertThat(processor).isNotNull();
  28419. + }
  28420. - @Test
  28421. - public void testNormalize() {
  28422. - TensorBuffer input = createExampleTensorBuffer();
  28423. - TensorProcessor processor =
  28424. - new TensorProcessor.Builder().add(new NormalizeOp(MEAN, STDDEV)).build();
  28425. - TensorBuffer output = processor.process(input);
  28426. + @Test
  28427. + public void testNormalize() {
  28428. + TensorBuffer input = createExampleTensorBuffer();
  28429. + TensorProcessor processor =
  28430. + new TensorProcessor.Builder().add(new NormalizeOp(MEAN, STDDEV)).build();
  28431. + TensorBuffer output = processor.process(input);
  28432. - float[] pixels = output.getFloatArray();
  28433. - assertThat(pixels.length).isEqualTo(EXAMPLE_NUM_FEATURES);
  28434. - for (float p : pixels) {
  28435. - assertThat(p).isAtLeast(-1);
  28436. - assertThat(p).isAtMost(1);
  28437. + float[] pixels = output.getFloatArray();
  28438. + assertThat(pixels.length).isEqualTo(EXAMPLE_NUM_FEATURES);
  28439. + for (float p : pixels) {
  28440. + assertThat(p).isAtLeast(-1);
  28441. + assertThat(p).isAtMost(1);
  28442. + }
  28443. }
  28444. - }
  28445. - @Test
  28446. - public void testMultipleNormalize() {
  28447. - TensorBuffer input = createExampleTensorBuffer();
  28448. - TensorProcessor processor =
  28449. - new TensorProcessor.Builder()
  28450. - .add(new NormalizeOp(MEAN, STDDEV)) // [0, 255] -> [-1, 1]
  28451. - .add(new NormalizeOp(-1, 2)) // [-1, 1] -> [0, 1]
  28452. - .build();
  28453. - TensorBuffer output = processor.process(input);
  28454. + @Test
  28455. + public void testMultipleNormalize() {
  28456. + TensorBuffer input = createExampleTensorBuffer();
  28457. + TensorProcessor processor =
  28458. + new TensorProcessor.Builder()
  28459. + .add(new NormalizeOp(MEAN, STDDEV)) // [0, 255] -> [-1, 1]
  28460. + .add(new NormalizeOp(-1, 2)) // [-1, 1] -> [0, 1]
  28461. + .build();
  28462. + TensorBuffer output = processor.process(input);
  28463. - float[] pixels = output.getFloatArray();
  28464. - assertThat(pixels.length).isEqualTo(EXAMPLE_NUM_FEATURES);
  28465. - for (float p : pixels) {
  28466. - assertThat(p).isAtLeast(0);
  28467. - assertThat(p).isAtMost(1);
  28468. + float[] pixels = output.getFloatArray();
  28469. + assertThat(pixels.length).isEqualTo(EXAMPLE_NUM_FEATURES);
  28470. + for (float p : pixels) {
  28471. + assertThat(p).isAtLeast(0);
  28472. + assertThat(p).isAtMost(1);
  28473. + }
  28474. }
  28475. - }
  28476. - // Creates a TensorBuffer of size {1, 1000}, containing values in range [0, 255].
  28477. - private static TensorBuffer createExampleTensorBuffer() {
  28478. - TensorBuffer buffer = TensorBuffer.createDynamic(DataType.FLOAT32);
  28479. - int[] features = new int[EXAMPLE_NUM_FEATURES];
  28480. - for (int i = 0; i < EXAMPLE_NUM_FEATURES; i++) {
  28481. - features[i] = i % 256;
  28482. + // Creates a TensorBuffer of size {1, 1000}, containing values in range [0, 255].
  28483. + private static TensorBuffer createExampleTensorBuffer() {
  28484. + TensorBuffer buffer = TensorBuffer.createDynamic(DataType.FLOAT32);
  28485. + int[] features = new int[EXAMPLE_NUM_FEATURES];
  28486. + for (int i = 0; i < EXAMPLE_NUM_FEATURES; i++) {
  28487. + features[i] = i % 256;
  28488. + }
  28489. + buffer.loadArray(features, new int[] {1, EXAMPLE_NUM_FEATURES});
  28490. + return buffer;
  28491. }
  28492. - buffer.loadArray(features, new int[] {1, EXAMPLE_NUM_FEATURES});
  28493. - return buffer;
  28494. - }
  28495. }
  28496. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/ops/CastOpTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/ops/CastOpTest.java
  28497. index a159c71863322..e8ba24d27550b 100644
  28498. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/ops/CastOpTest.java
  28499. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/ops/CastOpTest.java
  28500. @@ -27,56 +27,55 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
  28501. /** Tests of {@link CastOp}. */
  28502. @RunWith(RobolectricTestRunner.class)
  28503. public final class CastOpTest {
  28504. + private static final float[] FLOAT_ARRAY = new float[] {1.1f, 3.3f, 5.5f, 7.7f, 9.9f};
  28505. + private static final float[] CASTED_FLOAT_ARRAY = new float[] {1.0f, 3.0f, 5.0f, 7.0f, 9.0f};
  28506. + private static final int[] INT_ARRAY = new int[] {1, 3, 5, 7, 9};
  28507. + private static final int[] SHAPE = new int[] {5};
  28508. - private static final float[] FLOAT_ARRAY = new float[] {1.1f, 3.3f, 5.5f, 7.7f, 9.9f};
  28509. - private static final float[] CASTED_FLOAT_ARRAY = new float[] {1.0f, 3.0f, 5.0f, 7.0f, 9.0f};
  28510. - private static final int[] INT_ARRAY = new int[] {1, 3, 5, 7, 9};
  28511. - private static final int[] SHAPE = new int[] {5};
  28512. -
  28513. - @Test
  28514. - public void castFloat32ToUint8ShouldSuccess() {
  28515. - TensorBuffer floatBuffer = TensorBuffer.createDynamic(DataType.FLOAT32);
  28516. - floatBuffer.loadArray(FLOAT_ARRAY, SHAPE);
  28517. - CastOp op = new CastOp(DataType.UINT8);
  28518. - TensorBuffer uint8Buffer = op.apply(floatBuffer);
  28519. - assertThat(uint8Buffer.getDataType()).isEqualTo(DataType.UINT8);
  28520. - assertThat(uint8Buffer.getIntArray()).isEqualTo(INT_ARRAY);
  28521. - }
  28522. + @Test
  28523. + public void castFloat32ToUint8ShouldSuccess() {
  28524. + TensorBuffer floatBuffer = TensorBuffer.createDynamic(DataType.FLOAT32);
  28525. + floatBuffer.loadArray(FLOAT_ARRAY, SHAPE);
  28526. + CastOp op = new CastOp(DataType.UINT8);
  28527. + TensorBuffer uint8Buffer = op.apply(floatBuffer);
  28528. + assertThat(uint8Buffer.getDataType()).isEqualTo(DataType.UINT8);
  28529. + assertThat(uint8Buffer.getIntArray()).isEqualTo(INT_ARRAY);
  28530. + }
  28531. - @Test
  28532. - public void castUint8ToFloat32ShouldSuccess() {
  28533. - TensorBuffer uint8Buffer = TensorBuffer.createDynamic(DataType.UINT8);
  28534. - uint8Buffer.loadArray(INT_ARRAY, SHAPE);
  28535. - CastOp op = new CastOp(DataType.FLOAT32);
  28536. - TensorBuffer floatBuffer = op.apply(uint8Buffer);
  28537. - assertThat(floatBuffer.getDataType()).isEqualTo(DataType.FLOAT32);
  28538. - assertThat(floatBuffer.getFloatArray()).isEqualTo(CASTED_FLOAT_ARRAY);
  28539. - }
  28540. + @Test
  28541. + public void castUint8ToFloat32ShouldSuccess() {
  28542. + TensorBuffer uint8Buffer = TensorBuffer.createDynamic(DataType.UINT8);
  28543. + uint8Buffer.loadArray(INT_ARRAY, SHAPE);
  28544. + CastOp op = new CastOp(DataType.FLOAT32);
  28545. + TensorBuffer floatBuffer = op.apply(uint8Buffer);
  28546. + assertThat(floatBuffer.getDataType()).isEqualTo(DataType.FLOAT32);
  28547. + assertThat(floatBuffer.getFloatArray()).isEqualTo(CASTED_FLOAT_ARRAY);
  28548. + }
  28549. - @Test
  28550. - public void castFloat32ToFloat32ShouldNotRecreate() {
  28551. - TensorBuffer floatBuffer = TensorBuffer.createDynamic(DataType.FLOAT32);
  28552. - floatBuffer.loadArray(FLOAT_ARRAY, SHAPE);
  28553. - CastOp op = new CastOp(DataType.FLOAT32);
  28554. - TensorBuffer newBuffer = op.apply(floatBuffer);
  28555. - assertThat(newBuffer.getDataType()).isEqualTo(DataType.FLOAT32);
  28556. - assertThat(newBuffer).isSameInstanceAs(floatBuffer);
  28557. - }
  28558. + @Test
  28559. + public void castFloat32ToFloat32ShouldNotRecreate() {
  28560. + TensorBuffer floatBuffer = TensorBuffer.createDynamic(DataType.FLOAT32);
  28561. + floatBuffer.loadArray(FLOAT_ARRAY, SHAPE);
  28562. + CastOp op = new CastOp(DataType.FLOAT32);
  28563. + TensorBuffer newBuffer = op.apply(floatBuffer);
  28564. + assertThat(newBuffer.getDataType()).isEqualTo(DataType.FLOAT32);
  28565. + assertThat(newBuffer).isSameInstanceAs(floatBuffer);
  28566. + }
  28567. - @Test
  28568. - public void castUint8ToUint8ShouldNotRecreate() {
  28569. - TensorBuffer uint8Buffer = TensorBuffer.createDynamic(DataType.UINT8);
  28570. - uint8Buffer.loadArray(INT_ARRAY, SHAPE);
  28571. - CastOp op = new CastOp(DataType.UINT8);
  28572. - TensorBuffer newBuffer = op.apply(uint8Buffer);
  28573. - assertThat(newBuffer.getDataType()).isEqualTo(DataType.UINT8);
  28574. - assertThat(newBuffer).isSameInstanceAs(uint8Buffer);
  28575. - }
  28576. + @Test
  28577. + public void castUint8ToUint8ShouldNotRecreate() {
  28578. + TensorBuffer uint8Buffer = TensorBuffer.createDynamic(DataType.UINT8);
  28579. + uint8Buffer.loadArray(INT_ARRAY, SHAPE);
  28580. + CastOp op = new CastOp(DataType.UINT8);
  28581. + TensorBuffer newBuffer = op.apply(uint8Buffer);
  28582. + assertThat(newBuffer.getDataType()).isEqualTo(DataType.UINT8);
  28583. + assertThat(newBuffer).isSameInstanceAs(uint8Buffer);
  28584. + }
  28585. - @Test
  28586. - public void castToUnsupportedDataTypeShouldThrow() {
  28587. - for (DataType type : new DataType[] {DataType.INT32, DataType.INT64, DataType.STRING}) {
  28588. - Assert.assertThrows(IllegalArgumentException.class, () -> new CastOp(type));
  28589. + @Test
  28590. + public void castToUnsupportedDataTypeShouldThrow() {
  28591. + for (DataType type : new DataType[] {DataType.INT32, DataType.INT64, DataType.STRING}) {
  28592. + Assert.assertThrows(IllegalArgumentException.class, () -> new CastOp(type));
  28593. + }
  28594. }
  28595. - }
  28596. }
  28597. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/ops/DequantizeOpTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/ops/DequantizeOpTest.java
  28598. index 99ded56ce069a..a69bcd7ec0296 100644
  28599. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/ops/DequantizeOpTest.java
  28600. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/ops/DequantizeOpTest.java
  28601. @@ -26,16 +26,15 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
  28602. /** Tests of {@link DequantizeOp}. */
  28603. @RunWith(RobolectricTestRunner.class)
  28604. public final class DequantizeOpTest {
  28605. -
  28606. - @Test
  28607. - public void dequantizeShouldSucess() {
  28608. - int[] originalData = new int[] {191, 159, 63, 127, 255, 0};
  28609. - DequantizeOp op = new DequantizeOp(127.0f, 1.0f / 128);
  28610. - TensorBuffer input = TensorBuffer.createFixedSize(new int[] {6}, DataType.UINT8);
  28611. - input.loadArray(originalData);
  28612. - TensorBuffer dequantized = op.apply(input);
  28613. - assertThat(dequantized.getDataType()).isEqualTo(DataType.FLOAT32);
  28614. - assertThat(dequantized.getFloatArray())
  28615. - .isEqualTo(new float[] {0.5f, 0.25f, -0.5f, 0, 1, -0.9921875f});
  28616. - }
  28617. + @Test
  28618. + public void dequantizeShouldSucess() {
  28619. + int[] originalData = new int[] {191, 159, 63, 127, 255, 0};
  28620. + DequantizeOp op = new DequantizeOp(127.0f, 1.0f / 128);
  28621. + TensorBuffer input = TensorBuffer.createFixedSize(new int[] {6}, DataType.UINT8);
  28622. + input.loadArray(originalData);
  28623. + TensorBuffer dequantized = op.apply(input);
  28624. + assertThat(dequantized.getDataType()).isEqualTo(DataType.FLOAT32);
  28625. + assertThat(dequantized.getFloatArray())
  28626. + .isEqualTo(new float[] {0.5f, 0.25f, -0.5f, 0, 1, -0.9921875f});
  28627. + }
  28628. }
  28629. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/ops/NormalizeOpTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/ops/NormalizeOpTest.java
  28630. index 09ef275a826bc..aabc6be926106 100644
  28631. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/ops/NormalizeOpTest.java
  28632. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/ops/NormalizeOpTest.java
  28633. @@ -16,6 +16,7 @@ limitations under the License.
  28634. package org.tensorflow.lite.support.common.ops;
  28635. import static com.google.common.truth.Truth.assertThat;
  28636. +
  28637. import static org.tensorflow.lite.DataType.FLOAT32;
  28638. import static org.tensorflow.lite.DataType.UINT8;
  28639. @@ -31,122 +32,120 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
  28640. */
  28641. @RunWith(RobolectricTestRunner.class)
  28642. public final class NormalizeOpTest {
  28643. + private static final float MEAN = 50;
  28644. + private static final float STDDEV = 50;
  28645. + private static final int NUM_ELEMENTS = 100;
  28646. +
  28647. + @Test
  28648. + public void testNormalizeIntBuffer() {
  28649. + int[] inputArr = new int[NUM_ELEMENTS];
  28650. + for (int i = 0; i < NUM_ELEMENTS; i++) {
  28651. + inputArr[i] = i;
  28652. + }
  28653. + TensorBuffer input = TensorBuffer.createDynamic(DataType.UINT8);
  28654. + input.loadArray(inputArr, new int[] {inputArr.length});
  28655. + NormalizeOp op = new NormalizeOp(MEAN, STDDEV);
  28656. + TensorBuffer output = op.apply(input);
  28657. + assertThat(output.getDataType()).isEqualTo(FLOAT32);
  28658. + float[] outputArr = output.getFloatArray();
  28659. + for (int i = 0; i < NUM_ELEMENTS; i++) {
  28660. + assertThat(outputArr[i]).isEqualTo((inputArr[i] - MEAN) / STDDEV);
  28661. + }
  28662. + }
  28663. - private static final float MEAN = 50;
  28664. - private static final float STDDEV = 50;
  28665. - private static final int NUM_ELEMENTS = 100;
  28666. + @Test
  28667. + public void testNormalizeFloatBuffer() {
  28668. + float[] inputArr = new float[NUM_ELEMENTS];
  28669. + for (int i = 0; i < NUM_ELEMENTS; i++) {
  28670. + inputArr[i] = i;
  28671. + }
  28672. + TensorBuffer input = TensorBuffer.createDynamic(FLOAT32);
  28673. + input.loadArray(inputArr, new int[] {inputArr.length});
  28674. + NormalizeOp op = new NormalizeOp(MEAN, STDDEV);
  28675. + TensorBuffer output = op.apply(input);
  28676. + assertThat(output.getDataType()).isEqualTo(FLOAT32);
  28677. + float[] outputArr = output.getFloatArray();
  28678. + for (int i = 0; i < NUM_ELEMENTS; i++) {
  28679. + assertThat(outputArr[i]).isEqualTo((inputArr[i] - MEAN) / STDDEV);
  28680. + }
  28681. + }
  28682. - @Test
  28683. - public void testNormalizeIntBuffer() {
  28684. - int[] inputArr = new int[NUM_ELEMENTS];
  28685. - for (int i = 0; i < NUM_ELEMENTS; i++) {
  28686. - inputArr[i] = i;
  28687. + @Test
  28688. + public void testZeroStddev() {
  28689. + Assert.assertThrows(IllegalArgumentException.class, () -> new NormalizeOp(1, 0));
  28690. }
  28691. - TensorBuffer input = TensorBuffer.createDynamic(DataType.UINT8);
  28692. - input.loadArray(inputArr, new int[] {inputArr.length});
  28693. - NormalizeOp op = new NormalizeOp(MEAN, STDDEV);
  28694. - TensorBuffer output = op.apply(input);
  28695. - assertThat(output.getDataType()).isEqualTo(FLOAT32);
  28696. - float[] outputArr = output.getFloatArray();
  28697. - for (int i = 0; i < NUM_ELEMENTS; i++) {
  28698. - assertThat(outputArr[i]).isEqualTo((inputArr[i] - MEAN) / STDDEV);
  28699. +
  28700. + @Test
  28701. + public void testIdentityShortcut() {
  28702. + TensorBuffer input = TensorBuffer.createFixedSize(new int[] {3, 3}, UINT8);
  28703. + NormalizeOp op = new NormalizeOp(0, 1);
  28704. + TensorBuffer output = op.apply(input);
  28705. + assertThat(output.getDataType()).isEqualTo(UINT8);
  28706. + assertThat(output).isSameInstanceAs(input);
  28707. }
  28708. - }
  28709. - @Test
  28710. - public void testNormalizeFloatBuffer() {
  28711. - float[] inputArr = new float[NUM_ELEMENTS];
  28712. - for (int i = 0; i < NUM_ELEMENTS; i++) {
  28713. - inputArr[i] = i;
  28714. + @Test
  28715. + public void testNormalizeOp_zeroMeanAndZeroStddev() {
  28716. + TensorBuffer input = TensorBuffer.createFixedSize(new int[] {3, 3}, UINT8);
  28717. + NormalizeOp op = new NormalizeOp(0, 0);
  28718. + TensorBuffer output = op.apply(input);
  28719. + assertThat(output.getDataType()).isEqualTo(UINT8);
  28720. + assertThat(output).isSameInstanceAs(input);
  28721. }
  28722. - TensorBuffer input = TensorBuffer.createDynamic(FLOAT32);
  28723. - input.loadArray(inputArr, new int[] {inputArr.length});
  28724. - NormalizeOp op = new NormalizeOp(MEAN, STDDEV);
  28725. - TensorBuffer output = op.apply(input);
  28726. - assertThat(output.getDataType()).isEqualTo(FLOAT32);
  28727. - float[] outputArr = output.getFloatArray();
  28728. - for (int i = 0; i < NUM_ELEMENTS; i++) {
  28729. - assertThat(outputArr[i]).isEqualTo((inputArr[i] - MEAN) / STDDEV);
  28730. +
  28731. + @Test
  28732. + public void testNormalizeOp_zeroMeanAndInifityStddev() {
  28733. + TensorBuffer input = TensorBuffer.createFixedSize(new int[] {3, 3}, UINT8);
  28734. + NormalizeOp op = new NormalizeOp(0, Float.POSITIVE_INFINITY);
  28735. + TensorBuffer output = op.apply(input);
  28736. + assertThat(output.getDataType()).isEqualTo(UINT8);
  28737. + assertThat(output).isSameInstanceAs(input);
  28738. }
  28739. - }
  28740. -
  28741. - @Test
  28742. - public void testZeroStddev() {
  28743. - Assert.assertThrows(IllegalArgumentException.class, () -> new NormalizeOp(1, 0));
  28744. - }
  28745. -
  28746. - @Test
  28747. - public void testIdentityShortcut() {
  28748. - TensorBuffer input = TensorBuffer.createFixedSize(new int[] {3, 3}, UINT8);
  28749. - NormalizeOp op = new NormalizeOp(0, 1);
  28750. - TensorBuffer output = op.apply(input);
  28751. - assertThat(output.getDataType()).isEqualTo(UINT8);
  28752. - assertThat(output).isSameInstanceAs(input);
  28753. - }
  28754. -
  28755. - @Test
  28756. - public void testNormalizeOp_zeroMeanAndZeroStddev() {
  28757. - TensorBuffer input = TensorBuffer.createFixedSize(new int[] {3, 3}, UINT8);
  28758. - NormalizeOp op = new NormalizeOp(0, 0);
  28759. - TensorBuffer output = op.apply(input);
  28760. - assertThat(output.getDataType()).isEqualTo(UINT8);
  28761. - assertThat(output).isSameInstanceAs(input);
  28762. - }
  28763. -
  28764. - @Test
  28765. - public void testNormalizeOp_zeroMeanAndInifityStddev() {
  28766. - TensorBuffer input = TensorBuffer.createFixedSize(new int[] {3, 3}, UINT8);
  28767. - NormalizeOp op = new NormalizeOp(0, Float.POSITIVE_INFINITY);
  28768. - TensorBuffer output = op.apply(input);
  28769. - assertThat(output.getDataType()).isEqualTo(UINT8);
  28770. - assertThat(output).isSameInstanceAs(input);
  28771. - }
  28772. -
  28773. - @Test
  28774. - public void testMultiChannelNormalize() {
  28775. - float[] inputArr = new float[NUM_ELEMENTS];
  28776. - for (int i = 0; i < NUM_ELEMENTS; i++) {
  28777. - inputArr[i] = i;
  28778. +
  28779. + @Test
  28780. + public void testMultiChannelNormalize() {
  28781. + float[] inputArr = new float[NUM_ELEMENTS];
  28782. + for (int i = 0; i < NUM_ELEMENTS; i++) {
  28783. + inputArr[i] = i;
  28784. + }
  28785. + TensorBuffer input = TensorBuffer.createDynamic(FLOAT32);
  28786. + input.loadArray(inputArr, new int[] {20, 5});
  28787. + float[] means = new float[] {1, 2, 3, 4, 5};
  28788. + float[] stddevs = new float[] {6, 7, 8, 9, 10};
  28789. + NormalizeOp op = new NormalizeOp(means, stddevs);
  28790. + TensorBuffer output = op.apply(input);
  28791. + assertThat(output.getDataType()).isEqualTo(FLOAT32);
  28792. + float[] outputArr = output.getFloatArray();
  28793. + for (int i = 0; i < NUM_ELEMENTS; i++) {
  28794. + assertThat(outputArr[i]).isEqualTo((i - means[i % 5]) / stddevs[i % 5]);
  28795. + }
  28796. }
  28797. - TensorBuffer input = TensorBuffer.createDynamic(FLOAT32);
  28798. - input.loadArray(inputArr, new int[] {20, 5});
  28799. - float[] means = new float[] {1, 2, 3, 4, 5};
  28800. - float[] stddevs = new float[] {6, 7, 8, 9, 10};
  28801. - NormalizeOp op = new NormalizeOp(means, stddevs);
  28802. - TensorBuffer output = op.apply(input);
  28803. - assertThat(output.getDataType()).isEqualTo(FLOAT32);
  28804. - float[] outputArr = output.getFloatArray();
  28805. - for (int i = 0; i < NUM_ELEMENTS; i++) {
  28806. - assertThat(outputArr[i]).isEqualTo((i - means[i % 5]) / stddevs[i % 5]);
  28807. +
  28808. + @Test
  28809. + public void testMultiChannelShortcut() {
  28810. + TensorBuffer input = TensorBuffer.createFixedSize(new int[] {3, 3}, UINT8);
  28811. + NormalizeOp op = new NormalizeOp(new float[] {0, 0, 0}, new float[] {1, 1, 1});
  28812. + TensorBuffer output = op.apply(input);
  28813. + assertThat(output.getDataType()).isEqualTo(UINT8);
  28814. + assertThat(output).isSameInstanceAs(input);
  28815. + }
  28816. +
  28817. + @Test
  28818. + public void testMismatchedNumbersOfMeansAndStddevs() {
  28819. + Assert.assertThrows(IllegalArgumentException.class,
  28820. + () -> new NormalizeOp(new float[] {2, 3}, new float[] {1}));
  28821. + }
  28822. +
  28823. + @Test
  28824. + public void testMismatchedInputTensorChannelNum() {
  28825. + TensorBuffer input = TensorBuffer.createFixedSize(new int[] {3, 3}, UINT8);
  28826. + NormalizeOp op = new NormalizeOp(new float[] {0, 0}, new float[] {1, 2});
  28827. + Assert.assertThrows(IllegalArgumentException.class, () -> op.apply(input));
  28828. + }
  28829. +
  28830. + @Test
  28831. + public void testAnyChannelInvalidStddev() {
  28832. + Assert.assertThrows(IllegalArgumentException.class,
  28833. + () -> new NormalizeOp(new float[] {2, 3}, new float[] {1, 0}));
  28834. }
  28835. - }
  28836. -
  28837. - @Test
  28838. - public void testMultiChannelShortcut() {
  28839. - TensorBuffer input = TensorBuffer.createFixedSize(new int[] {3, 3}, UINT8);
  28840. - NormalizeOp op = new NormalizeOp(new float[] {0, 0, 0}, new float[] {1, 1, 1});
  28841. - TensorBuffer output = op.apply(input);
  28842. - assertThat(output.getDataType()).isEqualTo(UINT8);
  28843. - assertThat(output).isSameInstanceAs(input);
  28844. - }
  28845. -
  28846. - @Test
  28847. - public void testMismatchedNumbersOfMeansAndStddevs() {
  28848. - Assert.assertThrows(
  28849. - IllegalArgumentException.class, () -> new NormalizeOp(new float[] {2, 3}, new float[] {1}));
  28850. - }
  28851. -
  28852. - @Test
  28853. - public void testMismatchedInputTensorChannelNum() {
  28854. - TensorBuffer input = TensorBuffer.createFixedSize(new int[] {3, 3}, UINT8);
  28855. - NormalizeOp op = new NormalizeOp(new float[] {0, 0}, new float[] {1, 2});
  28856. - Assert.assertThrows(IllegalArgumentException.class, () -> op.apply(input));
  28857. - }
  28858. -
  28859. - @Test
  28860. - public void testAnyChannelInvalidStddev() {
  28861. - Assert.assertThrows(
  28862. - IllegalArgumentException.class,
  28863. - () -> new NormalizeOp(new float[] {2, 3}, new float[] {1, 0}));
  28864. - }
  28865. }
  28866. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/ops/QuantizeOpTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/ops/QuantizeOpTest.java
  28867. index 8ef72f92e0696..519cd287e1575 100644
  28868. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/ops/QuantizeOpTest.java
  28869. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/ops/QuantizeOpTest.java
  28870. @@ -26,15 +26,14 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
  28871. /** Tests of {@link QuantizeOp}. */
  28872. @RunWith(RobolectricTestRunner.class)
  28873. public final class QuantizeOpTest {
  28874. -
  28875. - @Test
  28876. - public void quantizeShouldSuccess() {
  28877. - float[] originalData = {0.5f, 0.25f, -0.5f, 0, 1, -0.9921875f}; // -0.9921875 == -127 / 128
  28878. - QuantizeOp op = new QuantizeOp(127.0f, 1.0f / 128);
  28879. - TensorBuffer input = TensorBuffer.createFixedSize(new int[] {6}, DataType.FLOAT32);
  28880. - input.loadArray(originalData);
  28881. - TensorBuffer quantized = op.apply(input);
  28882. - assertThat(quantized.getDataType()).isEqualTo(DataType.FLOAT32);
  28883. - assertThat(quantized.getIntArray()).isEqualTo(new int[] {191, 159, 63, 127, 255, 0});
  28884. - }
  28885. + @Test
  28886. + public void quantizeShouldSuccess() {
  28887. + float[] originalData = {0.5f, 0.25f, -0.5f, 0, 1, -0.9921875f}; // -0.9921875 == -127 / 128
  28888. + QuantizeOp op = new QuantizeOp(127.0f, 1.0f / 128);
  28889. + TensorBuffer input = TensorBuffer.createFixedSize(new int[] {6}, DataType.FLOAT32);
  28890. + input.loadArray(originalData);
  28891. + TensorBuffer quantized = op.apply(input);
  28892. + assertThat(quantized.getDataType()).isEqualTo(DataType.FLOAT32);
  28893. + assertThat(quantized.getIntArray()).isEqualTo(new int[] {191, 159, 63, 127, 255, 0});
  28894. + }
  28895. }
  28896. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/BoundingBoxUtilTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/BoundingBoxUtilTest.java
  28897. index 7f16c8e95628d..e8edb588c61c6 100644
  28898. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/BoundingBoxUtilTest.java
  28899. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/BoundingBoxUtilTest.java
  28900. @@ -18,7 +18,7 @@ package org.tensorflow.lite.support.image;
  28901. import static com.google.common.truth.Truth.assertThat;
  28902. import android.graphics.RectF;
  28903. -import java.util.List;
  28904. +
  28905. import org.junit.Assert;
  28906. import org.junit.Before;
  28907. import org.junit.Test;
  28908. @@ -28,213 +28,142 @@ import org.tensorflow.lite.DataType;
  28909. import org.tensorflow.lite.support.image.BoundingBoxUtil.CoordinateType;
  28910. import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
  28911. +import java.util.List;
  28912. +
  28913. /** Tests of {@link BoundingBoxUtil}. */
  28914. @RunWith(RobolectricTestRunner.class)
  28915. public class BoundingBoxUtilTest {
  28916. -
  28917. - private TensorBuffer tensorBuffer;
  28918. -
  28919. - @Before
  28920. - public void setUp() {
  28921. - // 2 bounding boxes with additional batch dimension.
  28922. - tensorBuffer = TensorBuffer.createFixedSize(new int[] {1, 2, 4}, DataType.FLOAT32);
  28923. - }
  28924. -
  28925. - @Test
  28926. - public void convertDefaultRatioBoundaries() {
  28927. - tensorBuffer.loadArray(new float[] {0.25f, 0.2f, 0.75f, 0.8f, 0.5f, 0.0f, 1.0f, 1.0f});
  28928. -
  28929. - List<RectF> boxList =
  28930. - BoundingBoxUtil.convert(
  28931. - tensorBuffer,
  28932. - new int[] {0, 1, 2, 3},
  28933. - -1,
  28934. - BoundingBoxUtil.Type.BOUNDARIES,
  28935. - CoordinateType.RATIO,
  28936. - 500,
  28937. - 400);
  28938. -
  28939. - assertThat(boxList).hasSize(2);
  28940. - assertThat(boxList.get(0)).isEqualTo(new RectF(100, 100, 300, 400));
  28941. - assertThat(boxList.get(1)).isEqualTo(new RectF(200, 0, 400, 500));
  28942. - }
  28943. -
  28944. - @Test
  28945. - public void convertComplexTensor() {
  28946. - tensorBuffer = TensorBuffer.createFixedSize(new int[] {3, 4, 2}, DataType.FLOAT32);
  28947. - tensorBuffer.loadArray(
  28948. - new float[] {
  28949. - // sub tensor 0
  28950. - 0, 1, 10, 11, 20, 21, 30, 31,
  28951. - // sub tensor 1
  28952. - 100, 101, 110, 111, 120, 121, 130, 131,
  28953. - // sub tensor 2
  28954. - 200, 201, 210, 211, 220, 221, 230, 231
  28955. - });
  28956. -
  28957. - List<RectF> boxList =
  28958. - BoundingBoxUtil.convert(
  28959. - tensorBuffer,
  28960. - new int[] {0, 1, 2, 3},
  28961. - 1,
  28962. - BoundingBoxUtil.Type.BOUNDARIES,
  28963. - CoordinateType.PIXEL,
  28964. - 0,
  28965. - 0);
  28966. -
  28967. - assertThat(boxList).hasSize(6);
  28968. - assertThat(boxList.get(0)).isEqualTo(new RectF(0, 10, 20, 30));
  28969. - assertThat(boxList.get(1)).isEqualTo(new RectF(1, 11, 21, 31));
  28970. - assertThat(boxList.get(2)).isEqualTo(new RectF(100, 110, 120, 130));
  28971. - assertThat(boxList.get(3)).isEqualTo(new RectF(101, 111, 121, 131));
  28972. - }
  28973. -
  28974. - @Test
  28975. - public void convertIndexedRatioBoundaries() {
  28976. - tensorBuffer.loadArray(new float[] {0.25f, 0.2f, 0.75f, 0.8f, 0.5f, 0.0f, 1.0f, 1.0f});
  28977. -
  28978. - List<RectF> boxList =
  28979. - BoundingBoxUtil.convert(
  28980. - tensorBuffer,
  28981. - new int[] {1, 0, 3, 2},
  28982. - -1,
  28983. - BoundingBoxUtil.Type.BOUNDARIES,
  28984. - CoordinateType.RATIO,
  28985. - 500,
  28986. - 400);
  28987. -
  28988. - assertThat(boxList).hasSize(2);
  28989. - assertThat(boxList.get(0)).isEqualTo(new RectF(80, 125, 320, 375));
  28990. - assertThat(boxList.get(1)).isEqualTo(new RectF(0, 250, 400, 500));
  28991. - }
  28992. -
  28993. - @Test
  28994. - public void convertPixelBoundaries() {
  28995. - tensorBuffer.loadArray(new float[] {100, 100, 300, 400, 200, 0, 400, 500});
  28996. -
  28997. - List<RectF> boxList =
  28998. - BoundingBoxUtil.convert(
  28999. - tensorBuffer,
  29000. - new int[] {0, 1, 2, 3},
  29001. - -1,
  29002. - BoundingBoxUtil.Type.BOUNDARIES,
  29003. - CoordinateType.PIXEL,
  29004. - 500,
  29005. - 400);
  29006. -
  29007. - assertThat(boxList)
  29008. - .containsExactly(new RectF(100, 100, 300, 400), new RectF(200, 0, 400, 500))
  29009. - .inOrder();
  29010. - }
  29011. -
  29012. - @Test
  29013. - public void convertRatioUpperLeft() {
  29014. - tensorBuffer.loadArray(new float[] {0.25f, 0.2f, 0.5f, 0.6f, 0.5f, 0.0f, 0.5f, 1.0f});
  29015. -
  29016. - List<RectF> boxList =
  29017. - BoundingBoxUtil.convert(
  29018. - tensorBuffer,
  29019. - new int[] {0, 1, 2, 3},
  29020. - -1,
  29021. - BoundingBoxUtil.Type.UPPER_LEFT,
  29022. - CoordinateType.RATIO,
  29023. - 500,
  29024. - 400);
  29025. -
  29026. - assertThat(boxList).hasSize(2);
  29027. - assertThat(boxList)
  29028. - .containsExactly(new RectF(100, 100, 300, 400), new RectF(200, 0, 400, 500))
  29029. - .inOrder();
  29030. - }
  29031. -
  29032. - @Test
  29033. - public void convertPixelUpperLeft() {
  29034. - tensorBuffer.loadArray(new float[] {100, 100, 200, 300, 200, 0, 200, 500});
  29035. -
  29036. - List<RectF> boxList =
  29037. - BoundingBoxUtil.convert(
  29038. - tensorBuffer,
  29039. - new int[] {0, 1, 2, 3},
  29040. - -1,
  29041. - BoundingBoxUtil.Type.UPPER_LEFT,
  29042. - CoordinateType.PIXEL,
  29043. - 500,
  29044. - 400);
  29045. -
  29046. - assertThat(boxList)
  29047. - .containsExactly(new RectF(100, 100, 300, 400), new RectF(200, 0, 400, 500))
  29048. - .inOrder();
  29049. - }
  29050. -
  29051. - @Test
  29052. - public void convertRatioCenter() {
  29053. - tensorBuffer.loadArray(new float[] {0.5f, 0.5f, 0.5f, 0.6f, 0.75f, 0.5f, 0.5f, 1.0f});
  29054. -
  29055. - List<RectF> boxList =
  29056. - BoundingBoxUtil.convert(
  29057. - tensorBuffer,
  29058. - new int[] {0, 1, 2, 3},
  29059. - -1,
  29060. - BoundingBoxUtil.Type.CENTER,
  29061. - CoordinateType.RATIO,
  29062. - 500,
  29063. - 400);
  29064. -
  29065. - assertThat(boxList)
  29066. - .containsExactly(new RectF(100, 99.99999f, 300, 400), new RectF(200, 0, 400, 500))
  29067. - .inOrder();
  29068. - }
  29069. -
  29070. - @Test
  29071. - public void convertPixelCenter() {
  29072. - tensorBuffer.loadArray(new float[] {200, 250, 200, 300, 300, 250, 200, 500});
  29073. -
  29074. - List<RectF> boxList =
  29075. - BoundingBoxUtil.convert(
  29076. - tensorBuffer,
  29077. - new int[] {0, 1, 2, 3},
  29078. - -1,
  29079. - BoundingBoxUtil.Type.CENTER,
  29080. - CoordinateType.PIXEL,
  29081. - 500,
  29082. - 400);
  29083. -
  29084. - assertThat(boxList)
  29085. - .containsExactly(new RectF(100, 100, 300, 400), new RectF(200, 0, 400, 500))
  29086. - .inOrder();
  29087. - }
  29088. -
  29089. - @Test
  29090. - public void convertTensorWithUnexpectedShapeShouldThrow() {
  29091. - TensorBuffer badShapeTensor = TensorBuffer.createFixedSize(new int[] {1, 5}, DataType.FLOAT32);
  29092. -
  29093. - Assert.assertThrows(
  29094. - IllegalArgumentException.class,
  29095. - () ->
  29096. - BoundingBoxUtil.convert(
  29097. - badShapeTensor,
  29098. - new int[] {0, 1, 2, 3},
  29099. - -1,
  29100. - BoundingBoxUtil.Type.BOUNDARIES,
  29101. - CoordinateType.RATIO,
  29102. - 300,
  29103. - 400));
  29104. - }
  29105. -
  29106. - @Test
  29107. - public void convertIntTensorShouldThrow() {
  29108. - TensorBuffer badTypeTensor = TensorBuffer.createFixedSize(new int[] {1, 4}, DataType.UINT8);
  29109. -
  29110. - Assert.assertThrows(
  29111. - IllegalArgumentException.class,
  29112. - () ->
  29113. - BoundingBoxUtil.convert(
  29114. - badTypeTensor,
  29115. - new int[] {0, 1, 2, 3},
  29116. - -1,
  29117. - BoundingBoxUtil.Type.BOUNDARIES,
  29118. - CoordinateType.RATIO,
  29119. - 300,
  29120. - 400));
  29121. - }
  29122. + private TensorBuffer tensorBuffer;
  29123. +
  29124. + @Before
  29125. + public void setUp() {
  29126. + // 2 bounding boxes with additional batch dimension.
  29127. + tensorBuffer = TensorBuffer.createFixedSize(new int[] {1, 2, 4}, DataType.FLOAT32);
  29128. + }
  29129. +
  29130. + @Test
  29131. + public void convertDefaultRatioBoundaries() {
  29132. + tensorBuffer.loadArray(new float[] {0.25f, 0.2f, 0.75f, 0.8f, 0.5f, 0.0f, 1.0f, 1.0f});
  29133. +
  29134. + List<RectF> boxList = BoundingBoxUtil.convert(tensorBuffer, new int[] {0, 1, 2, 3}, -1,
  29135. + BoundingBoxUtil.Type.BOUNDARIES, CoordinateType.RATIO, 500, 400);
  29136. +
  29137. + assertThat(boxList).hasSize(2);
  29138. + assertThat(boxList.get(0)).isEqualTo(new RectF(100, 100, 300, 400));
  29139. + assertThat(boxList.get(1)).isEqualTo(new RectF(200, 0, 400, 500));
  29140. + }
  29141. +
  29142. + @Test
  29143. + public void convertComplexTensor() {
  29144. + tensorBuffer = TensorBuffer.createFixedSize(new int[] {3, 4, 2}, DataType.FLOAT32);
  29145. + tensorBuffer.loadArray(new float[] {// sub tensor 0
  29146. + 0, 1, 10, 11, 20, 21, 30, 31,
  29147. + // sub tensor 1
  29148. + 100, 101, 110, 111, 120, 121, 130, 131,
  29149. + // sub tensor 2
  29150. + 200, 201, 210, 211, 220, 221, 230, 231});
  29151. +
  29152. + List<RectF> boxList = BoundingBoxUtil.convert(tensorBuffer, new int[] {0, 1, 2, 3}, 1,
  29153. + BoundingBoxUtil.Type.BOUNDARIES, CoordinateType.PIXEL, 0, 0);
  29154. +
  29155. + assertThat(boxList).hasSize(6);
  29156. + assertThat(boxList.get(0)).isEqualTo(new RectF(0, 10, 20, 30));
  29157. + assertThat(boxList.get(1)).isEqualTo(new RectF(1, 11, 21, 31));
  29158. + assertThat(boxList.get(2)).isEqualTo(new RectF(100, 110, 120, 130));
  29159. + assertThat(boxList.get(3)).isEqualTo(new RectF(101, 111, 121, 131));
  29160. + }
  29161. +
  29162. + @Test
  29163. + public void convertIndexedRatioBoundaries() {
  29164. + tensorBuffer.loadArray(new float[] {0.25f, 0.2f, 0.75f, 0.8f, 0.5f, 0.0f, 1.0f, 1.0f});
  29165. +
  29166. + List<RectF> boxList = BoundingBoxUtil.convert(tensorBuffer, new int[] {1, 0, 3, 2}, -1,
  29167. + BoundingBoxUtil.Type.BOUNDARIES, CoordinateType.RATIO, 500, 400);
  29168. +
  29169. + assertThat(boxList).hasSize(2);
  29170. + assertThat(boxList.get(0)).isEqualTo(new RectF(80, 125, 320, 375));
  29171. + assertThat(boxList.get(1)).isEqualTo(new RectF(0, 250, 400, 500));
  29172. + }
  29173. +
  29174. + @Test
  29175. + public void convertPixelBoundaries() {
  29176. + tensorBuffer.loadArray(new float[] {100, 100, 300, 400, 200, 0, 400, 500});
  29177. +
  29178. + List<RectF> boxList = BoundingBoxUtil.convert(tensorBuffer, new int[] {0, 1, 2, 3}, -1,
  29179. + BoundingBoxUtil.Type.BOUNDARIES, CoordinateType.PIXEL, 500, 400);
  29180. +
  29181. + assertThat(boxList)
  29182. + .containsExactly(new RectF(100, 100, 300, 400), new RectF(200, 0, 400, 500))
  29183. + .inOrder();
  29184. + }
  29185. +
  29186. + @Test
  29187. + public void convertRatioUpperLeft() {
  29188. + tensorBuffer.loadArray(new float[] {0.25f, 0.2f, 0.5f, 0.6f, 0.5f, 0.0f, 0.5f, 1.0f});
  29189. +
  29190. + List<RectF> boxList = BoundingBoxUtil.convert(tensorBuffer, new int[] {0, 1, 2, 3}, -1,
  29191. + BoundingBoxUtil.Type.UPPER_LEFT, CoordinateType.RATIO, 500, 400);
  29192. +
  29193. + assertThat(boxList).hasSize(2);
  29194. + assertThat(boxList)
  29195. + .containsExactly(new RectF(100, 100, 300, 400), new RectF(200, 0, 400, 500))
  29196. + .inOrder();
  29197. + }
  29198. +
  29199. + @Test
  29200. + public void convertPixelUpperLeft() {
  29201. + tensorBuffer.loadArray(new float[] {100, 100, 200, 300, 200, 0, 200, 500});
  29202. +
  29203. + List<RectF> boxList = BoundingBoxUtil.convert(tensorBuffer, new int[] {0, 1, 2, 3}, -1,
  29204. + BoundingBoxUtil.Type.UPPER_LEFT, CoordinateType.PIXEL, 500, 400);
  29205. +
  29206. + assertThat(boxList)
  29207. + .containsExactly(new RectF(100, 100, 300, 400), new RectF(200, 0, 400, 500))
  29208. + .inOrder();
  29209. + }
  29210. +
  29211. + @Test
  29212. + public void convertRatioCenter() {
  29213. + tensorBuffer.loadArray(new float[] {0.5f, 0.5f, 0.5f, 0.6f, 0.75f, 0.5f, 0.5f, 1.0f});
  29214. +
  29215. + List<RectF> boxList = BoundingBoxUtil.convert(tensorBuffer, new int[] {0, 1, 2, 3}, -1,
  29216. + BoundingBoxUtil.Type.CENTER, CoordinateType.RATIO, 500, 400);
  29217. +
  29218. + assertThat(boxList)
  29219. + .containsExactly(new RectF(100, 99.99999f, 300, 400), new RectF(200, 0, 400, 500))
  29220. + .inOrder();
  29221. + }
  29222. +
  29223. + @Test
  29224. + public void convertPixelCenter() {
  29225. + tensorBuffer.loadArray(new float[] {200, 250, 200, 300, 300, 250, 200, 500});
  29226. +
  29227. + List<RectF> boxList = BoundingBoxUtil.convert(tensorBuffer, new int[] {0, 1, 2, 3}, -1,
  29228. + BoundingBoxUtil.Type.CENTER, CoordinateType.PIXEL, 500, 400);
  29229. +
  29230. + assertThat(boxList)
  29231. + .containsExactly(new RectF(100, 100, 300, 400), new RectF(200, 0, 400, 500))
  29232. + .inOrder();
  29233. + }
  29234. +
  29235. + @Test
  29236. + public void convertTensorWithUnexpectedShapeShouldThrow() {
  29237. + TensorBuffer badShapeTensor =
  29238. + TensorBuffer.createFixedSize(new int[] {1, 5}, DataType.FLOAT32);
  29239. +
  29240. + Assert.assertThrows(IllegalArgumentException.class,
  29241. + ()
  29242. + -> BoundingBoxUtil.convert(badShapeTensor, new int[] {0, 1, 2, 3}, -1,
  29243. + BoundingBoxUtil.Type.BOUNDARIES, CoordinateType.RATIO, 300, 400));
  29244. + }
  29245. +
  29246. + @Test
  29247. + public void convertIntTensorShouldThrow() {
  29248. + TensorBuffer badTypeTensor = TensorBuffer.createFixedSize(new int[] {1, 4}, DataType.UINT8);
  29249. +
  29250. + Assert.assertThrows(IllegalArgumentException.class,
  29251. + ()
  29252. + -> BoundingBoxUtil.convert(badTypeTensor, new int[] {0, 1, 2, 3}, -1,
  29253. + BoundingBoxUtil.Type.BOUNDARIES, CoordinateType.RATIO, 300, 400));
  29254. + }
  29255. }
  29256. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ColorSpaceTypeInstrumentedTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ColorSpaceTypeInstrumentedTest.java
  29257. index c41508308291a..329b5aa370744 100644
  29258. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ColorSpaceTypeInstrumentedTest.java
  29259. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ColorSpaceTypeInstrumentedTest.java
  29260. @@ -15,10 +15,12 @@ limitations under the License.
  29261. package org.tensorflow.lite.support.image;
  29262. import static com.google.common.truth.Truth.assertThat;
  29263. +
  29264. import static org.tensorflow.lite.support.image.TestImageCreator.createGrayscaleBitmap;
  29265. import static org.tensorflow.lite.support.image.TestImageCreator.createGrayscaleTensorBuffer;
  29266. import android.graphics.Bitmap;
  29267. +
  29268. import org.junit.Test;
  29269. import org.junit.runner.RunWith;
  29270. import org.junit.runners.JUnit4;
  29271. @@ -27,22 +29,21 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
  29272. @RunWith(JUnit4.class)
  29273. public final class ColorSpaceTypeInstrumentedTest {
  29274. -
  29275. - @Test
  29276. - public void convertTensorBufferToBitmapShouldSuccessWithGrayscaleWithUint8() {
  29277. - TensorBuffer buffer = createGrayscaleTensorBuffer(DataType.UINT8, false);
  29278. - Bitmap bitmap = ColorSpaceType.GRAYSCALE.convertTensorBufferToBitmap(buffer);
  29279. -
  29280. - Bitmap expectedBitmap = createGrayscaleBitmap();
  29281. - assertThat(bitmap.sameAs(expectedBitmap)).isTrue();
  29282. - }
  29283. -
  29284. - @Test
  29285. - public void convertTensorBufferToBitmapShouldSuccessWithGrayscaleWithFloat() {
  29286. - TensorBuffer buffer = createGrayscaleTensorBuffer(DataType.FLOAT32, false);
  29287. - Bitmap bitmap = ColorSpaceType.GRAYSCALE.convertTensorBufferToBitmap(buffer);
  29288. -
  29289. - Bitmap expectedBitmap = createGrayscaleBitmap();
  29290. - assertThat(bitmap.sameAs(expectedBitmap)).isTrue();
  29291. - }
  29292. + @Test
  29293. + public void convertTensorBufferToBitmapShouldSuccessWithGrayscaleWithUint8() {
  29294. + TensorBuffer buffer = createGrayscaleTensorBuffer(DataType.UINT8, false);
  29295. + Bitmap bitmap = ColorSpaceType.GRAYSCALE.convertTensorBufferToBitmap(buffer);
  29296. +
  29297. + Bitmap expectedBitmap = createGrayscaleBitmap();
  29298. + assertThat(bitmap.sameAs(expectedBitmap)).isTrue();
  29299. + }
  29300. +
  29301. + @Test
  29302. + public void convertTensorBufferToBitmapShouldSuccessWithGrayscaleWithFloat() {
  29303. + TensorBuffer buffer = createGrayscaleTensorBuffer(DataType.FLOAT32, false);
  29304. + Bitmap bitmap = ColorSpaceType.GRAYSCALE.convertTensorBufferToBitmap(buffer);
  29305. +
  29306. + Bitmap expectedBitmap = createGrayscaleBitmap();
  29307. + assertThat(bitmap.sameAs(expectedBitmap)).isTrue();
  29308. + }
  29309. }
  29310. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ColorSpaceTypeTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ColorSpaceTypeTest.java
  29311. index 46977fdb2bdfa..92612255269f6 100644
  29312. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ColorSpaceTypeTest.java
  29313. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ColorSpaceTypeTest.java
  29314. @@ -16,6 +16,7 @@ limitations under the License.
  29315. package org.tensorflow.lite.support.image;
  29316. import static com.google.common.truth.Truth.assertThat;
  29317. +
  29318. import static org.junit.Assert.assertThrows;
  29319. import static org.tensorflow.lite.support.image.TestImageCreator.createRgbBitmap;
  29320. import static org.tensorflow.lite.support.image.TestImageCreator.createRgbTensorBuffer;
  29321. @@ -23,8 +24,7 @@ import static org.tensorflow.lite.support.image.TestImageCreator.createRgbTensor
  29322. import android.graphics.Bitmap;
  29323. import android.graphics.Bitmap.Config;
  29324. import android.graphics.ImageFormat;
  29325. -import java.util.Arrays;
  29326. -import java.util.Collection;
  29327. +
  29328. import org.junit.Rule;
  29329. import org.junit.Test;
  29330. import org.junit.rules.ErrorCollector;
  29331. @@ -38,386 +38,353 @@ import org.robolectric.RobolectricTestRunner;
  29332. import org.tensorflow.lite.DataType;
  29333. import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
  29334. +import java.util.Arrays;
  29335. +import java.util.Collection;
  29336. +
  29337. /** Tests of {@link ImageConversions}. */
  29338. @RunWith(Suite.class)
  29339. -@SuiteClasses({
  29340. - ColorSpaceTypeTest.ValidShapeTest.class,
  29341. - ColorSpaceTypeTest.InvalidShapeTest.class,
  29342. - ColorSpaceTypeTest.BitmapConfigTest.class,
  29343. - ColorSpaceTypeTest.ImageFormatTest.class,
  29344. - ColorSpaceTypeTest.YuvImageTest.class,
  29345. - ColorSpaceTypeTest.AssertNumElementsTest.class,
  29346. - ColorSpaceTypeTest.General.class
  29347. -})
  29348. +@SuiteClasses({ColorSpaceTypeTest.ValidShapeTest.class, ColorSpaceTypeTest.InvalidShapeTest.class,
  29349. + ColorSpaceTypeTest.BitmapConfigTest.class, ColorSpaceTypeTest.ImageFormatTest.class,
  29350. + ColorSpaceTypeTest.YuvImageTest.class, ColorSpaceTypeTest.AssertNumElementsTest.class,
  29351. + ColorSpaceTypeTest.General.class})
  29352. public class ColorSpaceTypeTest {
  29353. -
  29354. - /** Parameterized tests for valid shapes. */
  29355. - @RunWith(ParameterizedRobolectricTestRunner.class)
  29356. - public static final class ValidShapeTest extends ColorSpaceTypeTest {
  29357. -
  29358. - @Parameter(0)
  29359. - public ColorSpaceType colorSpaceType;
  29360. -
  29361. - /** The shape that matches the colorSpaceType. */
  29362. - @Parameter(1)
  29363. - public int[] validShape;
  29364. -
  29365. - /** The height of validShape. */
  29366. - @Parameter(2)
  29367. - public int expectedHeight;
  29368. -
  29369. - /** The width of validShape. */
  29370. - @Parameter(3)
  29371. - public int expectedWidth;
  29372. -
  29373. - @Parameters(name = "colorSpaceType={0}; validShape={1}; height={2}; width={3}")
  29374. - public static Collection<Object[]> data() {
  29375. - return Arrays.asList(
  29376. - new Object[][] {
  29377. - {ColorSpaceType.RGB, new int[] {1, 10, 20, 3}, 10, 20},
  29378. - {ColorSpaceType.RGB, new int[] {10, 20, 3}, 10, 20},
  29379. - {ColorSpaceType.GRAYSCALE, new int[] {10, 20}, 10, 20},
  29380. - {ColorSpaceType.GRAYSCALE, new int[] {1, 10, 20, 1}, 10, 20},
  29381. - });
  29382. - }
  29383. -
  29384. - @Test
  29385. - public void getHeightSucceedsWithValidShape() {
  29386. - assertThat(colorSpaceType.getHeight(validShape)).isEqualTo(expectedHeight);
  29387. + /** Parameterized tests for valid shapes. */
  29388. + @RunWith(ParameterizedRobolectricTestRunner.class)
  29389. + public static final class ValidShapeTest extends ColorSpaceTypeTest {
  29390. + @Parameter(0)
  29391. + public ColorSpaceType colorSpaceType;
  29392. +
  29393. + /** The shape that matches the colorSpaceType. */
  29394. + @Parameter(1)
  29395. + public int[] validShape;
  29396. +
  29397. + /** The height of validShape. */
  29398. + @Parameter(2)
  29399. + public int expectedHeight;
  29400. +
  29401. + /** The width of validShape. */
  29402. + @Parameter(3)
  29403. + public int expectedWidth;
  29404. +
  29405. + @Parameters(name = "colorSpaceType={0}; validShape={1}; height={2}; width={3}")
  29406. + public static Collection<Object[]> data() {
  29407. + return Arrays.asList(new Object[][] {
  29408. + {ColorSpaceType.RGB, new int[] {1, 10, 20, 3}, 10, 20},
  29409. + {ColorSpaceType.RGB, new int[] {10, 20, 3}, 10, 20},
  29410. + {ColorSpaceType.GRAYSCALE, new int[] {10, 20}, 10, 20},
  29411. + {ColorSpaceType.GRAYSCALE, new int[] {1, 10, 20, 1}, 10, 20},
  29412. + });
  29413. + }
  29414. +
  29415. + @Test
  29416. + public void getHeightSucceedsWithValidShape() {
  29417. + assertThat(colorSpaceType.getHeight(validShape)).isEqualTo(expectedHeight);
  29418. + }
  29419. +
  29420. + @Test
  29421. + public void getWidthSucceedsWithValidShape() {
  29422. + assertThat(colorSpaceType.getWidth(validShape)).isEqualTo(expectedWidth);
  29423. + }
  29424. }
  29425. - @Test
  29426. - public void getWidthSucceedsWithValidShape() {
  29427. - assertThat(colorSpaceType.getWidth(validShape)).isEqualTo(expectedWidth);
  29428. - }
  29429. - }
  29430. -
  29431. - /** Parameterized tests for invalid shapes. */
  29432. - @RunWith(ParameterizedRobolectricTestRunner.class)
  29433. - public static final class InvalidShapeTest extends ColorSpaceTypeTest {
  29434. -
  29435. - private static final String RGB_ASSERT_SHAPE_MESSAGE =
  29436. - "The shape of a RGB image should be (h, w, c) or (1, h, w, c), and channels"
  29437. - + " representing R, G, B in order. The provided image shape is ";
  29438. - private static final String GRAYSCALE_ASSERT_SHAPE_MESSAGE =
  29439. - "The shape of a grayscale image should be (h, w) or (1, h, w, 1). The provided image"
  29440. - + " shape is ";
  29441. -
  29442. - @Parameter(0)
  29443. - public ColorSpaceType colorSpaceType;
  29444. -
  29445. - /** The shape that does not match the colorSpaceType. */
  29446. - @Parameter(1)
  29447. - public int[] invalidShape;
  29448. -
  29449. - @Parameter(2)
  29450. - public String errorMessage;
  29451. -
  29452. - @Parameters(name = "colorSpaceType={0}; invalidShape={1}")
  29453. - public static Collection<Object[]> data() {
  29454. - return Arrays.asList(
  29455. - new Object[][] {
  29456. - {ColorSpaceType.RGB, new int[] {2, 10, 20, 3}, RGB_ASSERT_SHAPE_MESSAGE},
  29457. - {ColorSpaceType.RGB, new int[] {1, 10, 20, 3, 4}, RGB_ASSERT_SHAPE_MESSAGE},
  29458. - {ColorSpaceType.RGB, new int[] {1, 10, 20, 5}, RGB_ASSERT_SHAPE_MESSAGE},
  29459. - {ColorSpaceType.RGB, new int[] {1, 10, 20}, RGB_ASSERT_SHAPE_MESSAGE},
  29460. - {ColorSpaceType.RGB, new int[] {1, -10, 20, 3}, RGB_ASSERT_SHAPE_MESSAGE},
  29461. - {ColorSpaceType.RGB, new int[] {1, 10, -20, 3}, RGB_ASSERT_SHAPE_MESSAGE},
  29462. - {ColorSpaceType.RGB, new int[] {10, 20, 3, 4}, RGB_ASSERT_SHAPE_MESSAGE},
  29463. - {ColorSpaceType.RGB, new int[] {10, 20, 5}, RGB_ASSERT_SHAPE_MESSAGE},
  29464. - {ColorSpaceType.RGB, new int[] {10, 20}, RGB_ASSERT_SHAPE_MESSAGE},
  29465. - {ColorSpaceType.RGB, new int[] {-10, 20, 3}, RGB_ASSERT_SHAPE_MESSAGE},
  29466. - {ColorSpaceType.RGB, new int[] {10, -20, 3}, RGB_ASSERT_SHAPE_MESSAGE},
  29467. - {ColorSpaceType.GRAYSCALE, new int[] {2, 10, 20}, GRAYSCALE_ASSERT_SHAPE_MESSAGE},
  29468. - {ColorSpaceType.GRAYSCALE, new int[] {1, 10, 20, 3}, GRAYSCALE_ASSERT_SHAPE_MESSAGE},
  29469. - {ColorSpaceType.GRAYSCALE, new int[] {1, -10, 20}, GRAYSCALE_ASSERT_SHAPE_MESSAGE},
  29470. - {ColorSpaceType.GRAYSCALE, new int[] {1, 10, -20}, GRAYSCALE_ASSERT_SHAPE_MESSAGE},
  29471. - {ColorSpaceType.GRAYSCALE, new int[] {10, 20, 4}, GRAYSCALE_ASSERT_SHAPE_MESSAGE},
  29472. - {ColorSpaceType.GRAYSCALE, new int[] {10}, GRAYSCALE_ASSERT_SHAPE_MESSAGE},
  29473. - {ColorSpaceType.GRAYSCALE, new int[] {-10, 20}, GRAYSCALE_ASSERT_SHAPE_MESSAGE},
  29474. - {ColorSpaceType.GRAYSCALE, new int[] {10, -20}, GRAYSCALE_ASSERT_SHAPE_MESSAGE},
  29475. - });
  29476. + /** Parameterized tests for invalid shapes. */
  29477. + @RunWith(ParameterizedRobolectricTestRunner.class)
  29478. + public static final class InvalidShapeTest extends ColorSpaceTypeTest {
  29479. + private static final String RGB_ASSERT_SHAPE_MESSAGE =
  29480. + "The shape of a RGB image should be (h, w, c) or (1, h, w, c), and channels"
  29481. + + " representing R, G, B in order. The provided image shape is ";
  29482. + private static final String GRAYSCALE_ASSERT_SHAPE_MESSAGE =
  29483. + "The shape of a grayscale image should be (h, w) or (1, h, w, 1). The provided image"
  29484. + + " shape is ";
  29485. +
  29486. + @Parameter(0)
  29487. + public ColorSpaceType colorSpaceType;
  29488. +
  29489. + /** The shape that does not match the colorSpaceType. */
  29490. + @Parameter(1)
  29491. + public int[] invalidShape;
  29492. +
  29493. + @Parameter(2)
  29494. + public String errorMessage;
  29495. +
  29496. + @Parameters(name = "colorSpaceType={0}; invalidShape={1}")
  29497. + public static Collection<Object[]> data() {
  29498. + return Arrays.asList(new Object[][] {
  29499. + {ColorSpaceType.RGB, new int[] {2, 10, 20, 3}, RGB_ASSERT_SHAPE_MESSAGE},
  29500. + {ColorSpaceType.RGB, new int[] {1, 10, 20, 3, 4}, RGB_ASSERT_SHAPE_MESSAGE},
  29501. + {ColorSpaceType.RGB, new int[] {1, 10, 20, 5}, RGB_ASSERT_SHAPE_MESSAGE},
  29502. + {ColorSpaceType.RGB, new int[] {1, 10, 20}, RGB_ASSERT_SHAPE_MESSAGE},
  29503. + {ColorSpaceType.RGB, new int[] {1, -10, 20, 3}, RGB_ASSERT_SHAPE_MESSAGE},
  29504. + {ColorSpaceType.RGB, new int[] {1, 10, -20, 3}, RGB_ASSERT_SHAPE_MESSAGE},
  29505. + {ColorSpaceType.RGB, new int[] {10, 20, 3, 4}, RGB_ASSERT_SHAPE_MESSAGE},
  29506. + {ColorSpaceType.RGB, new int[] {10, 20, 5}, RGB_ASSERT_SHAPE_MESSAGE},
  29507. + {ColorSpaceType.RGB, new int[] {10, 20}, RGB_ASSERT_SHAPE_MESSAGE},
  29508. + {ColorSpaceType.RGB, new int[] {-10, 20, 3}, RGB_ASSERT_SHAPE_MESSAGE},
  29509. + {ColorSpaceType.RGB, new int[] {10, -20, 3}, RGB_ASSERT_SHAPE_MESSAGE},
  29510. + {ColorSpaceType.GRAYSCALE, new int[] {2, 10, 20},
  29511. + GRAYSCALE_ASSERT_SHAPE_MESSAGE},
  29512. + {ColorSpaceType.GRAYSCALE, new int[] {1, 10, 20, 3},
  29513. + GRAYSCALE_ASSERT_SHAPE_MESSAGE},
  29514. + {ColorSpaceType.GRAYSCALE, new int[] {1, -10, 20},
  29515. + GRAYSCALE_ASSERT_SHAPE_MESSAGE},
  29516. + {ColorSpaceType.GRAYSCALE, new int[] {1, 10, -20},
  29517. + GRAYSCALE_ASSERT_SHAPE_MESSAGE},
  29518. + {ColorSpaceType.GRAYSCALE, new int[] {10, 20, 4},
  29519. + GRAYSCALE_ASSERT_SHAPE_MESSAGE},
  29520. + {ColorSpaceType.GRAYSCALE, new int[] {10}, GRAYSCALE_ASSERT_SHAPE_MESSAGE},
  29521. + {ColorSpaceType.GRAYSCALE, new int[] {-10, 20}, GRAYSCALE_ASSERT_SHAPE_MESSAGE},
  29522. + {ColorSpaceType.GRAYSCALE, new int[] {10, -20}, GRAYSCALE_ASSERT_SHAPE_MESSAGE},
  29523. + });
  29524. + }
  29525. +
  29526. + @Test
  29527. + public void assertShapeFaislsWithInvalidShape() {
  29528. + IllegalArgumentException exception = assertThrows(
  29529. + IllegalArgumentException.class, () -> colorSpaceType.assertShape(invalidShape));
  29530. + assertThat(exception).hasMessageThat().contains(
  29531. + errorMessage + Arrays.toString(invalidShape));
  29532. + }
  29533. +
  29534. + @Test
  29535. + public void getHeightFaislsWithInvalidShape() {
  29536. + IllegalArgumentException exception = assertThrows(
  29537. + IllegalArgumentException.class, () -> colorSpaceType.getHeight(invalidShape));
  29538. + assertThat(exception).hasMessageThat().contains(
  29539. + errorMessage + Arrays.toString(invalidShape));
  29540. + }
  29541. +
  29542. + @Test
  29543. + public void getWidthFaislsWithInvalidShape() {
  29544. + IllegalArgumentException exception = assertThrows(
  29545. + IllegalArgumentException.class, () -> colorSpaceType.getWidth(invalidShape));
  29546. + assertThat(exception).hasMessageThat().contains(
  29547. + errorMessage + Arrays.toString(invalidShape));
  29548. + }
  29549. }
  29550. - @Test
  29551. - public void assertShapeFaislsWithInvalidShape() {
  29552. - IllegalArgumentException exception =
  29553. - assertThrows(
  29554. - IllegalArgumentException.class, () -> colorSpaceType.assertShape(invalidShape));
  29555. - assertThat(exception).hasMessageThat().contains(errorMessage + Arrays.toString(invalidShape));
  29556. + /** Parameterized tests for Bitmap Config. */
  29557. + @RunWith(ParameterizedRobolectricTestRunner.class)
  29558. + public static final class BitmapConfigTest extends ColorSpaceTypeTest {
  29559. + @Parameter(0)
  29560. + public ColorSpaceType colorSpaceType;
  29561. +
  29562. + /** The Bitmap configuration match the colorSpaceType. */
  29563. + @Parameter(1)
  29564. + public Config config;
  29565. +
  29566. + @Parameters(name = "colorSpaceType={0}; config={1}")
  29567. + public static Collection<Object[]> data() {
  29568. + return Arrays.asList(new Object[][] {
  29569. + {ColorSpaceType.RGB, Config.ARGB_8888},
  29570. + {ColorSpaceType.GRAYSCALE, Config.ALPHA_8},
  29571. + });
  29572. + }
  29573. +
  29574. + @Test
  29575. + public void fromBitmapConfigSucceedsWithSupportedConfig() {
  29576. + assertThat(ColorSpaceType.fromBitmapConfig(config)).isEqualTo(colorSpaceType);
  29577. + }
  29578. +
  29579. + @Test
  29580. + public void toBitmapConfigSucceedsWithSupportedConfig() {
  29581. + assertThat(colorSpaceType.toBitmapConfig()).isEqualTo(config);
  29582. + }
  29583. }
  29584. - @Test
  29585. - public void getHeightFaislsWithInvalidShape() {
  29586. - IllegalArgumentException exception =
  29587. - assertThrows(
  29588. - IllegalArgumentException.class, () -> colorSpaceType.getHeight(invalidShape));
  29589. - assertThat(exception).hasMessageThat().contains(errorMessage + Arrays.toString(invalidShape));
  29590. + /** Parameterized tests for ImageFormat. */
  29591. + @RunWith(ParameterizedRobolectricTestRunner.class)
  29592. + public static final class ImageFormatTest extends ColorSpaceTypeTest {
  29593. + @Parameter(0)
  29594. + public ColorSpaceType colorSpaceType;
  29595. +
  29596. + /** The ImageFormat that matches the colorSpaceType. */
  29597. + @Parameter(1)
  29598. + public int imageFormat;
  29599. +
  29600. + @Parameters(name = "colorSpaceType={0}; imageFormat={1}")
  29601. + public static Collection<Object[]> data() {
  29602. + return Arrays.asList(new Object[][] {
  29603. + {ColorSpaceType.NV21, ImageFormat.NV21},
  29604. + {ColorSpaceType.YV12, ImageFormat.YV12},
  29605. + {ColorSpaceType.YUV_420_888, ImageFormat.YUV_420_888},
  29606. + });
  29607. + }
  29608. +
  29609. + @Test
  29610. + public void fromImageFormatSucceedsWithSupportedImageFormat() {
  29611. + assertThat(ColorSpaceType.fromImageFormat(imageFormat)).isEqualTo(colorSpaceType);
  29612. + }
  29613. }
  29614. - @Test
  29615. - public void getWidthFaislsWithInvalidShape() {
  29616. - IllegalArgumentException exception =
  29617. - assertThrows(IllegalArgumentException.class, () -> colorSpaceType.getWidth(invalidShape));
  29618. - assertThat(exception).hasMessageThat().contains(errorMessage + Arrays.toString(invalidShape));
  29619. - }
  29620. - }
  29621. -
  29622. - /** Parameterized tests for Bitmap Config. */
  29623. - @RunWith(ParameterizedRobolectricTestRunner.class)
  29624. - public static final class BitmapConfigTest extends ColorSpaceTypeTest {
  29625. -
  29626. - @Parameter(0)
  29627. - public ColorSpaceType colorSpaceType;
  29628. -
  29629. - /** The Bitmap configuration match the colorSpaceType. */
  29630. - @Parameter(1)
  29631. - public Config config;
  29632. -
  29633. - @Parameters(name = "colorSpaceType={0}; config={1}")
  29634. - public static Collection<Object[]> data() {
  29635. - return Arrays.asList(
  29636. - new Object[][] {
  29637. - {ColorSpaceType.RGB, Config.ARGB_8888},
  29638. - {ColorSpaceType.GRAYSCALE, Config.ALPHA_8},
  29639. - });
  29640. + /** Parameterized tests for YUV image formats: NV12, NV21, YV12, YV21, YUV_420_888. */
  29641. + @RunWith(ParameterizedRobolectricTestRunner.class)
  29642. + public static final class YuvImageTest extends ColorSpaceTypeTest {
  29643. + @Parameter(0)
  29644. + public ColorSpaceType colorSpaceType;
  29645. +
  29646. + @Parameters(name = "colorSpaceType={0}")
  29647. + public static Collection<Object[]> data() {
  29648. + return Arrays.asList(new Object[][] {
  29649. + {ColorSpaceType.NV12},
  29650. + {ColorSpaceType.NV21},
  29651. + {ColorSpaceType.YV12},
  29652. + {ColorSpaceType.YV21},
  29653. + {ColorSpaceType.YUV_420_888},
  29654. + });
  29655. + }
  29656. +
  29657. + @Test
  29658. + public void convertTensorBufferToBitmapShouldFail() {
  29659. + UnsupportedOperationException exception =
  29660. + assertThrows(UnsupportedOperationException.class,
  29661. + ()
  29662. + -> colorSpaceType.convertTensorBufferToBitmap(
  29663. + TensorBuffer.createDynamic(DataType.FLOAT32)));
  29664. + assertThat(exception).hasMessageThat().contains(
  29665. + "convertTensorBufferToBitmap() is unsupported for the color space type "
  29666. + + colorSpaceType.name());
  29667. + }
  29668. +
  29669. + @Test
  29670. + public void getWidthShouldFail() {
  29671. + UnsupportedOperationException exception =
  29672. + assertThrows(UnsupportedOperationException.class,
  29673. + () -> colorSpaceType.getWidth(new int[] {}));
  29674. + assertThat(exception).hasMessageThat().contains(
  29675. + "getWidth() only supports RGB and GRAYSCALE formats, but not "
  29676. + + colorSpaceType.name());
  29677. + }
  29678. +
  29679. + @Test
  29680. + public void getHeightShouldFail() {
  29681. + UnsupportedOperationException exception =
  29682. + assertThrows(UnsupportedOperationException.class,
  29683. + () -> colorSpaceType.getHeight(new int[] {}));
  29684. + assertThat(exception).hasMessageThat().contains(
  29685. + "getHeight() only supports RGB and GRAYSCALE formats, but not "
  29686. + + colorSpaceType.name());
  29687. + }
  29688. +
  29689. + @Test
  29690. + public void assertShapeShouldFail() {
  29691. + UnsupportedOperationException exception =
  29692. + assertThrows(UnsupportedOperationException.class,
  29693. + () -> colorSpaceType.assertShape(new int[] {}));
  29694. + assertThat(exception).hasMessageThat().contains(
  29695. + "assertShape() only supports RGB and GRAYSCALE formats, but not "
  29696. + + colorSpaceType.name());
  29697. + }
  29698. +
  29699. + @Test
  29700. + public void getChannelValueShouldFail() {
  29701. + UnsupportedOperationException exception = assertThrows(
  29702. + UnsupportedOperationException.class, () -> colorSpaceType.getChannelValue());
  29703. + assertThat(exception).hasMessageThat().contains(
  29704. + "getChannelValue() is unsupported for the color space type "
  29705. + + colorSpaceType.name());
  29706. + }
  29707. +
  29708. + @Test
  29709. + public void getNormalizedShapeShouldFail() {
  29710. + UnsupportedOperationException exception =
  29711. + assertThrows(UnsupportedOperationException.class,
  29712. + () -> colorSpaceType.getNormalizedShape(new int[] {}));
  29713. + assertThat(exception).hasMessageThat().contains(
  29714. + "getNormalizedShape() is unsupported for the color space type "
  29715. + + colorSpaceType.name());
  29716. + }
  29717. +
  29718. + @Test
  29719. + public void getShapeInfoMessageShouldFail() {
  29720. + UnsupportedOperationException exception =
  29721. + assertThrows(UnsupportedOperationException.class,
  29722. + () -> colorSpaceType.getShapeInfoMessage());
  29723. + assertThat(exception).hasMessageThat().contains(
  29724. + "getShapeInfoMessage() is unsupported for the color space type "
  29725. + + colorSpaceType.name());
  29726. + }
  29727. +
  29728. + @Test
  29729. + public void toBitmapConfigShouldFail() {
  29730. + UnsupportedOperationException exception = assertThrows(
  29731. + UnsupportedOperationException.class, () -> colorSpaceType.toBitmapConfig());
  29732. + assertThat(exception).hasMessageThat().contains(
  29733. + "toBitmapConfig() is unsupported for the color space type "
  29734. + + colorSpaceType.name());
  29735. + }
  29736. }
  29737. - @Test
  29738. - public void fromBitmapConfigSucceedsWithSupportedConfig() {
  29739. - assertThat(ColorSpaceType.fromBitmapConfig(config)).isEqualTo(colorSpaceType);
  29740. - }
  29741. -
  29742. - @Test
  29743. - public void toBitmapConfigSucceedsWithSupportedConfig() {
  29744. - assertThat(colorSpaceType.toBitmapConfig()).isEqualTo(config);
  29745. - }
  29746. - }
  29747. -
  29748. - /** Parameterized tests for ImageFormat. */
  29749. - @RunWith(ParameterizedRobolectricTestRunner.class)
  29750. - public static final class ImageFormatTest extends ColorSpaceTypeTest {
  29751. -
  29752. - @Parameter(0)
  29753. - public ColorSpaceType colorSpaceType;
  29754. -
  29755. - /** The ImageFormat that matches the colorSpaceType. */
  29756. - @Parameter(1)
  29757. - public int imageFormat;
  29758. -
  29759. - @Parameters(name = "colorSpaceType={0}; imageFormat={1}")
  29760. - public static Collection<Object[]> data() {
  29761. - return Arrays.asList(
  29762. - new Object[][] {
  29763. - {ColorSpaceType.NV21, ImageFormat.NV21},
  29764. - {ColorSpaceType.YV12, ImageFormat.YV12},
  29765. - {ColorSpaceType.YUV_420_888, ImageFormat.YUV_420_888},
  29766. - });
  29767. - }
  29768. -
  29769. - @Test
  29770. - public void fromImageFormatSucceedsWithSupportedImageFormat() {
  29771. - assertThat(ColorSpaceType.fromImageFormat(imageFormat)).isEqualTo(colorSpaceType);
  29772. - }
  29773. - }
  29774. -
  29775. - /** Parameterized tests for YUV image formats: NV12, NV21, YV12, YV21, YUV_420_888. */
  29776. - @RunWith(ParameterizedRobolectricTestRunner.class)
  29777. - public static final class YuvImageTest extends ColorSpaceTypeTest {
  29778. -
  29779. - @Parameter(0)
  29780. - public ColorSpaceType colorSpaceType;
  29781. -
  29782. - @Parameters(name = "colorSpaceType={0}")
  29783. - public static Collection<Object[]> data() {
  29784. - return Arrays.asList(
  29785. - new Object[][] {
  29786. - {ColorSpaceType.NV12},
  29787. - {ColorSpaceType.NV21},
  29788. - {ColorSpaceType.YV12},
  29789. - {ColorSpaceType.YV21},
  29790. - {ColorSpaceType.YUV_420_888},
  29791. - });
  29792. - }
  29793. -
  29794. - @Test
  29795. - public void convertTensorBufferToBitmapShouldFail() {
  29796. - UnsupportedOperationException exception =
  29797. - assertThrows(
  29798. - UnsupportedOperationException.class,
  29799. - () ->
  29800. - colorSpaceType.convertTensorBufferToBitmap(
  29801. - TensorBuffer.createDynamic(DataType.FLOAT32)));
  29802. - assertThat(exception)
  29803. - .hasMessageThat()
  29804. - .contains(
  29805. - "convertTensorBufferToBitmap() is unsupported for the color space type "
  29806. - + colorSpaceType.name());
  29807. - }
  29808. -
  29809. - @Test
  29810. - public void getWidthShouldFail() {
  29811. - UnsupportedOperationException exception =
  29812. - assertThrows(
  29813. - UnsupportedOperationException.class, () -> colorSpaceType.getWidth(new int[] {}));
  29814. - assertThat(exception)
  29815. - .hasMessageThat()
  29816. - .contains(
  29817. - "getWidth() only supports RGB and GRAYSCALE formats, but not "
  29818. - + colorSpaceType.name());
  29819. - }
  29820. -
  29821. - @Test
  29822. - public void getHeightShouldFail() {
  29823. - UnsupportedOperationException exception =
  29824. - assertThrows(
  29825. - UnsupportedOperationException.class, () -> colorSpaceType.getHeight(new int[] {}));
  29826. - assertThat(exception)
  29827. - .hasMessageThat()
  29828. - .contains(
  29829. - "getHeight() only supports RGB and GRAYSCALE formats, but not "
  29830. - + colorSpaceType.name());
  29831. - }
  29832. -
  29833. - @Test
  29834. - public void assertShapeShouldFail() {
  29835. - UnsupportedOperationException exception =
  29836. - assertThrows(
  29837. - UnsupportedOperationException.class, () -> colorSpaceType.assertShape(new int[] {}));
  29838. - assertThat(exception)
  29839. - .hasMessageThat()
  29840. - .contains(
  29841. - "assertShape() only supports RGB and GRAYSCALE formats, but not "
  29842. - + colorSpaceType.name());
  29843. - }
  29844. -
  29845. - @Test
  29846. - public void getChannelValueShouldFail() {
  29847. - UnsupportedOperationException exception =
  29848. - assertThrows(UnsupportedOperationException.class, () -> colorSpaceType.getChannelValue());
  29849. - assertThat(exception)
  29850. - .hasMessageThat()
  29851. - .contains(
  29852. - "getChannelValue() is unsupported for the color space type " + colorSpaceType.name());
  29853. - }
  29854. -
  29855. - @Test
  29856. - public void getNormalizedShapeShouldFail() {
  29857. - UnsupportedOperationException exception =
  29858. - assertThrows(
  29859. - UnsupportedOperationException.class,
  29860. - () -> colorSpaceType.getNormalizedShape(new int[] {}));
  29861. - assertThat(exception)
  29862. - .hasMessageThat()
  29863. - .contains(
  29864. - "getNormalizedShape() is unsupported for the color space type "
  29865. - + colorSpaceType.name());
  29866. - }
  29867. -
  29868. - @Test
  29869. - public void getShapeInfoMessageShouldFail() {
  29870. - UnsupportedOperationException exception =
  29871. - assertThrows(
  29872. - UnsupportedOperationException.class, () -> colorSpaceType.getShapeInfoMessage());
  29873. - assertThat(exception)
  29874. - .hasMessageThat()
  29875. - .contains(
  29876. - "getShapeInfoMessage() is unsupported for the color space type "
  29877. - + colorSpaceType.name());
  29878. - }
  29879. -
  29880. - @Test
  29881. - public void toBitmapConfigShouldFail() {
  29882. - UnsupportedOperationException exception =
  29883. - assertThrows(UnsupportedOperationException.class, () -> colorSpaceType.toBitmapConfig());
  29884. - assertThat(exception)
  29885. - .hasMessageThat()
  29886. - .contains(
  29887. - "toBitmapConfig() is unsupported for the color space type " + colorSpaceType.name());
  29888. - }
  29889. - }
  29890. -
  29891. - /** Parameterized tests for assertNumElements/getNumElements with all image formats. */
  29892. - @RunWith(ParameterizedRobolectricTestRunner.class)
  29893. - public static final class AssertNumElementsTest extends ColorSpaceTypeTest {
  29894. - private static final int HEIGHT = 2;
  29895. - private static final int WIDTH = 3;
  29896. - private static final int LESS_NUM_ELEMENTS = 5; // less than expected
  29897. - private static final int MORE_NUM_ELEMENTS = 20; // more than expected. OK.
  29898. - @Rule public ErrorCollector errorCollector = new ErrorCollector();
  29899. -
  29900. - @Parameter(0)
  29901. - public ColorSpaceType colorSpaceType;
  29902. -
  29903. - @Parameter(1)
  29904. - public int expectedNumElements;
  29905. -
  29906. - @Parameters(name = "colorSpaceType={0};expectedNumElements={1}")
  29907. - public static Collection<Object[]> data() {
  29908. - return Arrays.asList(
  29909. - new Object[][] {
  29910. - {ColorSpaceType.RGB, 18},
  29911. - {ColorSpaceType.GRAYSCALE, 6},
  29912. - {ColorSpaceType.NV12, 10},
  29913. - {ColorSpaceType.NV21, 10},
  29914. - {ColorSpaceType.YV12, 10},
  29915. - {ColorSpaceType.YV21, 10},
  29916. - });
  29917. - }
  29918. -
  29919. - @Test
  29920. - public void getNumElementsShouldSucceedWithExpectedNumElements() {
  29921. - assertThat(colorSpaceType.getNumElements(HEIGHT, WIDTH)).isEqualTo(expectedNumElements);
  29922. - }
  29923. -
  29924. - @Test
  29925. - public void assertNumElementsShouldSucceedWithMoreNumElements() {
  29926. - errorCollector.checkSucceeds(
  29927. - () -> {
  29928. - colorSpaceType.assertNumElements(MORE_NUM_ELEMENTS, HEIGHT, WIDTH);
  29929. - return null;
  29930. - });
  29931. - }
  29932. -
  29933. - @Test
  29934. - public void assertNumElementsShouldFailWithLessNumElements() {
  29935. - IllegalArgumentException exception =
  29936. - assertThrows(
  29937. - IllegalArgumentException.class,
  29938. - () -> colorSpaceType.assertNumElements(LESS_NUM_ELEMENTS, HEIGHT, WIDTH));
  29939. - assertThat(exception)
  29940. - .hasMessageThat()
  29941. - .contains(
  29942. - String.format(
  29943. - "The given number of elements (%d) does not match the image (%s) in %d x %d. The"
  29944. - + " expected number of elements should be at least %d.",
  29945. - LESS_NUM_ELEMENTS, colorSpaceType.name(), HEIGHT, WIDTH, expectedNumElements));
  29946. - }
  29947. - }
  29948. -
  29949. - /** General tests of ColorSpaceTypeTest. */
  29950. - @RunWith(RobolectricTestRunner.class)
  29951. - public static final class General extends ColorSpaceTypeTest {
  29952. -
  29953. - @Test
  29954. - public void convertTensorBufferToBitmapShouldSuccessWithRGB() {
  29955. - TensorBuffer buffer = createRgbTensorBuffer(DataType.UINT8, false);
  29956. - Bitmap bitmap = ColorSpaceType.RGB.convertTensorBufferToBitmap(buffer);
  29957. -
  29958. - Bitmap expectedBitmap = createRgbBitmap();
  29959. - assertThat(bitmap.sameAs(expectedBitmap)).isTrue();
  29960. + /** Parameterized tests for assertNumElements/getNumElements with all image formats. */
  29961. + @RunWith(ParameterizedRobolectricTestRunner.class)
  29962. + public static final class AssertNumElementsTest extends ColorSpaceTypeTest {
  29963. + private static final int HEIGHT = 2;
  29964. + private static final int WIDTH = 3;
  29965. + private static final int LESS_NUM_ELEMENTS = 5; // less than expected
  29966. + private static final int MORE_NUM_ELEMENTS = 20; // more than expected. OK.
  29967. + @Rule
  29968. + public ErrorCollector errorCollector = new ErrorCollector();
  29969. +
  29970. + @Parameter(0)
  29971. + public ColorSpaceType colorSpaceType;
  29972. +
  29973. + @Parameter(1)
  29974. + public int expectedNumElements;
  29975. +
  29976. + @Parameters(name = "colorSpaceType={0};expectedNumElements={1}")
  29977. + public static Collection<Object[]> data() {
  29978. + return Arrays.asList(new Object[][] {
  29979. + {ColorSpaceType.RGB, 18},
  29980. + {ColorSpaceType.GRAYSCALE, 6},
  29981. + {ColorSpaceType.NV12, 10},
  29982. + {ColorSpaceType.NV21, 10},
  29983. + {ColorSpaceType.YV12, 10},
  29984. + {ColorSpaceType.YV21, 10},
  29985. + });
  29986. + }
  29987. +
  29988. + @Test
  29989. + public void getNumElementsShouldSucceedWithExpectedNumElements() {
  29990. + assertThat(colorSpaceType.getNumElements(HEIGHT, WIDTH)).isEqualTo(expectedNumElements);
  29991. + }
  29992. +
  29993. + @Test
  29994. + public void assertNumElementsShouldSucceedWithMoreNumElements() {
  29995. + errorCollector.checkSucceeds(() -> {
  29996. + colorSpaceType.assertNumElements(MORE_NUM_ELEMENTS, HEIGHT, WIDTH);
  29997. + return null;
  29998. + });
  29999. + }
  30000. +
  30001. + @Test
  30002. + public void assertNumElementsShouldFailWithLessNumElements() {
  30003. + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
  30004. + () -> colorSpaceType.assertNumElements(LESS_NUM_ELEMENTS, HEIGHT, WIDTH));
  30005. + assertThat(exception).hasMessageThat().contains(String.format(
  30006. + "The given number of elements (%d) does not match the image (%s) in %d x %d. The"
  30007. + + " expected number of elements should be at least %d.",
  30008. + LESS_NUM_ELEMENTS, colorSpaceType.name(), HEIGHT, WIDTH, expectedNumElements));
  30009. + }
  30010. }
  30011. - @Test
  30012. - public void fromBitmapConfigFailsWithUnsupportedConfig() {
  30013. - Config unsupportedConfig = Config.ARGB_4444;
  30014. - IllegalArgumentException exception =
  30015. - assertThrows(
  30016. - IllegalArgumentException.class,
  30017. - () -> ColorSpaceType.fromBitmapConfig(unsupportedConfig));
  30018. - assertThat(exception)
  30019. - .hasMessageThat()
  30020. - .contains("Bitmap configuration: " + unsupportedConfig + ", is not supported yet.");
  30021. + /** General tests of ColorSpaceTypeTest. */
  30022. + @RunWith(RobolectricTestRunner.class)
  30023. + public static final class General extends ColorSpaceTypeTest {
  30024. + @Test
  30025. + public void convertTensorBufferToBitmapShouldSuccessWithRGB() {
  30026. + TensorBuffer buffer = createRgbTensorBuffer(DataType.UINT8, false);
  30027. + Bitmap bitmap = ColorSpaceType.RGB.convertTensorBufferToBitmap(buffer);
  30028. +
  30029. + Bitmap expectedBitmap = createRgbBitmap();
  30030. + assertThat(bitmap.sameAs(expectedBitmap)).isTrue();
  30031. + }
  30032. +
  30033. + @Test
  30034. + public void fromBitmapConfigFailsWithUnsupportedConfig() {
  30035. + Config unsupportedConfig = Config.ARGB_4444;
  30036. + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
  30037. + () -> ColorSpaceType.fromBitmapConfig(unsupportedConfig));
  30038. + assertThat(exception).hasMessageThat().contains(
  30039. + "Bitmap configuration: " + unsupportedConfig + ", is not supported yet.");
  30040. + }
  30041. }
  30042. - }
  30043. }
  30044. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ImageConversionsInstrumentedTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ImageConversionsInstrumentedTest.java
  30045. index 1a4d367bf0fe1..49efc4273911c 100644
  30046. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ImageConversionsInstrumentedTest.java
  30047. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ImageConversionsInstrumentedTest.java
  30048. @@ -21,7 +21,9 @@ import static android.graphics.Color.BLUE;
  30049. import static android.graphics.Color.GREEN;
  30050. import static android.graphics.Color.RED;
  30051. import static android.graphics.Color.WHITE;
  30052. +
  30053. import static com.google.common.truth.Truth.assertThat;
  30054. +
  30055. import static org.junit.Assert.assertThrows;
  30056. import static org.tensorflow.lite.support.image.ImageConversions.convertGrayscaleTensorBufferToBitmap;
  30057. @@ -30,10 +32,10 @@ import android.content.res.AssetManager;
  30058. import android.graphics.Bitmap;
  30059. import android.graphics.BitmapFactory;
  30060. import android.util.Log;
  30061. +
  30062. import androidx.test.core.app.ApplicationProvider;
  30063. import androidx.test.ext.junit.runners.AndroidJUnit4;
  30064. -import java.io.IOException;
  30065. -import java.util.Arrays;
  30066. +
  30067. import org.junit.Assert;
  30068. import org.junit.Before;
  30069. import org.junit.Test;
  30070. @@ -43,192 +45,190 @@ import org.junit.runners.Suite.SuiteClasses;
  30071. import org.tensorflow.lite.DataType;
  30072. import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
  30073. +import java.io.IOException;
  30074. +import java.util.Arrays;
  30075. +
  30076. /** Instrumented unit test for {@link ImageConversions}. */
  30077. @RunWith(Suite.class)
  30078. -@SuiteClasses({
  30079. - ImageConversionsInstrumentedTest.TensorBufferToBitmap.class,
  30080. - ImageConversionsInstrumentedTest.BitmapToTensorBuffer.class
  30081. -})
  30082. +@SuiteClasses({ImageConversionsInstrumentedTest.TensorBufferToBitmap.class,
  30083. + ImageConversionsInstrumentedTest.BitmapToTensorBuffer.class})
  30084. public class ImageConversionsInstrumentedTest {
  30085. + /** Tests for the TensorBuffer data type and normalized form. */
  30086. + // Note that parameterized test with android_library_instrumentation_tests is currently not
  30087. + // supported internally.
  30088. + @RunWith(AndroidJUnit4.class)
  30089. + public static final class TensorBufferToBitmap extends ImageConversionsInstrumentedTest {
  30090. + @Test
  30091. + public void convertGrayscaleTensorBufferToBitmapShouldSuccessWithFloatNormalized() {
  30092. + DataType dataType = DataType.FLOAT32;
  30093. + boolean isNormalized = true;
  30094. +
  30095. + TensorBuffer buffer =
  30096. + TestImageCreator.createGrayscaleTensorBuffer(dataType, isNormalized);
  30097. + Bitmap bitmap = convertGrayscaleTensorBufferToBitmap(buffer);
  30098. +
  30099. + Bitmap expectedBitmap = TestImageCreator.createGrayscaleBitmap();
  30100. + assertThat(bitmap.sameAs(expectedBitmap)).isTrue();
  30101. + }
  30102. - /** Tests for the TensorBuffer data type and normalized form. */
  30103. - // Note that parameterized test with android_library_instrumentation_tests is currently not
  30104. - // supported internally.
  30105. - @RunWith(AndroidJUnit4.class)
  30106. - public static final class TensorBufferToBitmap extends ImageConversionsInstrumentedTest {
  30107. -
  30108. - @Test
  30109. - public void convertGrayscaleTensorBufferToBitmapShouldSuccessWithFloatNormalized() {
  30110. - DataType dataType = DataType.FLOAT32;
  30111. - boolean isNormalized = true;
  30112. + @Test
  30113. + public void convertGrayscaleTensorBufferToBitmapShouldSuccessWithFloatUnnormalized() {
  30114. + DataType dataType = DataType.FLOAT32;
  30115. + boolean isNormalized = false;
  30116. - TensorBuffer buffer = TestImageCreator.createGrayscaleTensorBuffer(dataType, isNormalized);
  30117. - Bitmap bitmap = convertGrayscaleTensorBufferToBitmap(buffer);
  30118. + TensorBuffer buffer =
  30119. + TestImageCreator.createGrayscaleTensorBuffer(dataType, isNormalized);
  30120. + Bitmap bitmap = convertGrayscaleTensorBufferToBitmap(buffer);
  30121. - Bitmap expectedBitmap = TestImageCreator.createGrayscaleBitmap();
  30122. - assertThat(bitmap.sameAs(expectedBitmap)).isTrue();
  30123. - }
  30124. -
  30125. - @Test
  30126. - public void convertGrayscaleTensorBufferToBitmapShouldSuccessWithFloatUnnormalized() {
  30127. - DataType dataType = DataType.FLOAT32;
  30128. - boolean isNormalized = false;
  30129. + Bitmap expectedBitmap = TestImageCreator.createGrayscaleBitmap();
  30130. + assertThat(bitmap.sameAs(expectedBitmap)).isTrue();
  30131. + }
  30132. - TensorBuffer buffer = TestImageCreator.createGrayscaleTensorBuffer(dataType, isNormalized);
  30133. - Bitmap bitmap = convertGrayscaleTensorBufferToBitmap(buffer);
  30134. + @Test
  30135. + public void convertGrayscaleTensorBufferToBitmapShouldSuccessWithUint8Normalized() {
  30136. + DataType dataType = DataType.UINT8;
  30137. + boolean isNormalized = true;
  30138. - Bitmap expectedBitmap = TestImageCreator.createGrayscaleBitmap();
  30139. - assertThat(bitmap.sameAs(expectedBitmap)).isTrue();
  30140. - }
  30141. + TensorBuffer buffer =
  30142. + TestImageCreator.createGrayscaleTensorBuffer(dataType, isNormalized);
  30143. + Bitmap bitmap = convertGrayscaleTensorBufferToBitmap(buffer);
  30144. - @Test
  30145. - public void convertGrayscaleTensorBufferToBitmapShouldSuccessWithUint8Normalized() {
  30146. - DataType dataType = DataType.UINT8;
  30147. - boolean isNormalized = true;
  30148. -
  30149. - TensorBuffer buffer = TestImageCreator.createGrayscaleTensorBuffer(dataType, isNormalized);
  30150. - Bitmap bitmap = convertGrayscaleTensorBufferToBitmap(buffer);
  30151. + Bitmap expectedBitmap = TestImageCreator.createGrayscaleBitmap();
  30152. + assertThat(bitmap.sameAs(expectedBitmap)).isTrue();
  30153. + }
  30154. - Bitmap expectedBitmap = TestImageCreator.createGrayscaleBitmap();
  30155. - assertThat(bitmap.sameAs(expectedBitmap)).isTrue();
  30156. - }
  30157. + @Test
  30158. + public void convertGrayscaleTensorBufferToBitmapShouldSuccessWithUint8Unnormalized() {
  30159. + DataType dataType = DataType.UINT8;
  30160. + boolean isNormalized = false;
  30161. - @Test
  30162. - public void convertGrayscaleTensorBufferToBitmapShouldSuccessWithUint8Unnormalized() {
  30163. - DataType dataType = DataType.UINT8;
  30164. - boolean isNormalized = false;
  30165. + TensorBuffer buffer =
  30166. + TestImageCreator.createGrayscaleTensorBuffer(dataType, isNormalized);
  30167. + Bitmap bitmap = convertGrayscaleTensorBufferToBitmap(buffer);
  30168. - TensorBuffer buffer = TestImageCreator.createGrayscaleTensorBuffer(dataType, isNormalized);
  30169. - Bitmap bitmap = convertGrayscaleTensorBufferToBitmap(buffer);
  30170. + Bitmap expectedBitmap = TestImageCreator.createGrayscaleBitmap();
  30171. + assertThat(bitmap.sameAs(expectedBitmap)).isTrue();
  30172. + }
  30173. - Bitmap expectedBitmap = TestImageCreator.createGrayscaleBitmap();
  30174. - assertThat(bitmap.sameAs(expectedBitmap)).isTrue();
  30175. - }
  30176. + @Test
  30177. + public void
  30178. + convertGrayscaleTensorBufferToBitmapShouldRejectBufferWithInvalidShapeWithFloat() {
  30179. + DataType dataType = DataType.FLOAT32;
  30180. + TensorBuffer buffer = TensorBuffer.createFixedSize(new int[] {2, 5, 10}, dataType);
  30181. +
  30182. + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
  30183. + () -> convertGrayscaleTensorBufferToBitmap(buffer));
  30184. + assertThat(exception).hasMessageThat().contains(
  30185. + "The shape of a grayscale image should be (h, w) or (1, h, w, 1). The provided image"
  30186. + + " shape is " + Arrays.toString(buffer.getShape()));
  30187. + }
  30188. - @Test
  30189. - public void convertGrayscaleTensorBufferToBitmapShouldRejectBufferWithInvalidShapeWithFloat() {
  30190. - DataType dataType = DataType.FLOAT32;
  30191. - TensorBuffer buffer = TensorBuffer.createFixedSize(new int[] {2, 5, 10}, dataType);
  30192. -
  30193. - IllegalArgumentException exception =
  30194. - assertThrows(
  30195. - IllegalArgumentException.class, () -> convertGrayscaleTensorBufferToBitmap(buffer));
  30196. - assertThat(exception)
  30197. - .hasMessageThat()
  30198. - .contains(
  30199. - "The shape of a grayscale image should be (h, w) or (1, h, w, 1). The provided image"
  30200. - + " shape is "
  30201. - + Arrays.toString(buffer.getShape()));
  30202. + @Test
  30203. + public void
  30204. + convertGrayscaleTensorBufferToBitmapShouldRejectBufferWithInvalidShapeWithUint8() {
  30205. + DataType dataType = DataType.UINT8;
  30206. + TensorBuffer buffer = TensorBuffer.createFixedSize(new int[] {2, 5, 10}, dataType);
  30207. +
  30208. + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
  30209. + () -> convertGrayscaleTensorBufferToBitmap(buffer));
  30210. + assertThat(exception).hasMessageThat().contains(
  30211. + "The shape of a grayscale image should be (h, w) or (1, h, w, 1). The provided image"
  30212. + + " shape is " + Arrays.toString(buffer.getShape()));
  30213. + }
  30214. }
  30215. - @Test
  30216. - public void convertGrayscaleTensorBufferToBitmapShouldRejectBufferWithInvalidShapeWithUint8() {
  30217. - DataType dataType = DataType.UINT8;
  30218. - TensorBuffer buffer = TensorBuffer.createFixedSize(new int[] {2, 5, 10}, dataType);
  30219. -
  30220. - IllegalArgumentException exception =
  30221. - assertThrows(
  30222. - IllegalArgumentException.class, () -> convertGrayscaleTensorBufferToBitmap(buffer));
  30223. - assertThat(exception)
  30224. - .hasMessageThat()
  30225. - .contains(
  30226. - "The shape of a grayscale image should be (h, w) or (1, h, w, 1). The provided image"
  30227. - + " shape is "
  30228. - + Arrays.toString(buffer.getShape()));
  30229. - }
  30230. - }
  30231. -
  30232. - /** BitmapToTensorBuffer tests of ImageConversionsInstrumentedTest. */
  30233. - @RunWith(AndroidJUnit4.class)
  30234. - public static final class BitmapToTensorBuffer extends ImageConversionsInstrumentedTest {
  30235. -
  30236. - private Bitmap greyGrid;
  30237. - private Bitmap colorGrid;
  30238. - private TensorBuffer buffer;
  30239. -
  30240. - static final String GREY_GRID_PATH = "grey_grid.png";
  30241. - static final String COLOR_GRID_PATH = "color_grid.png";
  30242. -
  30243. - @Before
  30244. - public void loadAssets() {
  30245. - Context context = ApplicationProvider.getApplicationContext();
  30246. - AssetManager assetManager = context.getAssets();
  30247. - try {
  30248. - greyGrid = BitmapFactory.decodeStream(assetManager.open(GREY_GRID_PATH));
  30249. - colorGrid = BitmapFactory.decodeStream(assetManager.open(COLOR_GRID_PATH));
  30250. - } catch (IOException e) {
  30251. - Log.e("Test", "Cannot load asset files");
  30252. - }
  30253. - Assert.assertEquals(ARGB_8888, greyGrid.getConfig());
  30254. - Assert.assertEquals(ARGB_8888, colorGrid.getConfig());
  30255. - buffer = TensorBuffer.createDynamic(DataType.UINT8);
  30256. - }
  30257. + /** BitmapToTensorBuffer tests of ImageConversionsInstrumentedTest. */
  30258. + @RunWith(AndroidJUnit4.class)
  30259. + public static final class BitmapToTensorBuffer extends ImageConversionsInstrumentedTest {
  30260. + private Bitmap greyGrid;
  30261. + private Bitmap colorGrid;
  30262. + private TensorBuffer buffer;
  30263. +
  30264. + static final String GREY_GRID_PATH = "grey_grid.png";
  30265. + static final String COLOR_GRID_PATH = "color_grid.png";
  30266. +
  30267. + @Before
  30268. + public void loadAssets() {
  30269. + Context context = ApplicationProvider.getApplicationContext();
  30270. + AssetManager assetManager = context.getAssets();
  30271. + try {
  30272. + greyGrid = BitmapFactory.decodeStream(assetManager.open(GREY_GRID_PATH));
  30273. + colorGrid = BitmapFactory.decodeStream(assetManager.open(COLOR_GRID_PATH));
  30274. + } catch (IOException e) {
  30275. + Log.e("Test", "Cannot load asset files");
  30276. + }
  30277. + Assert.assertEquals(ARGB_8888, greyGrid.getConfig());
  30278. + Assert.assertEquals(ARGB_8888, colorGrid.getConfig());
  30279. + buffer = TensorBuffer.createDynamic(DataType.UINT8);
  30280. + }
  30281. - @Test
  30282. - public void testBitmapDimensionLayout() {
  30283. - // This test is not only for proving the correctness of bitmap -> TensorBuffer conversion, but
  30284. - // also for us to better understand how Android Bitmap is storing pixels - height first or
  30285. - // width first.
  30286. - // We use a black image which has a white corner to understand what happens. By setting up the
  30287. - // correct loop to pass the test, we can reveal the order of pixels returned from `getPixels`.
  30288. - // The result shows that Android stores bitmap in an h-first manner. The returned array of
  30289. - // `getPixels` is like [ 1st row, 2nd row, ... ] which is the same with TFLite.
  30290. - Assert.assertEquals(100, greyGrid.getWidth());
  30291. - Assert.assertEquals(100, greyGrid.getHeight());
  30292. - Assert.assertEquals(BLACK, greyGrid.getPixel(25, 25)); // left top
  30293. - Assert.assertEquals(BLACK, greyGrid.getPixel(75, 25)); // right top
  30294. - Assert.assertEquals(WHITE, greyGrid.getPixel(25, 75)); // left bottom
  30295. - Assert.assertEquals(BLACK, greyGrid.getPixel(75, 75)); // right bottom
  30296. -
  30297. - ImageConversions.convertBitmapToTensorBuffer(greyGrid, buffer);
  30298. - Assert.assertArrayEquals(new int[] {100, 100, 3}, buffer.getShape());
  30299. - Assert.assertEquals(DataType.UINT8, buffer.getDataType());
  30300. -
  30301. - int[] pixels = buffer.getIntArray();
  30302. - int index = 0;
  30303. - for (int h = 0; h < 100; h++) {
  30304. - for (int w = 0; w < 100; w++) {
  30305. - int expected = (w < 50 && h >= 50) ? 255 : 0;
  30306. - Assert.assertEquals(expected, pixels[index++]);
  30307. - Assert.assertEquals(expected, pixels[index++]);
  30308. - Assert.assertEquals(expected, pixels[index++]);
  30309. + @Test
  30310. + public void testBitmapDimensionLayout() {
  30311. + // This test is not only for proving the correctness of bitmap -> TensorBuffer
  30312. + // conversion, but also for us to better understand how Android Bitmap is storing pixels
  30313. + // - height first or width first. We use a black image which has a white corner to
  30314. + // understand what happens. By setting up the correct loop to pass the test, we can
  30315. + // reveal the order of pixels returned from `getPixels`. The result shows that Android
  30316. + // stores bitmap in an h-first manner. The returned array of `getPixels` is like [ 1st
  30317. + // row, 2nd row, ... ] which is the same with TFLite.
  30318. + Assert.assertEquals(100, greyGrid.getWidth());
  30319. + Assert.assertEquals(100, greyGrid.getHeight());
  30320. + Assert.assertEquals(BLACK, greyGrid.getPixel(25, 25)); // left top
  30321. + Assert.assertEquals(BLACK, greyGrid.getPixel(75, 25)); // right top
  30322. + Assert.assertEquals(WHITE, greyGrid.getPixel(25, 75)); // left bottom
  30323. + Assert.assertEquals(BLACK, greyGrid.getPixel(75, 75)); // right bottom
  30324. +
  30325. + ImageConversions.convertBitmapToTensorBuffer(greyGrid, buffer);
  30326. + Assert.assertArrayEquals(new int[] {100, 100, 3}, buffer.getShape());
  30327. + Assert.assertEquals(DataType.UINT8, buffer.getDataType());
  30328. +
  30329. + int[] pixels = buffer.getIntArray();
  30330. + int index = 0;
  30331. + for (int h = 0; h < 100; h++) {
  30332. + for (int w = 0; w < 100; w++) {
  30333. + int expected = (w < 50 && h >= 50) ? 255 : 0;
  30334. + Assert.assertEquals(expected, pixels[index++]);
  30335. + Assert.assertEquals(expected, pixels[index++]);
  30336. + Assert.assertEquals(expected, pixels[index++]);
  30337. + }
  30338. + }
  30339. }
  30340. - }
  30341. - }
  30342. - @Test
  30343. - public void testBitmapARGB8888ChannelLayout() {
  30344. - // This test is not only for proving the correctness of bitmap -> TensorBuffer conversion, but
  30345. - // also for us to better understand how Android Bitmap is storing pixels - RGB channel or
  30346. - // other possible ordering.
  30347. - // We use an colored grid image to understand what happens. It's a simple grid image with 4
  30348. - // grid in different colors. Passed through our Bitmap -> TensorBuffer conversion which simply
  30349. - // unpack channels from an integer returned from `getPixel`, its channel sequence could be
  30350. - // revealed directly.
  30351. - // The result shows that Android Bitmap has no magic when loading channels. If loading from
  30352. - // PNG images, channel order still remains R-G-B.
  30353. - Assert.assertEquals(100, colorGrid.getWidth());
  30354. - Assert.assertEquals(100, colorGrid.getHeight());
  30355. - Assert.assertEquals(BLUE, colorGrid.getPixel(25, 25)); // left top
  30356. - Assert.assertEquals(BLACK, colorGrid.getPixel(75, 25)); // right top
  30357. - Assert.assertEquals(GREEN, colorGrid.getPixel(25, 75)); // left bottom
  30358. - Assert.assertEquals(RED, colorGrid.getPixel(75, 75)); // right bottom
  30359. -
  30360. - ImageConversions.convertBitmapToTensorBuffer(colorGrid, buffer);
  30361. - Assert.assertArrayEquals(new int[] {100, 100, 3}, buffer.getShape());
  30362. - Assert.assertEquals(DataType.UINT8, buffer.getDataType());
  30363. -
  30364. - int[] pixels = buffer.getIntArray();
  30365. - Assert.assertArrayEquals(new int[] {0, 0, 255}, getChannels(pixels, 25, 25)); // left top
  30366. - Assert.assertArrayEquals(new int[] {0, 0, 0}, getChannels(pixels, 25, 75)); // right top
  30367. - Assert.assertArrayEquals(new int[] {0, 255, 0}, getChannels(pixels, 75, 25)); // left bottom
  30368. - Assert.assertArrayEquals(new int[] {255, 0, 0}, getChannels(pixels, 75, 75)); // right bottom
  30369. - }
  30370. + @Test
  30371. + public void testBitmapARGB8888ChannelLayout() {
  30372. + // This test is not only for proving the correctness of bitmap -> TensorBuffer
  30373. + // conversion, but also for us to better understand how Android Bitmap is storing pixels
  30374. + // - RGB channel or other possible ordering. We use an colored grid image to understand
  30375. + // what happens. It's a simple grid image with 4 grid in different colors. Passed
  30376. + // through our Bitmap -> TensorBuffer conversion which simply unpack channels from an
  30377. + // integer returned from `getPixel`, its channel sequence could be revealed directly.
  30378. + // The result shows that Android Bitmap has no magic when loading channels. If loading
  30379. + // from PNG images, channel order still remains R-G-B.
  30380. + Assert.assertEquals(100, colorGrid.getWidth());
  30381. + Assert.assertEquals(100, colorGrid.getHeight());
  30382. + Assert.assertEquals(BLUE, colorGrid.getPixel(25, 25)); // left top
  30383. + Assert.assertEquals(BLACK, colorGrid.getPixel(75, 25)); // right top
  30384. + Assert.assertEquals(GREEN, colorGrid.getPixel(25, 75)); // left bottom
  30385. + Assert.assertEquals(RED, colorGrid.getPixel(75, 75)); // right bottom
  30386. +
  30387. + ImageConversions.convertBitmapToTensorBuffer(colorGrid, buffer);
  30388. + Assert.assertArrayEquals(new int[] {100, 100, 3}, buffer.getShape());
  30389. + Assert.assertEquals(DataType.UINT8, buffer.getDataType());
  30390. +
  30391. + int[] pixels = buffer.getIntArray();
  30392. + Assert.assertArrayEquals(
  30393. + new int[] {0, 0, 255}, getChannels(pixels, 25, 25)); // left top
  30394. + Assert.assertArrayEquals(new int[] {0, 0, 0}, getChannels(pixels, 25, 75)); // right top
  30395. + Assert.assertArrayEquals(
  30396. + new int[] {0, 255, 0}, getChannels(pixels, 75, 25)); // left bottom
  30397. + Assert.assertArrayEquals(
  30398. + new int[] {255, 0, 0}, getChannels(pixels, 75, 75)); // right bottom
  30399. + }
  30400. - /** Helper function only for {@link #testBitmapARGB8888ChannelLayout()}. */
  30401. - private static int[] getChannels(int[] pixels, int h, int w) {
  30402. - int id = (h * 100 + w) * 3;
  30403. - return new int[] {pixels[id++], pixels[id++], pixels[id]};
  30404. + /** Helper function only for {@link #testBitmapARGB8888ChannelLayout()}. */
  30405. + private static int[] getChannels(int[] pixels, int h, int w) {
  30406. + int id = (h * 100 + w) * 3;
  30407. + return new int[] {pixels[id++], pixels[id++], pixels[id]};
  30408. + }
  30409. }
  30410. - }
  30411. }
  30412. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ImageConversionsTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ImageConversionsTest.java
  30413. index b3300872c2357..c91db9d184f63 100644
  30414. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ImageConversionsTest.java
  30415. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ImageConversionsTest.java
  30416. @@ -16,13 +16,13 @@ limitations under the License.
  30417. package org.tensorflow.lite.support.image;
  30418. import static com.google.common.truth.Truth.assertThat;
  30419. +
  30420. import static org.junit.Assert.assertThrows;
  30421. import static org.tensorflow.lite.support.image.ImageConversions.convertBitmapToTensorBuffer;
  30422. import static org.tensorflow.lite.support.image.ImageConversions.convertRgbTensorBufferToBitmap;
  30423. import android.graphics.Bitmap;
  30424. -import java.util.Arrays;
  30425. -import java.util.Collection;
  30426. +
  30427. import org.junit.Assert;
  30428. import org.junit.Test;
  30429. import org.junit.runner.RunWith;
  30430. @@ -35,93 +35,93 @@ import org.robolectric.RobolectricTestRunner;
  30431. import org.tensorflow.lite.DataType;
  30432. import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
  30433. +import java.util.Arrays;
  30434. +import java.util.Collection;
  30435. +
  30436. /** Tests of {@link ImageConversions}. */
  30437. @RunWith(Suite.class)
  30438. @SuiteClasses({ImageConversionsTest.TensorBufferToBitmap.class, ImageConversionsTest.General.class})
  30439. public class ImageConversionsTest {
  30440. -
  30441. - /** Parameterized tests for the TensorBuffer data type and normalized form. */
  30442. - @RunWith(ParameterizedRobolectricTestRunner.class)
  30443. - public static final class TensorBufferToBitmap extends ImageConversionsTest {
  30444. -
  30445. - /** The data type that used to create the TensorBuffer. */
  30446. - @Parameter(0)
  30447. - public DataType dataType;
  30448. -
  30449. - /** Indicates whether the shape is in the normalized form of (1, h, w, 3). */
  30450. - @Parameter(1)
  30451. - public boolean isNormalized;
  30452. -
  30453. - @Parameters(name = "dataType={0}; isNormalized={1}")
  30454. - public static Collection<Object[]> data() {
  30455. - return Arrays.asList(
  30456. - new Object[][] {
  30457. - {DataType.FLOAT32, true}, {DataType.UINT8, true},
  30458. - {DataType.FLOAT32, false}, {DataType.UINT8, false},
  30459. - });
  30460. - }
  30461. -
  30462. - @Test
  30463. - public void convertRgbTensorBufferToBitmapShouldSuccess() {
  30464. - TensorBuffer buffer = TestImageCreator.createRgbTensorBuffer(dataType, isNormalized);
  30465. - Bitmap bitmap = convertRgbTensorBufferToBitmap(buffer);
  30466. -
  30467. - Bitmap expectedBitmap = TestImageCreator.createRgbBitmap();
  30468. - assertThat(bitmap.sameAs(expectedBitmap)).isTrue();
  30469. - }
  30470. -
  30471. - @Test
  30472. - public void convertRgbTensorBufferToBitmapShouldRejectBufferWithInvalidShape() {
  30473. - TensorBuffer buffer = TensorBuffer.createFixedSize(new int[] {2, 5, 10, 3}, dataType);
  30474. -
  30475. - IllegalArgumentException exception =
  30476. - assertThrows(
  30477. - IllegalArgumentException.class, () -> convertRgbTensorBufferToBitmap(buffer));
  30478. - assertThat(exception)
  30479. - .hasMessageThat()
  30480. - .contains(
  30481. - "The shape of a RGB image should be (h, w, c) or (1, h, w, c), and channels"
  30482. - + " representing R, G, B in order. The provided image shape is "
  30483. - + Arrays.toString(buffer.getShape()));
  30484. - }
  30485. - }
  30486. -
  30487. - /** General tests of ImageConversionsTest. */
  30488. - @RunWith(RobolectricTestRunner.class)
  30489. - public static final class General extends ImageConversionsTest {
  30490. -
  30491. - private static final Bitmap rgbBitmap = TestImageCreator.createRgbBitmap();
  30492. - private static final TensorBuffer rgbTensorBuffer =
  30493. - TestImageCreator.createRgbTensorBuffer(DataType.UINT8, false);
  30494. -
  30495. - @Test
  30496. - public void convertBitmapToTensorBufferShouldSuccess() {
  30497. - TensorBuffer intBuffer = TensorBuffer.createFixedSize(new int[] {10, 10, 3}, DataType.UINT8);
  30498. - convertBitmapToTensorBuffer(rgbBitmap, intBuffer);
  30499. - assertThat(areEqualIntTensorBuffer(intBuffer, rgbTensorBuffer)).isTrue();
  30500. - }
  30501. -
  30502. - @Test
  30503. - public void convertBitmapToTensorBufferShouldThrowShapeNotExactlySame() {
  30504. - TensorBuffer intBuffer = TensorBuffer.createFixedSize(new int[] {5, 20, 3}, DataType.UINT8);
  30505. - Assert.assertThrows(
  30506. - IllegalArgumentException.class, () -> convertBitmapToTensorBuffer(rgbBitmap, intBuffer));
  30507. + /** Parameterized tests for the TensorBuffer data type and normalized form. */
  30508. + @RunWith(ParameterizedRobolectricTestRunner.class)
  30509. + public static final class TensorBufferToBitmap extends ImageConversionsTest {
  30510. + /** The data type that used to create the TensorBuffer. */
  30511. + @Parameter(0)
  30512. + public DataType dataType;
  30513. +
  30514. + /** Indicates whether the shape is in the normalized form of (1, h, w, 3). */
  30515. + @Parameter(1)
  30516. + public boolean isNormalized;
  30517. +
  30518. + @Parameters(name = "dataType={0}; isNormalized={1}")
  30519. + public static Collection<Object[]> data() {
  30520. + return Arrays.asList(new Object[][] {
  30521. + {DataType.FLOAT32, true},
  30522. + {DataType.UINT8, true},
  30523. + {DataType.FLOAT32, false},
  30524. + {DataType.UINT8, false},
  30525. + });
  30526. + }
  30527. +
  30528. + @Test
  30529. + public void convertRgbTensorBufferToBitmapShouldSuccess() {
  30530. + TensorBuffer buffer = TestImageCreator.createRgbTensorBuffer(dataType, isNormalized);
  30531. + Bitmap bitmap = convertRgbTensorBufferToBitmap(buffer);
  30532. +
  30533. + Bitmap expectedBitmap = TestImageCreator.createRgbBitmap();
  30534. + assertThat(bitmap.sameAs(expectedBitmap)).isTrue();
  30535. + }
  30536. +
  30537. + @Test
  30538. + public void convertRgbTensorBufferToBitmapShouldRejectBufferWithInvalidShape() {
  30539. + TensorBuffer buffer = TensorBuffer.createFixedSize(new int[] {2, 5, 10, 3}, dataType);
  30540. +
  30541. + IllegalArgumentException exception = assertThrows(
  30542. + IllegalArgumentException.class, () -> convertRgbTensorBufferToBitmap(buffer));
  30543. + assertThat(exception).hasMessageThat().contains(
  30544. + "The shape of a RGB image should be (h, w, c) or (1, h, w, c), and channels"
  30545. + + " representing R, G, B in order. The provided image shape is "
  30546. + + Arrays.toString(buffer.getShape()));
  30547. + }
  30548. }
  30549. - @Test
  30550. - public void convertBitmapToTensorBufferShouldCastIntToFloatIfNeeded() {
  30551. - TensorBuffer floatBuffer = TensorBuffer.createDynamic(DataType.FLOAT32);
  30552. - convertBitmapToTensorBuffer(rgbBitmap, floatBuffer);
  30553. - assertThat(areEqualIntTensorBuffer(floatBuffer, rgbTensorBuffer)).isTrue();
  30554. + /** General tests of ImageConversionsTest. */
  30555. + @RunWith(RobolectricTestRunner.class)
  30556. + public static final class General extends ImageConversionsTest {
  30557. + private static final Bitmap rgbBitmap = TestImageCreator.createRgbBitmap();
  30558. + private static final TensorBuffer rgbTensorBuffer =
  30559. + TestImageCreator.createRgbTensorBuffer(DataType.UINT8, false);
  30560. +
  30561. + @Test
  30562. + public void convertBitmapToTensorBufferShouldSuccess() {
  30563. + TensorBuffer intBuffer =
  30564. + TensorBuffer.createFixedSize(new int[] {10, 10, 3}, DataType.UINT8);
  30565. + convertBitmapToTensorBuffer(rgbBitmap, intBuffer);
  30566. + assertThat(areEqualIntTensorBuffer(intBuffer, rgbTensorBuffer)).isTrue();
  30567. + }
  30568. +
  30569. + @Test
  30570. + public void convertBitmapToTensorBufferShouldThrowShapeNotExactlySame() {
  30571. + TensorBuffer intBuffer =
  30572. + TensorBuffer.createFixedSize(new int[] {5, 20, 3}, DataType.UINT8);
  30573. + Assert.assertThrows(IllegalArgumentException.class,
  30574. + () -> convertBitmapToTensorBuffer(rgbBitmap, intBuffer));
  30575. + }
  30576. +
  30577. + @Test
  30578. + public void convertBitmapToTensorBufferShouldCastIntToFloatIfNeeded() {
  30579. + TensorBuffer floatBuffer = TensorBuffer.createDynamic(DataType.FLOAT32);
  30580. + convertBitmapToTensorBuffer(rgbBitmap, floatBuffer);
  30581. + assertThat(areEqualIntTensorBuffer(floatBuffer, rgbTensorBuffer)).isTrue();
  30582. + }
  30583. }
  30584. - }
  30585. - private static boolean areEqualIntTensorBuffer(TensorBuffer tb1, TensorBuffer tb2) {
  30586. - if (!Arrays.equals(tb1.getShape(), tb2.getShape())) {
  30587. - return false;
  30588. + private static boolean areEqualIntTensorBuffer(TensorBuffer tb1, TensorBuffer tb2) {
  30589. + if (!Arrays.equals(tb1.getShape(), tb2.getShape())) {
  30590. + return false;
  30591. + }
  30592. + int[] arr1 = tb1.getIntArray();
  30593. + int[] arr2 = tb2.getIntArray();
  30594. + return Arrays.equals(arr1, arr2);
  30595. }
  30596. - int[] arr1 = tb1.getIntArray();
  30597. - int[] arr2 = tb2.getIntArray();
  30598. - return Arrays.equals(arr1, arr2);
  30599. - }
  30600. }
  30601. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ImageProcessorInstrumentedTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ImageProcessorInstrumentedTest.java
  30602. index 8ac27fdb07ad1..e9cbfc1dc50bd 100644
  30603. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ImageProcessorInstrumentedTest.java
  30604. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ImageProcessorInstrumentedTest.java
  30605. @@ -16,10 +16,13 @@ limitations under the License.
  30606. package org.tensorflow.lite.support.image;
  30607. import static com.google.common.truth.Truth.assertThat;
  30608. +
  30609. import static org.junit.Assert.assertThrows;
  30610. import android.graphics.Bitmap;
  30611. +
  30612. import androidx.test.ext.junit.runners.AndroidJUnit4;
  30613. +
  30614. import org.junit.Before;
  30615. import org.junit.Test;
  30616. import org.junit.runner.RunWith;
  30617. @@ -30,120 +33,114 @@ import org.tensorflow.lite.support.image.ops.Rot90Op;
  30618. /** Instrumented unit test for {@link ImageProcessor}. */
  30619. @RunWith(AndroidJUnit4.class)
  30620. public final class ImageProcessorInstrumentedTest {
  30621. + private Bitmap exampleBitmap;
  30622. + private TensorImage input;
  30623. + private ImageProcessor processor;
  30624. +
  30625. + private static final int EXAMPLE_WIDTH = 10;
  30626. + private static final int EXAMPLE_HEIGHT = 15;
  30627. +
  30628. + @Before
  30629. + public void setUp() {
  30630. + // The default number of rotation is once.
  30631. + processor = new ImageProcessor.Builder().add(new Rot90Op()).build();
  30632. + exampleBitmap = createExampleBitmap();
  30633. + input = new TensorImage(DataType.UINT8);
  30634. + input.load(exampleBitmap);
  30635. + }
  30636. +
  30637. + @Test
  30638. + public void updateNumberOfRotations_rotateTwice() {
  30639. + int numberOfRotations = 2;
  30640. +
  30641. + processor.updateNumberOfRotations(numberOfRotations);
  30642. + TensorImage output = processor.process(input);
  30643. +
  30644. + Bitmap outputBitmap = output.getBitmap();
  30645. + assertExampleBitmapWithTwoRotations(outputBitmap);
  30646. + }
  30647. +
  30648. + @Test
  30649. + public void updateNumberOfRotationsWithOpIndex_rotateTwiceAndOpIndex0() {
  30650. + int numberOfRotations = 2;
  30651. + int occurrence = 0;
  30652. +
  30653. + processor.updateNumberOfRotations(numberOfRotations, occurrence);
  30654. + TensorImage output = processor.process(input);
  30655. +
  30656. + Bitmap outputBitmap = output.getBitmap();
  30657. + assertExampleBitmapWithTwoRotations(outputBitmap);
  30658. + }
  30659. +
  30660. + @Test
  30661. + public void updateNumberOfRotationsWithOpIndex_negativeOpIndex() {
  30662. + int numberOfRotations = 2;
  30663. + int negativeOpIndex = -1;
  30664. +
  30665. + IndexOutOfBoundsException exception = assertThrows(IndexOutOfBoundsException.class,
  30666. + () -> processor.updateNumberOfRotations(numberOfRotations, negativeOpIndex));
  30667. + assertThat(exception).hasMessageThat().isEqualTo("occurrence (-1) must not be negative");
  30668. + }
  30669. +
  30670. + @Test
  30671. + public void updateNumberOfRotationsWithOpIndex_occurrenceEqualToTheNumberOfRot90Op() {
  30672. + int numberOfRotations = 2;
  30673. + int occurrence = 1;
  30674. +
  30675. + IndexOutOfBoundsException exception = assertThrows(IndexOutOfBoundsException.class,
  30676. + () -> processor.updateNumberOfRotations(numberOfRotations, occurrence));
  30677. + assertThat(exception).hasMessageThat().isEqualTo(
  30678. + "occurrence (1) must be less than size (1)");
  30679. + }
  30680. +
  30681. + @Test
  30682. + public void updateNumberOfRotationsWithOpIndex_noRot90OpIsAddedToImageProcessor() {
  30683. + int numberOfRotations = 2;
  30684. + int occurrence = 1;
  30685. + // Add an op other than Rot90Op into ImageProcessor.
  30686. + ImageProcessor processor =
  30687. + new ImageProcessor.Builder().add(new ResizeWithCropOrPadOp(5, 5)).build();
  30688. +
  30689. + IllegalStateException exception = assertThrows(IllegalStateException.class,
  30690. + () -> processor.updateNumberOfRotations(numberOfRotations, occurrence));
  30691. + assertThat(exception).hasMessageThat().isEqualTo(
  30692. + "The Rot90Op has not been added to the ImageProcessor.");
  30693. + }
  30694. +
  30695. + @Test
  30696. + public void updateNumberOfRotationsWithOpIndex_twoRot90Ops() {
  30697. + // The overall effect of the two rotations is equivalent to rotating for twice.
  30698. + int numberOfRotations0 = 5;
  30699. + int numberOfRotations1 = 1;
  30700. +
  30701. + // Add two Rot90Ops into ImageProcessor.
  30702. + ImageProcessor processor =
  30703. + new ImageProcessor.Builder().add(new Rot90Op()).add(new Rot90Op()).build();
  30704. + processor.updateNumberOfRotations(numberOfRotations0, /*occurrence=*/0);
  30705. + processor.updateNumberOfRotations(numberOfRotations1, /*occurrence=*/1);
  30706. +
  30707. + TensorImage output = processor.process(input);
  30708. + Bitmap outputBitmap = output.getBitmap();
  30709. + assertExampleBitmapWithTwoRotations(outputBitmap);
  30710. + }
  30711. - private Bitmap exampleBitmap;
  30712. - private TensorImage input;
  30713. - private ImageProcessor processor;
  30714. -
  30715. - private static final int EXAMPLE_WIDTH = 10;
  30716. - private static final int EXAMPLE_HEIGHT = 15;
  30717. -
  30718. - @Before
  30719. - public void setUp() {
  30720. - // The default number of rotation is once.
  30721. - processor = new ImageProcessor.Builder().add(new Rot90Op()).build();
  30722. - exampleBitmap = createExampleBitmap();
  30723. - input = new TensorImage(DataType.UINT8);
  30724. - input.load(exampleBitmap);
  30725. - }
  30726. -
  30727. - @Test
  30728. - public void updateNumberOfRotations_rotateTwice() {
  30729. - int numberOfRotations = 2;
  30730. -
  30731. - processor.updateNumberOfRotations(numberOfRotations);
  30732. - TensorImage output = processor.process(input);
  30733. -
  30734. - Bitmap outputBitmap = output.getBitmap();
  30735. - assertExampleBitmapWithTwoRotations(outputBitmap);
  30736. - }
  30737. -
  30738. - @Test
  30739. - public void updateNumberOfRotationsWithOpIndex_rotateTwiceAndOpIndex0() {
  30740. - int numberOfRotations = 2;
  30741. - int occurrence = 0;
  30742. -
  30743. - processor.updateNumberOfRotations(numberOfRotations, occurrence);
  30744. - TensorImage output = processor.process(input);
  30745. -
  30746. - Bitmap outputBitmap = output.getBitmap();
  30747. - assertExampleBitmapWithTwoRotations(outputBitmap);
  30748. - }
  30749. -
  30750. - @Test
  30751. - public void updateNumberOfRotationsWithOpIndex_negativeOpIndex() {
  30752. - int numberOfRotations = 2;
  30753. - int negativeOpIndex = -1;
  30754. -
  30755. - IndexOutOfBoundsException exception =
  30756. - assertThrows(
  30757. - IndexOutOfBoundsException.class,
  30758. - () -> processor.updateNumberOfRotations(numberOfRotations, negativeOpIndex));
  30759. - assertThat(exception).hasMessageThat().isEqualTo("occurrence (-1) must not be negative");
  30760. - }
  30761. -
  30762. - @Test
  30763. - public void updateNumberOfRotationsWithOpIndex_occurrenceEqualToTheNumberOfRot90Op() {
  30764. - int numberOfRotations = 2;
  30765. - int occurrence = 1;
  30766. -
  30767. - IndexOutOfBoundsException exception =
  30768. - assertThrows(
  30769. - IndexOutOfBoundsException.class,
  30770. - () -> processor.updateNumberOfRotations(numberOfRotations, occurrence));
  30771. - assertThat(exception).hasMessageThat().isEqualTo("occurrence (1) must be less than size (1)");
  30772. - }
  30773. -
  30774. - @Test
  30775. - public void updateNumberOfRotationsWithOpIndex_noRot90OpIsAddedToImageProcessor() {
  30776. - int numberOfRotations = 2;
  30777. - int occurrence = 1;
  30778. - // Add an op other than Rot90Op into ImageProcessor.
  30779. - ImageProcessor processor =
  30780. - new ImageProcessor.Builder().add(new ResizeWithCropOrPadOp(5, 5)).build();
  30781. -
  30782. - IllegalStateException exception =
  30783. - assertThrows(
  30784. - IllegalStateException.class,
  30785. - () -> processor.updateNumberOfRotations(numberOfRotations, occurrence));
  30786. - assertThat(exception)
  30787. - .hasMessageThat()
  30788. - .isEqualTo("The Rot90Op has not been added to the ImageProcessor.");
  30789. - }
  30790. -
  30791. - @Test
  30792. - public void updateNumberOfRotationsWithOpIndex_twoRot90Ops() {
  30793. - // The overall effect of the two rotations is equivalent to rotating for twice.
  30794. - int numberOfRotations0 = 5;
  30795. - int numberOfRotations1 = 1;
  30796. -
  30797. - // Add two Rot90Ops into ImageProcessor.
  30798. - ImageProcessor processor =
  30799. - new ImageProcessor.Builder().add(new Rot90Op()).add(new Rot90Op()).build();
  30800. - processor.updateNumberOfRotations(numberOfRotations0, /*occurrence=*/ 0);
  30801. - processor.updateNumberOfRotations(numberOfRotations1, /*occurrence=*/ 1);
  30802. -
  30803. - TensorImage output = processor.process(input);
  30804. - Bitmap outputBitmap = output.getBitmap();
  30805. - assertExampleBitmapWithTwoRotations(outputBitmap);
  30806. - }
  30807. -
  30808. - private void assertExampleBitmapWithTwoRotations(Bitmap bitmapRotated) {
  30809. - assertThat(bitmapRotated.getWidth()).isEqualTo(EXAMPLE_WIDTH);
  30810. - assertThat(bitmapRotated.getHeight()).isEqualTo(EXAMPLE_HEIGHT);
  30811. - for (int i = 0; i < exampleBitmap.getWidth(); i++) {
  30812. - for (int j = 0; j < exampleBitmap.getHeight(); j++) {
  30813. - assertThat(exampleBitmap.getPixel(i, j))
  30814. - .isEqualTo(bitmapRotated.getPixel(EXAMPLE_WIDTH - 1 - i, EXAMPLE_HEIGHT - 1 - j));
  30815. - }
  30816. + private void assertExampleBitmapWithTwoRotations(Bitmap bitmapRotated) {
  30817. + assertThat(bitmapRotated.getWidth()).isEqualTo(EXAMPLE_WIDTH);
  30818. + assertThat(bitmapRotated.getHeight()).isEqualTo(EXAMPLE_HEIGHT);
  30819. + for (int i = 0; i < exampleBitmap.getWidth(); i++) {
  30820. + for (int j = 0; j < exampleBitmap.getHeight(); j++) {
  30821. + assertThat(exampleBitmap.getPixel(i, j))
  30822. + .isEqualTo(bitmapRotated.getPixel(
  30823. + EXAMPLE_WIDTH - 1 - i, EXAMPLE_HEIGHT - 1 - j));
  30824. + }
  30825. + }
  30826. }
  30827. - }
  30828. - private static Bitmap createExampleBitmap() {
  30829. - int[] colors = new int[EXAMPLE_WIDTH * EXAMPLE_HEIGHT];
  30830. - for (int i = 0; i < EXAMPLE_WIDTH * EXAMPLE_HEIGHT; i++) {
  30831. - colors[i] = (i << 16) | ((i + 1) << 8) | (i + 2);
  30832. + private static Bitmap createExampleBitmap() {
  30833. + int[] colors = new int[EXAMPLE_WIDTH * EXAMPLE_HEIGHT];
  30834. + for (int i = 0; i < EXAMPLE_WIDTH * EXAMPLE_HEIGHT; i++) {
  30835. + colors[i] = (i << 16) | ((i + 1) << 8) | (i + 2);
  30836. + }
  30837. + return Bitmap.createBitmap(colors, EXAMPLE_WIDTH, EXAMPLE_HEIGHT, Bitmap.Config.ARGB_8888);
  30838. }
  30839. - return Bitmap.createBitmap(colors, EXAMPLE_WIDTH, EXAMPLE_HEIGHT, Bitmap.Config.ARGB_8888);
  30840. - }
  30841. }
  30842. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ImageProcessorTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ImageProcessorTest.java
  30843. index a655f4a506900..a93ba5465125c 100644
  30844. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ImageProcessorTest.java
  30845. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ImageProcessorTest.java
  30846. @@ -16,10 +16,12 @@ limitations under the License.
  30847. package org.tensorflow.lite.support.image;
  30848. import static com.google.common.truth.Truth.assertThat;
  30849. +
  30850. import static org.junit.Assert.assertThrows;
  30851. import android.graphics.Bitmap;
  30852. import android.graphics.RectF;
  30853. +
  30854. import org.junit.Test;
  30855. import org.junit.runner.RunWith;
  30856. import org.robolectric.RobolectricTestRunner;
  30857. @@ -34,115 +36,112 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
  30858. /** Tests for {@link ImageProcessor}. */
  30859. @RunWith(RobolectricTestRunner.class)
  30860. public final class ImageProcessorTest {
  30861. + private static final int EXAMPLE_WIDTH = 10;
  30862. + private static final int EXAMPLE_HEIGHT = 15;
  30863. + private static final int EXAMPLE_NUM_PIXELS = EXAMPLE_HEIGHT * EXAMPLE_WIDTH;
  30864. + private static final int EXAMPLE_NUM_CHANNELS = 3;
  30865. + private static final float MEAN = 127.5f;
  30866. + private static final float STDDEV = 127.5f;
  30867. +
  30868. + @Test
  30869. + public void testBuild() {
  30870. + ImageProcessor processor =
  30871. + new ImageProcessor.Builder().add(new NormalizeOp(MEAN, STDDEV)).build();
  30872. + assertThat(processor).isNotNull();
  30873. + }
  30874. - private static final int EXAMPLE_WIDTH = 10;
  30875. - private static final int EXAMPLE_HEIGHT = 15;
  30876. - private static final int EXAMPLE_NUM_PIXELS = EXAMPLE_HEIGHT * EXAMPLE_WIDTH;
  30877. - private static final int EXAMPLE_NUM_CHANNELS = 3;
  30878. - private static final float MEAN = 127.5f;
  30879. - private static final float STDDEV = 127.5f;
  30880. -
  30881. - @Test
  30882. - public void testBuild() {
  30883. - ImageProcessor processor =
  30884. - new ImageProcessor.Builder().add(new NormalizeOp(MEAN, STDDEV)).build();
  30885. - assertThat(processor).isNotNull();
  30886. - }
  30887. -
  30888. - @Test
  30889. - public void testNormalize() {
  30890. - TensorImage input = new TensorImage(DataType.FLOAT32);
  30891. - input.load(createExampleBitmap());
  30892. - ImageProcessor processor =
  30893. - new ImageProcessor.Builder().add(new NormalizeOp(MEAN, STDDEV)).build();
  30894. - TensorImage output = processor.process(input);
  30895. -
  30896. - float[] pixels = output.getTensorBuffer().getFloatArray();
  30897. - assertThat(pixels.length).isEqualTo(EXAMPLE_NUM_CHANNELS * EXAMPLE_NUM_PIXELS);
  30898. - for (float p : pixels) {
  30899. - assertThat(p).isAtLeast(-1);
  30900. - assertThat(p).isAtMost(1);
  30901. + @Test
  30902. + public void testNormalize() {
  30903. + TensorImage input = new TensorImage(DataType.FLOAT32);
  30904. + input.load(createExampleBitmap());
  30905. + ImageProcessor processor =
  30906. + new ImageProcessor.Builder().add(new NormalizeOp(MEAN, STDDEV)).build();
  30907. + TensorImage output = processor.process(input);
  30908. +
  30909. + float[] pixels = output.getTensorBuffer().getFloatArray();
  30910. + assertThat(pixels.length).isEqualTo(EXAMPLE_NUM_CHANNELS * EXAMPLE_NUM_PIXELS);
  30911. + for (float p : pixels) {
  30912. + assertThat(p).isAtLeast(-1);
  30913. + assertThat(p).isAtMost(1);
  30914. + }
  30915. }
  30916. - }
  30917. -
  30918. - @Test
  30919. - public void testMultipleNormalize() {
  30920. - TensorImage input = new TensorImage(DataType.FLOAT32);
  30921. - input.load(createExampleBitmap());
  30922. - ImageProcessor processor =
  30923. - new ImageProcessor.Builder()
  30924. - .add(new NormalizeOp(MEAN, STDDEV)) // [0, 255] -> [-1, 1]
  30925. - .add(new NormalizeOp(-1, 2)) // [-1, 1] -> [0, 1]
  30926. - .build();
  30927. - TensorImage output = processor.process(input);
  30928. -
  30929. - float[] pixels = output.getTensorBuffer().getFloatArray();
  30930. - assertThat(pixels.length).isEqualTo(EXAMPLE_NUM_CHANNELS * EXAMPLE_NUM_PIXELS);
  30931. - for (float p : pixels) {
  30932. - assertThat(p).isAtLeast(0);
  30933. - assertThat(p).isAtMost(1);
  30934. +
  30935. + @Test
  30936. + public void testMultipleNormalize() {
  30937. + TensorImage input = new TensorImage(DataType.FLOAT32);
  30938. + input.load(createExampleBitmap());
  30939. + ImageProcessor processor =
  30940. + new ImageProcessor.Builder()
  30941. + .add(new NormalizeOp(MEAN, STDDEV)) // [0, 255] -> [-1, 1]
  30942. + .add(new NormalizeOp(-1, 2)) // [-1, 1] -> [0, 1]
  30943. + .build();
  30944. + TensorImage output = processor.process(input);
  30945. +
  30946. + float[] pixels = output.getTensorBuffer().getFloatArray();
  30947. + assertThat(pixels.length).isEqualTo(EXAMPLE_NUM_CHANNELS * EXAMPLE_NUM_PIXELS);
  30948. + for (float p : pixels) {
  30949. + assertThat(p).isAtLeast(0);
  30950. + assertThat(p).isAtMost(1);
  30951. + }
  30952. }
  30953. - }
  30954. -
  30955. - @Test
  30956. - public void inverseTransformRectCorrectly() {
  30957. - ImageProcessor processor =
  30958. - new ImageProcessor.Builder()
  30959. - .add(new ResizeOp(200, 300, ResizeMethod.BILINEAR))
  30960. - .add(new ResizeWithCropOrPadOp(100, 200))
  30961. - .add(new Rot90Op(1))
  30962. - .add(new NormalizeOp(127, 128))
  30963. - .build();
  30964. - RectF transformed = new RectF(0, 50, 100, 150);
  30965. - RectF original = processor.inverseTransform(transformed, 400, 600);
  30966. - assertThat(original.top).isEqualTo(100);
  30967. - assertThat(original.left).isEqualTo(200);
  30968. - assertThat(original.right).isEqualTo(400);
  30969. - assertThat(original.bottom).isEqualTo(300);
  30970. - }
  30971. -
  30972. - @Test
  30973. - public void resizeShouldFailWithNonRgbImages() {
  30974. - int[] data = new int[] {1, 2, 3};
  30975. - TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.UINT8);
  30976. - tensorBuffer.loadArray(data, new int[] {1, 3});
  30977. - TensorImage image = new TensorImage();
  30978. - image.load(tensorBuffer, ColorSpaceType.GRAYSCALE);
  30979. -
  30980. - ImageProcessor processor =
  30981. - new ImageProcessor.Builder().add(new ResizeOp(200, 300, ResizeMethod.BILINEAR)).build();
  30982. -
  30983. - IllegalArgumentException exception =
  30984. - assertThrows(IllegalArgumentException.class, () -> processor.process(image));
  30985. - assertThat(exception)
  30986. - .hasMessageThat()
  30987. - .contains(
  30988. - "Only RGB images are supported in ResizeOp, but not "
  30989. +
  30990. + @Test
  30991. + public void inverseTransformRectCorrectly() {
  30992. + ImageProcessor processor = new ImageProcessor.Builder()
  30993. + .add(new ResizeOp(200, 300, ResizeMethod.BILINEAR))
  30994. + .add(new ResizeWithCropOrPadOp(100, 200))
  30995. + .add(new Rot90Op(1))
  30996. + .add(new NormalizeOp(127, 128))
  30997. + .build();
  30998. + RectF transformed = new RectF(0, 50, 100, 150);
  30999. + RectF original = processor.inverseTransform(transformed, 400, 600);
  31000. + assertThat(original.top).isEqualTo(100);
  31001. + assertThat(original.left).isEqualTo(200);
  31002. + assertThat(original.right).isEqualTo(400);
  31003. + assertThat(original.bottom).isEqualTo(300);
  31004. + }
  31005. +
  31006. + @Test
  31007. + public void resizeShouldFailWithNonRgbImages() {
  31008. + int[] data = new int[] {1, 2, 3};
  31009. + TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.UINT8);
  31010. + tensorBuffer.loadArray(data, new int[] {1, 3});
  31011. + TensorImage image = new TensorImage();
  31012. + image.load(tensorBuffer, ColorSpaceType.GRAYSCALE);
  31013. +
  31014. + ImageProcessor processor = new ImageProcessor.Builder()
  31015. + .add(new ResizeOp(200, 300, ResizeMethod.BILINEAR))
  31016. + .build();
  31017. +
  31018. + IllegalArgumentException exception =
  31019. + assertThrows(IllegalArgumentException.class, () -> processor.process(image));
  31020. + assertThat(exception).hasMessageThat().contains(
  31021. + "Only RGB images are supported in ResizeOp, but not "
  31022. + image.getColorSpaceType().name());
  31023. - }
  31024. -
  31025. - @Test
  31026. - public void normalizeShouldSuccessWithNonRgbImages() {
  31027. - int[] data = new int[] {1, 2, 3};
  31028. - TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.UINT8);
  31029. - tensorBuffer.loadArray(data, new int[] {1, 3});
  31030. - TensorImage image = new TensorImage();
  31031. - image.load(tensorBuffer, ColorSpaceType.GRAYSCALE);
  31032. -
  31033. - ImageProcessor processor =
  31034. - new ImageProcessor.Builder().add(new NormalizeOp(0.5f, 1f)).build();
  31035. - TensorImage output = processor.process(image);
  31036. -
  31037. - float[] pixels = output.getTensorBuffer().getFloatArray();
  31038. - assertThat(pixels).isEqualTo(new float[]{0.5f, 1.5f, 2.5f});
  31039. - }
  31040. -
  31041. - private static Bitmap createExampleBitmap() {
  31042. - int[] colors = new int[EXAMPLE_NUM_PIXELS];
  31043. - for (int i = 0; i < EXAMPLE_NUM_PIXELS; i++) {
  31044. - colors[i] = (i << 16) | ((i + 1) << 8) | (i + 2);
  31045. }
  31046. - return Bitmap.createBitmap(colors, EXAMPLE_WIDTH, EXAMPLE_HEIGHT, Bitmap.Config.ARGB_8888);
  31047. - }
  31048. + @Test
  31049. + public void normalizeShouldSuccessWithNonRgbImages() {
  31050. + int[] data = new int[] {1, 2, 3};
  31051. + TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.UINT8);
  31052. + tensorBuffer.loadArray(data, new int[] {1, 3});
  31053. + TensorImage image = new TensorImage();
  31054. + image.load(tensorBuffer, ColorSpaceType.GRAYSCALE);
  31055. +
  31056. + ImageProcessor processor =
  31057. + new ImageProcessor.Builder().add(new NormalizeOp(0.5f, 1f)).build();
  31058. + TensorImage output = processor.process(image);
  31059. +
  31060. + float[] pixels = output.getTensorBuffer().getFloatArray();
  31061. + assertThat(pixels).isEqualTo(new float[] {0.5f, 1.5f, 2.5f});
  31062. + }
  31063. +
  31064. + private static Bitmap createExampleBitmap() {
  31065. + int[] colors = new int[EXAMPLE_NUM_PIXELS];
  31066. + for (int i = 0; i < EXAMPLE_NUM_PIXELS; i++) {
  31067. + colors[i] = (i << 16) | ((i + 1) << 8) | (i + 2);
  31068. + }
  31069. +
  31070. + return Bitmap.createBitmap(colors, EXAMPLE_WIDTH, EXAMPLE_HEIGHT, Bitmap.Config.ARGB_8888);
  31071. + }
  31072. }
  31073. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/MlImageAdapterTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/MlImageAdapterTest.java
  31074. index 7e61aa8d3ce58..e8caefcab8a04 100644
  31075. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/MlImageAdapterTest.java
  31076. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/MlImageAdapterTest.java
  31077. @@ -16,20 +16,19 @@ limitations under the License.
  31078. package org.tensorflow.lite.support.image;
  31079. import static com.google.common.truth.Truth.assertThat;
  31080. +
  31081. import static org.junit.Assert.assertThrows;
  31082. import static org.mockito.Mockito.when;
  31083. import android.graphics.Bitmap;
  31084. import android.media.Image;
  31085. +
  31086. import com.google.android.odml.image.BitmapMlImageBuilder;
  31087. import com.google.android.odml.image.ByteBufferMlImageBuilder;
  31088. import com.google.android.odml.image.MediaMlImageBuilder;
  31089. import com.google.android.odml.image.MlImage;
  31090. import com.google.android.odml.image.MlImage.ImageFormat;
  31091. -import java.io.IOException;
  31092. -import java.nio.ByteBuffer;
  31093. -import java.util.Arrays;
  31094. -import java.util.Collection;
  31095. +
  31096. import org.junit.Before;
  31097. import org.junit.Test;
  31098. import org.junit.runner.RunWith;
  31099. @@ -42,139 +41,141 @@ import org.robolectric.ParameterizedRobolectricTestRunner.Parameter;
  31100. import org.robolectric.ParameterizedRobolectricTestRunner.Parameters;
  31101. import org.robolectric.RobolectricTestRunner;
  31102. +import java.io.IOException;
  31103. +import java.nio.ByteBuffer;
  31104. +import java.util.Arrays;
  31105. +import java.util.Collection;
  31106. +
  31107. /** Test for {@link MlImageAdapter}. */
  31108. @RunWith(Suite.class)
  31109. @SuiteClasses({
  31110. - MlImageAdapterTest.CreateTensorImageFromSupportedByteBufferMlImage.class,
  31111. - MlImageAdapterTest.CreateTensorImageFromUnsupportedByteBufferMlImage.class,
  31112. - MlImageAdapterTest.General.class,
  31113. + MlImageAdapterTest.CreateTensorImageFromSupportedByteBufferMlImage.class,
  31114. + MlImageAdapterTest.CreateTensorImageFromUnsupportedByteBufferMlImage.class,
  31115. + MlImageAdapterTest.General.class,
  31116. })
  31117. public class MlImageAdapterTest {
  31118. -
  31119. - @RunWith(ParameterizedRobolectricTestRunner.class)
  31120. - public static final class CreateTensorImageFromSupportedByteBufferMlImage
  31121. - extends MlImageAdapterTest {
  31122. -
  31123. - @Parameter(0)
  31124. - @ImageFormat
  31125. - public int imageFormat;
  31126. -
  31127. - @Parameter(1)
  31128. - public ColorSpaceType colorSpaceType;
  31129. -
  31130. - @Parameters(name = "imageFormat={0}")
  31131. - public static Collection<Object[]> data() {
  31132. - return Arrays.asList(
  31133. - new Object[][] {
  31134. - {MlImage.IMAGE_FORMAT_RGB, ColorSpaceType.RGB},
  31135. - {MlImage.IMAGE_FORMAT_ALPHA, ColorSpaceType.GRAYSCALE},
  31136. - {MlImage.IMAGE_FORMAT_NV21, ColorSpaceType.NV21},
  31137. - {MlImage.IMAGE_FORMAT_NV12, ColorSpaceType.NV12},
  31138. - {MlImage.IMAGE_FORMAT_YV12, ColorSpaceType.YV12},
  31139. - {MlImage.IMAGE_FORMAT_YV21, ColorSpaceType.YV21},
  31140. - });
  31141. - }
  31142. -
  31143. - @Test
  31144. - public void createTensorImageFrom_supportedByteBufferMlImage_succeeds() throws IOException {
  31145. - ByteBuffer buffer = ByteBuffer.allocateDirect(6).asReadOnlyBuffer();
  31146. - buffer.rewind();
  31147. - MlImage image = new ByteBufferMlImageBuilder(buffer, 1, 2, imageFormat).build();
  31148. -
  31149. - TensorImage tensorImage = MlImageAdapter.createTensorImageFrom(image);
  31150. -
  31151. - assertThat(tensorImage.getWidth()).isEqualTo(1);
  31152. - assertThat(tensorImage.getHeight()).isEqualTo(2);
  31153. - assertThat(tensorImage.getColorSpaceType()).isEqualTo(colorSpaceType);
  31154. - assertThat(tensorImage.getBuffer().position()).isEqualTo(0);
  31155. - assertThat(tensorImage.getBuffer()).isEquivalentAccordingToCompareTo(buffer);
  31156. - }
  31157. - }
  31158. -
  31159. - @RunWith(ParameterizedRobolectricTestRunner.class)
  31160. - public static final class CreateTensorImageFromUnsupportedByteBufferMlImage
  31161. - extends MlImageAdapterTest {
  31162. - @Parameter(0)
  31163. - @ImageFormat
  31164. - public int imageFormat;
  31165. -
  31166. - @Parameters(name = "imageFormat={0}")
  31167. - public static Collection<Object[]> data() {
  31168. - return Arrays.asList(
  31169. - new Object[][] {
  31170. - {MlImage.IMAGE_FORMAT_RGBA},
  31171. - {MlImage.IMAGE_FORMAT_JPEG},
  31172. - {MlImage.IMAGE_FORMAT_YUV_420_888},
  31173. - {MlImage.IMAGE_FORMAT_UNKNOWN},
  31174. - });
  31175. - }
  31176. -
  31177. - @Test
  31178. - public void createTensorImageFrom_unsupportedByteBufferMlImage_throws() throws IOException {
  31179. - ByteBuffer buffer = ByteBuffer.allocateDirect(6).asReadOnlyBuffer();
  31180. - buffer.rewind();
  31181. - MlImage image = new ByteBufferMlImageBuilder(buffer, 1, 2, imageFormat).build();
  31182. -
  31183. - assertThrows(
  31184. - IllegalArgumentException.class, () -> MlImageAdapter.createTensorImageFrom(image));
  31185. - }
  31186. - }
  31187. -
  31188. - @RunWith(RobolectricTestRunner.class)
  31189. - public static final class General extends MlImageAdapterTest {
  31190. -
  31191. - @Mock Image mediaImageMock;
  31192. -
  31193. - @Before
  31194. - public void setUp() {
  31195. - MockitoAnnotations.openMocks(this);
  31196. - }
  31197. -
  31198. - @Test
  31199. - public void createTensorImageFrom_bitmapMlImage_succeeds() throws IOException {
  31200. - Bitmap bitmap =
  31201. - Bitmap.createBitmap(new int[] {0xff000100, 0xff000001}, 1, 2, Bitmap.Config.ARGB_8888);
  31202. - MlImage image = new BitmapMlImageBuilder(bitmap).build();
  31203. - ByteBuffer expectedBuffer = ByteBuffer.allocateDirect(6);
  31204. - for (byte b : new byte[] {0, 1, 0, 0, 0, 1}) {
  31205. - expectedBuffer.put(b);
  31206. - }
  31207. - expectedBuffer.rewind();
  31208. -
  31209. - TensorImage tensorImage = MlImageAdapter.createTensorImageFrom(image);
  31210. -
  31211. - assertThat(tensorImage.getWidth()).isEqualTo(1);
  31212. - assertThat(tensorImage.getHeight()).isEqualTo(2);
  31213. - assertThat(tensorImage.getBuffer().position()).isEqualTo(0);
  31214. - assertThat(tensorImage.getBuffer()).isEquivalentAccordingToCompareTo(expectedBuffer);
  31215. - }
  31216. -
  31217. - @Test
  31218. - public void createTensorImageFrom_yuv420888MediaImageMlImage_succeeds() throws IOException {
  31219. - setUpMediaImageMock(mediaImageMock, android.graphics.ImageFormat.YUV_420_888, 1, 2);
  31220. - MlImage image = new MediaMlImageBuilder(mediaImageMock).build();
  31221. -
  31222. - TensorImage tensorImage = MlImageAdapter.createTensorImageFrom(image);
  31223. -
  31224. - assertThat(tensorImage.getWidth()).isEqualTo(1);
  31225. - assertThat(tensorImage.getHeight()).isEqualTo(2);
  31226. - assertThat(tensorImage.getColorSpaceType()).isEqualTo(ColorSpaceType.YUV_420_888);
  31227. + @RunWith(ParameterizedRobolectricTestRunner.class)
  31228. + public static final class CreateTensorImageFromSupportedByteBufferMlImage
  31229. + extends MlImageAdapterTest {
  31230. + @Parameter(0)
  31231. + @ImageFormat
  31232. + public int imageFormat;
  31233. +
  31234. + @Parameter(1)
  31235. + public ColorSpaceType colorSpaceType;
  31236. +
  31237. + @Parameters(name = "imageFormat={0}")
  31238. + public static Collection<Object[]> data() {
  31239. + return Arrays.asList(new Object[][] {
  31240. + {MlImage.IMAGE_FORMAT_RGB, ColorSpaceType.RGB},
  31241. + {MlImage.IMAGE_FORMAT_ALPHA, ColorSpaceType.GRAYSCALE},
  31242. + {MlImage.IMAGE_FORMAT_NV21, ColorSpaceType.NV21},
  31243. + {MlImage.IMAGE_FORMAT_NV12, ColorSpaceType.NV12},
  31244. + {MlImage.IMAGE_FORMAT_YV12, ColorSpaceType.YV12},
  31245. + {MlImage.IMAGE_FORMAT_YV21, ColorSpaceType.YV21},
  31246. + });
  31247. + }
  31248. +
  31249. + @Test
  31250. + public void createTensorImageFrom_supportedByteBufferMlImage_succeeds() throws IOException {
  31251. + ByteBuffer buffer = ByteBuffer.allocateDirect(6).asReadOnlyBuffer();
  31252. + buffer.rewind();
  31253. + MlImage image = new ByteBufferMlImageBuilder(buffer, 1, 2, imageFormat).build();
  31254. +
  31255. + TensorImage tensorImage = MlImageAdapter.createTensorImageFrom(image);
  31256. +
  31257. + assertThat(tensorImage.getWidth()).isEqualTo(1);
  31258. + assertThat(tensorImage.getHeight()).isEqualTo(2);
  31259. + assertThat(tensorImage.getColorSpaceType()).isEqualTo(colorSpaceType);
  31260. + assertThat(tensorImage.getBuffer().position()).isEqualTo(0);
  31261. + assertThat(tensorImage.getBuffer()).isEquivalentAccordingToCompareTo(buffer);
  31262. + }
  31263. }
  31264. - @Test
  31265. - public void createTensorImageFrom_nonYuv420888MediaImageMlImage_throws() throws IOException {
  31266. - setUpMediaImageMock(mediaImageMock, android.graphics.ImageFormat.YUV_422_888, 1, 2);
  31267. - MlImage image = new MediaMlImageBuilder(mediaImageMock).build();
  31268. -
  31269. - assertThrows(
  31270. - IllegalArgumentException.class, () -> MlImageAdapter.createTensorImageFrom(image));
  31271. + @RunWith(ParameterizedRobolectricTestRunner.class)
  31272. + public static final class CreateTensorImageFromUnsupportedByteBufferMlImage
  31273. + extends MlImageAdapterTest {
  31274. + @Parameter(0)
  31275. + @ImageFormat
  31276. + public int imageFormat;
  31277. +
  31278. + @Parameters(name = "imageFormat={0}")
  31279. + public static Collection<Object[]> data() {
  31280. + return Arrays.asList(new Object[][] {
  31281. + {MlImage.IMAGE_FORMAT_RGBA},
  31282. + {MlImage.IMAGE_FORMAT_JPEG},
  31283. + {MlImage.IMAGE_FORMAT_YUV_420_888},
  31284. + {MlImage.IMAGE_FORMAT_UNKNOWN},
  31285. + });
  31286. + }
  31287. +
  31288. + @Test
  31289. + public void createTensorImageFrom_unsupportedByteBufferMlImage_throws() throws IOException {
  31290. + ByteBuffer buffer = ByteBuffer.allocateDirect(6).asReadOnlyBuffer();
  31291. + buffer.rewind();
  31292. + MlImage image = new ByteBufferMlImageBuilder(buffer, 1, 2, imageFormat).build();
  31293. +
  31294. + assertThrows(IllegalArgumentException.class,
  31295. + () -> MlImageAdapter.createTensorImageFrom(image));
  31296. + }
  31297. }
  31298. - private static void setUpMediaImageMock(
  31299. - Image mediaImageMock, int imageFormat, int width, int height) {
  31300. - when(mediaImageMock.getFormat()).thenReturn(imageFormat);
  31301. - when(mediaImageMock.getWidth()).thenReturn(width);
  31302. - when(mediaImageMock.getHeight()).thenReturn(height);
  31303. + @RunWith(RobolectricTestRunner.class)
  31304. + public static final class General extends MlImageAdapterTest {
  31305. + @Mock
  31306. + Image mediaImageMock;
  31307. +
  31308. + @Before
  31309. + public void setUp() {
  31310. + MockitoAnnotations.openMocks(this);
  31311. + }
  31312. +
  31313. + @Test
  31314. + public void createTensorImageFrom_bitmapMlImage_succeeds() throws IOException {
  31315. + Bitmap bitmap = Bitmap.createBitmap(
  31316. + new int[] {0xff000100, 0xff000001}, 1, 2, Bitmap.Config.ARGB_8888);
  31317. + MlImage image = new BitmapMlImageBuilder(bitmap).build();
  31318. + ByteBuffer expectedBuffer = ByteBuffer.allocateDirect(6);
  31319. + for (byte b : new byte[] {0, 1, 0, 0, 0, 1}) {
  31320. + expectedBuffer.put(b);
  31321. + }
  31322. + expectedBuffer.rewind();
  31323. +
  31324. + TensorImage tensorImage = MlImageAdapter.createTensorImageFrom(image);
  31325. +
  31326. + assertThat(tensorImage.getWidth()).isEqualTo(1);
  31327. + assertThat(tensorImage.getHeight()).isEqualTo(2);
  31328. + assertThat(tensorImage.getBuffer().position()).isEqualTo(0);
  31329. + assertThat(tensorImage.getBuffer()).isEquivalentAccordingToCompareTo(expectedBuffer);
  31330. + }
  31331. +
  31332. + @Test
  31333. + public void createTensorImageFrom_yuv420888MediaImageMlImage_succeeds() throws IOException {
  31334. + setUpMediaImageMock(mediaImageMock, android.graphics.ImageFormat.YUV_420_888, 1, 2);
  31335. + MlImage image = new MediaMlImageBuilder(mediaImageMock).build();
  31336. +
  31337. + TensorImage tensorImage = MlImageAdapter.createTensorImageFrom(image);
  31338. +
  31339. + assertThat(tensorImage.getWidth()).isEqualTo(1);
  31340. + assertThat(tensorImage.getHeight()).isEqualTo(2);
  31341. + assertThat(tensorImage.getColorSpaceType()).isEqualTo(ColorSpaceType.YUV_420_888);
  31342. + }
  31343. +
  31344. + @Test
  31345. + public void createTensorImageFrom_nonYuv420888MediaImageMlImage_throws()
  31346. + throws IOException {
  31347. + setUpMediaImageMock(mediaImageMock, android.graphics.ImageFormat.YUV_422_888, 1, 2);
  31348. + MlImage image = new MediaMlImageBuilder(mediaImageMock).build();
  31349. +
  31350. + assertThrows(IllegalArgumentException.class,
  31351. + () -> MlImageAdapter.createTensorImageFrom(image));
  31352. + }
  31353. +
  31354. + private static void setUpMediaImageMock(
  31355. + Image mediaImageMock, int imageFormat, int width, int height) {
  31356. + when(mediaImageMock.getFormat()).thenReturn(imageFormat);
  31357. + when(mediaImageMock.getWidth()).thenReturn(width);
  31358. + when(mediaImageMock.getHeight()).thenReturn(height);
  31359. + }
  31360. }
  31361. - }
  31362. }
  31363. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/TensorImageInstrumentedTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/TensorImageInstrumentedTest.java
  31364. index ca5f7dc7551be..83b54d0a8db78 100644
  31365. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/TensorImageInstrumentedTest.java
  31366. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/TensorImageInstrumentedTest.java
  31367. @@ -15,6 +15,7 @@ limitations under the License.
  31368. package org.tensorflow.lite.support.image;
  31369. import static com.google.common.truth.Truth.assertThat;
  31370. +
  31371. import static org.tensorflow.lite.DataType.FLOAT32;
  31372. import static org.tensorflow.lite.DataType.UINT8;
  31373. import static org.tensorflow.lite.support.image.TestImageCreator.createGrayscaleBitmap;
  31374. @@ -23,6 +24,7 @@ import static org.tensorflow.lite.support.image.TestImageCreator.createRgbBitmap
  31375. import static org.tensorflow.lite.support.image.TestImageCreator.createRgbTensorBuffer;
  31376. import android.graphics.Bitmap;
  31377. +
  31378. import org.junit.Test;
  31379. import org.junit.runner.RunWith;
  31380. import org.junit.runners.JUnit4;
  31381. @@ -31,110 +33,110 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
  31382. @RunWith(JUnit4.class)
  31383. public final class TensorImageInstrumentedTest {
  31384. + /**
  31385. + * Difference between the pair of float and uint8 values. It is used to test the data
  31386. + * conversion.
  31387. + */
  31388. + private static final float DELTA = 0.1f;
  31389. +
  31390. + // Note that parameterized test with android_library_instrumentation_tests is currently not
  31391. + // supported in internally.
  31392. + @Test
  31393. + public void loadAndGetBitmapSucceedsWithFloatBufferFloatImage() {
  31394. + DataType tensorBufferDataType = FLOAT32;
  31395. + DataType tensorImageDataType = FLOAT32;
  31396. + boolean isNormalized = true;
  31397. + ColorSpaceType colorSpaceType = ColorSpaceType.GRAYSCALE;
  31398. +
  31399. + TensorBuffer tensorBuffer =
  31400. + createTensorBuffer(tensorBufferDataType, isNormalized, colorSpaceType, DELTA);
  31401. + TensorImage tensorImage = new TensorImage(tensorImageDataType);
  31402. +
  31403. + tensorImage.load(tensorBuffer, colorSpaceType);
  31404. + Bitmap bitmap = tensorImage.getBitmap();
  31405. +
  31406. + Bitmap expectedBitmap = createBitmap(colorSpaceType);
  31407. + assertThat(bitmap.sameAs(expectedBitmap)).isTrue();
  31408. + }
  31409. +
  31410. + @Test
  31411. + public void loadAndGetBitmapSucceedsWithFloatBufferUINT8Image() {
  31412. + DataType tensorBufferDataType = FLOAT32;
  31413. + DataType tensorImageDataType = UINT8;
  31414. + boolean isNormalized = false;
  31415. + ColorSpaceType colorSpaceType = ColorSpaceType.GRAYSCALE;
  31416. +
  31417. + TensorBuffer tensorBuffer =
  31418. + createTensorBuffer(tensorBufferDataType, isNormalized, colorSpaceType, DELTA);
  31419. + TensorImage tensorImage = new TensorImage(tensorImageDataType);
  31420. - /**
  31421. - * Difference between the pair of float and uint8 values. It is used to test the data conversion.
  31422. - */
  31423. - private static final float DELTA = 0.1f;
  31424. -
  31425. - // Note that parameterized test with android_library_instrumentation_tests is currently not
  31426. - // supported in internally.
  31427. - @Test
  31428. - public void loadAndGetBitmapSucceedsWithFloatBufferFloatImage() {
  31429. - DataType tensorBufferDataType = FLOAT32;
  31430. - DataType tensorImageDataType = FLOAT32;
  31431. - boolean isNormalized = true;
  31432. - ColorSpaceType colorSpaceType = ColorSpaceType.GRAYSCALE;
  31433. -
  31434. - TensorBuffer tensorBuffer =
  31435. - createTensorBuffer(tensorBufferDataType, isNormalized, colorSpaceType, DELTA);
  31436. - TensorImage tensorImage = new TensorImage(tensorImageDataType);
  31437. -
  31438. - tensorImage.load(tensorBuffer, colorSpaceType);
  31439. - Bitmap bitmap = tensorImage.getBitmap();
  31440. -
  31441. - Bitmap expectedBitmap = createBitmap(colorSpaceType);
  31442. - assertThat(bitmap.sameAs(expectedBitmap)).isTrue();
  31443. - }
  31444. -
  31445. - @Test
  31446. - public void loadAndGetBitmapSucceedsWithFloatBufferUINT8Image() {
  31447. - DataType tensorBufferDataType = FLOAT32;
  31448. - DataType tensorImageDataType = UINT8;
  31449. - boolean isNormalized = false;
  31450. - ColorSpaceType colorSpaceType = ColorSpaceType.GRAYSCALE;
  31451. -
  31452. - TensorBuffer tensorBuffer =
  31453. - createTensorBuffer(tensorBufferDataType, isNormalized, colorSpaceType, DELTA);
  31454. - TensorImage tensorImage = new TensorImage(tensorImageDataType);
  31455. -
  31456. - tensorImage.load(tensorBuffer, colorSpaceType);
  31457. - Bitmap bitmap = tensorImage.getBitmap();
  31458. -
  31459. - Bitmap expectedBitmap = createBitmap(colorSpaceType);
  31460. - assertThat(bitmap.sameAs(expectedBitmap)).isTrue();
  31461. - }
  31462. -
  31463. - @Test
  31464. - public void loadAndGetBitmapSucceedsWithUINT8BufferFloatImage() {
  31465. - DataType tensorBufferDataType = UINT8;
  31466. - DataType tensorImageDataType = FLOAT32;
  31467. - boolean isNormalized = true;
  31468. - ColorSpaceType colorSpaceType = ColorSpaceType.GRAYSCALE;
  31469. -
  31470. - TensorBuffer tensorBuffer =
  31471. - createTensorBuffer(tensorBufferDataType, isNormalized, colorSpaceType, DELTA);
  31472. - TensorImage tensorImage = new TensorImage(tensorImageDataType);
  31473. -
  31474. - tensorImage.load(tensorBuffer, colorSpaceType);
  31475. - Bitmap bitmap = tensorImage.getBitmap();
  31476. -
  31477. - Bitmap expectedBitmap = createBitmap(colorSpaceType);
  31478. - assertThat(bitmap.sameAs(expectedBitmap)).isTrue();
  31479. - }
  31480. -
  31481. - @Test
  31482. - public void loadAndGetBitmapSucceedsWithUINT8BufferUINT8Image() {
  31483. - DataType tensorBufferDataType = UINT8;
  31484. - DataType tensorImageDataType = UINT8;
  31485. - boolean isNormalized = false;
  31486. - ColorSpaceType colorSpaceType = ColorSpaceType.GRAYSCALE;
  31487. -
  31488. - TensorBuffer tensorBuffer =
  31489. - createTensorBuffer(tensorBufferDataType, isNormalized, colorSpaceType, DELTA);
  31490. - TensorImage tensorImage = new TensorImage(tensorImageDataType);
  31491. -
  31492. - tensorImage.load(tensorBuffer, colorSpaceType);
  31493. - Bitmap bitmap = tensorImage.getBitmap();
  31494. -
  31495. - Bitmap expectedBitmap = createBitmap(colorSpaceType);
  31496. - assertThat(bitmap.sameAs(expectedBitmap)).isTrue();
  31497. - }
  31498. -
  31499. - private static TensorBuffer createTensorBuffer(
  31500. - DataType dataType, boolean isNormalized, ColorSpaceType colorSpaceType, float delta) {
  31501. - switch (colorSpaceType) {
  31502. - case RGB:
  31503. - return createRgbTensorBuffer(dataType, isNormalized, delta);
  31504. - case GRAYSCALE:
  31505. - return createGrayscaleTensorBuffer(dataType, isNormalized, delta);
  31506. - default:
  31507. - break;
  31508. + tensorImage.load(tensorBuffer, colorSpaceType);
  31509. + Bitmap bitmap = tensorImage.getBitmap();
  31510. +
  31511. + Bitmap expectedBitmap = createBitmap(colorSpaceType);
  31512. + assertThat(bitmap.sameAs(expectedBitmap)).isTrue();
  31513. }
  31514. - throw new IllegalArgumentException(
  31515. - "The ColorSpaceType, " + colorSpaceType + ", is unsupported.");
  31516. - }
  31517. -
  31518. - private static Bitmap createBitmap(ColorSpaceType colorSpaceType) {
  31519. - switch (colorSpaceType) {
  31520. - case RGB:
  31521. - return createRgbBitmap();
  31522. - case GRAYSCALE:
  31523. - return createGrayscaleBitmap();
  31524. - default:
  31525. - break;
  31526. +
  31527. + @Test
  31528. + public void loadAndGetBitmapSucceedsWithUINT8BufferFloatImage() {
  31529. + DataType tensorBufferDataType = UINT8;
  31530. + DataType tensorImageDataType = FLOAT32;
  31531. + boolean isNormalized = true;
  31532. + ColorSpaceType colorSpaceType = ColorSpaceType.GRAYSCALE;
  31533. +
  31534. + TensorBuffer tensorBuffer =
  31535. + createTensorBuffer(tensorBufferDataType, isNormalized, colorSpaceType, DELTA);
  31536. + TensorImage tensorImage = new TensorImage(tensorImageDataType);
  31537. +
  31538. + tensorImage.load(tensorBuffer, colorSpaceType);
  31539. + Bitmap bitmap = tensorImage.getBitmap();
  31540. +
  31541. + Bitmap expectedBitmap = createBitmap(colorSpaceType);
  31542. + assertThat(bitmap.sameAs(expectedBitmap)).isTrue();
  31543. + }
  31544. +
  31545. + @Test
  31546. + public void loadAndGetBitmapSucceedsWithUINT8BufferUINT8Image() {
  31547. + DataType tensorBufferDataType = UINT8;
  31548. + DataType tensorImageDataType = UINT8;
  31549. + boolean isNormalized = false;
  31550. + ColorSpaceType colorSpaceType = ColorSpaceType.GRAYSCALE;
  31551. +
  31552. + TensorBuffer tensorBuffer =
  31553. + createTensorBuffer(tensorBufferDataType, isNormalized, colorSpaceType, DELTA);
  31554. + TensorImage tensorImage = new TensorImage(tensorImageDataType);
  31555. +
  31556. + tensorImage.load(tensorBuffer, colorSpaceType);
  31557. + Bitmap bitmap = tensorImage.getBitmap();
  31558. +
  31559. + Bitmap expectedBitmap = createBitmap(colorSpaceType);
  31560. + assertThat(bitmap.sameAs(expectedBitmap)).isTrue();
  31561. + }
  31562. +
  31563. + private static TensorBuffer createTensorBuffer(
  31564. + DataType dataType, boolean isNormalized, ColorSpaceType colorSpaceType, float delta) {
  31565. + switch (colorSpaceType) {
  31566. + case RGB:
  31567. + return createRgbTensorBuffer(dataType, isNormalized, delta);
  31568. + case GRAYSCALE:
  31569. + return createGrayscaleTensorBuffer(dataType, isNormalized, delta);
  31570. + default:
  31571. + break;
  31572. + }
  31573. + throw new IllegalArgumentException(
  31574. + "The ColorSpaceType, " + colorSpaceType + ", is unsupported.");
  31575. + }
  31576. +
  31577. + private static Bitmap createBitmap(ColorSpaceType colorSpaceType) {
  31578. + switch (colorSpaceType) {
  31579. + case RGB:
  31580. + return createRgbBitmap();
  31581. + case GRAYSCALE:
  31582. + return createGrayscaleBitmap();
  31583. + default:
  31584. + break;
  31585. + }
  31586. + throw new IllegalArgumentException(
  31587. + "The ColorSpaceType, " + colorSpaceType + ", is unsupported.");
  31588. }
  31589. - throw new IllegalArgumentException(
  31590. - "The ColorSpaceType, " + colorSpaceType + ", is unsupported.");
  31591. - }
  31592. }
  31593. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/TensorImageTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/TensorImageTest.java
  31594. index f27edef4de779..b3130f4f2073c 100644
  31595. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/TensorImageTest.java
  31596. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/TensorImageTest.java
  31597. @@ -16,6 +16,7 @@ limitations under the License.
  31598. package org.tensorflow.lite.support.image;
  31599. import static com.google.common.truth.Truth.assertThat;
  31600. +
  31601. import static org.junit.Assert.assertArrayEquals;
  31602. import static org.junit.Assert.assertThrows;
  31603. import static org.mockito.Mockito.when;
  31604. @@ -31,9 +32,7 @@ import android.graphics.Bitmap.Config;
  31605. import android.graphics.Color;
  31606. import android.graphics.ImageFormat;
  31607. import android.media.Image;
  31608. -import java.nio.ByteBuffer;
  31609. -import java.util.Arrays;
  31610. -import java.util.Collection;
  31611. +
  31612. import org.junit.Before;
  31613. import org.junit.Test;
  31614. import org.junit.runner.RunWith;
  31615. @@ -48,713 +47,689 @@ import org.robolectric.RobolectricTestRunner;
  31616. import org.tensorflow.lite.DataType;
  31617. import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
  31618. +import java.nio.ByteBuffer;
  31619. +import java.util.Arrays;
  31620. +import java.util.Collection;
  31621. +
  31622. /** Tests of {@link org.tensorflow.lite.support.image.TensorImage}. */
  31623. @RunWith(Suite.class)
  31624. -@SuiteClasses({
  31625. - TensorImageTest.General.class,
  31626. - TensorImageTest.LoadTensorBufferWithRgbAndGrayscale.class,
  31627. - TensorImageTest.LoadTensorBufferWithInvalidShapeTest.class,
  31628. - TensorImageTest.LoadTensorBufferWithYUV.class,
  31629. - TensorImageTest.LoadTensorBufferWithImageProperties.class
  31630. -})
  31631. +@SuiteClasses(
  31632. + {TensorImageTest.General.class, TensorImageTest.LoadTensorBufferWithRgbAndGrayscale.class,
  31633. + TensorImageTest.LoadTensorBufferWithInvalidShapeTest.class,
  31634. + TensorImageTest.LoadTensorBufferWithYUV.class,
  31635. + TensorImageTest.LoadTensorBufferWithImageProperties.class})
  31636. public class TensorImageTest {
  31637. -
  31638. - @RunWith(RobolectricTestRunner.class)
  31639. - public static final class General extends TensorImageTest {
  31640. -
  31641. - private static final Bitmap exampleBitmap = createExampleBitmap();
  31642. - private static final float[] exampleFloatPixels = createExampleFloatPixels();
  31643. - private static final int[] exampleUint8Pixels = createExampleUint8Pixels();
  31644. -
  31645. - private static final int EXAMPLE_WIDTH = 5;
  31646. - private static final int EXAMPLE_HEIGHT = 10;
  31647. - private static final int EXAMPLE_NUM_PIXELS = EXAMPLE_HEIGHT * EXAMPLE_WIDTH;
  31648. - private static final int EXAMPLE_NUM_CHANNELS = 3;
  31649. - private static final int[] EXAMPLE_SHAPE = {
  31650. - EXAMPLE_HEIGHT, EXAMPLE_WIDTH, EXAMPLE_NUM_CHANNELS
  31651. - };
  31652. - private static final float MEAN = 127.5f;
  31653. - private static final float STDDEV = 127.5f;
  31654. -
  31655. - @Mock Image imageMock;
  31656. -
  31657. - @Before
  31658. - public void setUp() {
  31659. - MockitoAnnotations.initMocks(this);
  31660. - }
  31661. -
  31662. - @Test
  31663. - public void defaultConstructorCreatesUint8TensorImage() {
  31664. - TensorImage image = new TensorImage();
  31665. - assertThat(image.getDataType()).isEqualTo(UINT8);
  31666. - }
  31667. -
  31668. - @Test
  31669. - public void createFromSucceedsWithUint8TensorImage() {
  31670. - TensorImage uint8Image = new TensorImage(UINT8);
  31671. - uint8Image.load(new int[] {1, 2, 3, 4, -5, 600}, new int[] {2, 1, 3});
  31672. -
  31673. - TensorImage floatImage = TensorImage.createFrom(uint8Image, FLOAT32);
  31674. - float[] pixels = floatImage.getTensorBuffer().getFloatArray();
  31675. - assertThat(pixels).isEqualTo(new float[] {1.0f, 2.0f, 3.0f, 4.0f, 0.0f, 255.0f});
  31676. - }
  31677. -
  31678. - @Test
  31679. - public void createFromSucceedsWithFloatTensorImage() {
  31680. - TensorImage floatImage = new TensorImage(FLOAT32);
  31681. - floatImage.load(new float[] {1, 2.495f, 3.5f, 4.5f, -5, 600}, new int[] {2, 1, 3});
  31682. -
  31683. - TensorImage uint8Image = TensorImage.createFrom(floatImage, UINT8);
  31684. - int[] pixels = uint8Image.getTensorBuffer().getIntArray();
  31685. - assertThat(pixels).isEqualTo(new int[] {1, 2, 3, 4, 0, 255});
  31686. - }
  31687. -
  31688. - @Test
  31689. - public void loadBitmapSucceedsWithUint8TensorImage() {
  31690. - Bitmap rgbBitmap = createRgbBitmap();
  31691. - TensorBuffer rgbTensorBuffer = createRgbTensorBuffer(UINT8, false, 0.0f);
  31692. - TensorImage uint8Image = new TensorImage(UINT8);
  31693. -
  31694. - uint8Image.load(rgbBitmap);
  31695. - assertThat(uint8Image.getBitmap().sameAs(rgbBitmap)).isTrue();
  31696. - assertEqualTensorBuffers(uint8Image.getTensorBuffer(), rgbTensorBuffer);
  31697. - assertThat(uint8Image.getDataType()).isEqualTo(UINT8);
  31698. - }
  31699. -
  31700. - @Test
  31701. - public void loadBitmapSucceedsWithFloatTensorImage() {
  31702. - Bitmap rgbBitmap = createRgbBitmap();
  31703. - TensorBuffer rgbTensorBuffer = createRgbTensorBuffer(FLOAT32, false, 0.0f);
  31704. - TensorImage floatImage = new TensorImage(FLOAT32);
  31705. -
  31706. - floatImage.load(rgbBitmap);
  31707. - assertThat(floatImage.getBitmap().sameAs(rgbBitmap)).isTrue();
  31708. - assertEqualTensorBuffers(floatImage.getTensorBuffer(), rgbTensorBuffer);
  31709. - assertThat(floatImage.getDataType()).isEqualTo(FLOAT32);
  31710. - }
  31711. -
  31712. - @Test
  31713. - public void loadFloatArrayWithUint8TensorImage() {
  31714. - TensorImage uint8Image = new TensorImage(UINT8);
  31715. -
  31716. - uint8Image.load(exampleFloatPixels, EXAMPLE_SHAPE);
  31717. - assertThat(uint8Image.getBitmap()).isNotNull();
  31718. - assertThat(uint8Image.getTensorBuffer().getFloatArray())
  31719. - .isEqualTo(
  31720. - new float
  31721. - [exampleFloatPixels
  31722. - .length]); // All zero because of normalization and casting when loading.
  31723. - }
  31724. -
  31725. - @Test
  31726. - public void loadFloatArrayWithFloatTensorImage() {
  31727. - TensorImage floatImage = new TensorImage(FLOAT32);
  31728. -
  31729. - floatImage.load(exampleFloatPixels, EXAMPLE_SHAPE);
  31730. - assertThat(floatImage.getTensorBuffer().getFloatArray()).isEqualTo(exampleFloatPixels);
  31731. - }
  31732. -
  31733. - @Test
  31734. - public void loadUint8ArrayWithUint8TensorImage() {
  31735. - TensorImage uint8Image = new TensorImage(UINT8);
  31736. -
  31737. - uint8Image.load(exampleUint8Pixels, EXAMPLE_SHAPE);
  31738. - assertThat(uint8Image.getBitmap().sameAs(exampleBitmap)).isTrue();
  31739. - assertThat(uint8Image.getTensorBuffer().getIntArray()).isEqualTo(exampleUint8Pixels);
  31740. - }
  31741. -
  31742. - @Test
  31743. - public void loadUint8ArrayWithFloatTensorImage() {
  31744. - TensorImage floatImage = new TensorImage(FLOAT32);
  31745. -
  31746. - floatImage.load(exampleUint8Pixels, EXAMPLE_SHAPE);
  31747. - assertThat(floatImage.getTensorBuffer().getIntArray()).isEqualTo(exampleUint8Pixels);
  31748. - }
  31749. -
  31750. - @Test
  31751. - public void loadTensorBufferWithUint8TensorImage() {
  31752. - TensorImage uint8Image = new TensorImage(UINT8);
  31753. -
  31754. - uint8Image.load(exampleBitmap);
  31755. - TensorBuffer buffer = uint8Image.getTensorBuffer();
  31756. - uint8Image.load(buffer);
  31757. - assertThat(uint8Image.getBitmap().sameAs(exampleBitmap)).isTrue();
  31758. - }
  31759. -
  31760. - @Test
  31761. - public void loadTensorBufferWithFloatTensorImage() {
  31762. - TensorImage floatImage = new TensorImage(FLOAT32);
  31763. -
  31764. - floatImage.load(exampleBitmap);
  31765. - TensorBuffer buffer = floatImage.getTensorBuffer();
  31766. - floatImage.load(buffer);
  31767. - assertThat(floatImage.getTensorBuffer().getIntArray()).isEqualTo(exampleUint8Pixels);
  31768. - }
  31769. -
  31770. - @Test
  31771. - public void loadAndGetMediaImageSucceedsWithYuv420888Format() {
  31772. - setUpImageMock(imageMock, ImageFormat.YUV_420_888);
  31773. - TensorImage tensorImage = new TensorImage(UINT8);
  31774. -
  31775. - tensorImage.load(imageMock);
  31776. - Image imageReturned = tensorImage.getMediaImage();
  31777. -
  31778. - assertThat(imageReturned).isEqualTo(imageMock);
  31779. - }
  31780. -
  31781. - @Test
  31782. - public void loadMediaImageFailsWithNonYuv420888Format() {
  31783. - setUpImageMock(imageMock, ImageFormat.YUV_422_888);
  31784. - TensorImage tensorImage = new TensorImage(UINT8);
  31785. -
  31786. - IllegalArgumentException exception =
  31787. - assertThrows(IllegalArgumentException.class, () -> tensorImage.load(imageMock));
  31788. - assertThat(exception).hasMessageThat().contains("Only supports loading YUV_420_888 Image.");
  31789. - }
  31790. -
  31791. - @Test
  31792. - public void getBitmapWithUint8TensorImage() {
  31793. - TensorImage uint8Image = new TensorImage(UINT8);
  31794. -
  31795. - uint8Image.load(exampleBitmap);
  31796. - assertThat(uint8Image.getBitmap().sameAs(exampleBitmap)).isTrue();
  31797. - // Also check zero copy is effective here (input and output are references of the same
  31798. - // object).
  31799. - assertThat(uint8Image.getBitmap()).isSameInstanceAs(exampleBitmap);
  31800. - // Also check we don't create new Bitmap only with reading operations.
  31801. - assertThat(uint8Image.getBuffer().limit())
  31802. - .isEqualTo(EXAMPLE_NUM_PIXELS * EXAMPLE_NUM_CHANNELS);
  31803. - assertThat(uint8Image.getBitmap()).isSameInstanceAs(exampleBitmap);
  31804. -
  31805. - uint8Image.load(exampleUint8Pixels, EXAMPLE_SHAPE);
  31806. - assertThat(uint8Image.getBitmap()).isNotSameInstanceAs(exampleBitmap);
  31807. - }
  31808. -
  31809. - @Test
  31810. - public void getBitmapWithFloatTensorImage() {
  31811. - TensorImage floatImage = new TensorImage(FLOAT32);
  31812. -
  31813. - floatImage.load(exampleBitmap);
  31814. - assertThat(floatImage.getBitmap()).isSameInstanceAs(exampleBitmap);
  31815. - }
  31816. -
  31817. - @Test
  31818. - public void getBitmapWithEmptyTensorImage() {
  31819. - TensorImage uint8Image = new TensorImage(UINT8);
  31820. -
  31821. - assertThrows(IllegalStateException.class, uint8Image::getBitmap);
  31822. - }
  31823. -
  31824. - @Test
  31825. - public void getMediaImageFailsWithBackedBitmap() {
  31826. - TensorImage tensorImage = TensorImage.fromBitmap(exampleBitmap);
  31827. -
  31828. - UnsupportedOperationException exception =
  31829. - assertThrows(UnsupportedOperationException.class, () -> tensorImage.getMediaImage());
  31830. - assertThat(exception)
  31831. - .hasMessageThat()
  31832. - .contains("Converting from Bitmap to android.media.Image is unsupported.");
  31833. - }
  31834. -
  31835. - @Test
  31836. - public void getMediaImageFailsWithBackedTensorBuffer() {
  31837. - TensorImage tensorImage = new TensorImage(UINT8);
  31838. - tensorImage.load(exampleFloatPixels, EXAMPLE_SHAPE);
  31839. -
  31840. - UnsupportedOperationException exception =
  31841. - assertThrows(UnsupportedOperationException.class, () -> tensorImage.getMediaImage());
  31842. - assertThat(exception)
  31843. - .hasMessageThat()
  31844. - .contains("Converting from TensorBuffer to android.media.Image is unsupported.");
  31845. - }
  31846. -
  31847. - @Test
  31848. - public void getShapeOfInternalBitmapShouldSuccess() {
  31849. - Bitmap bitmap = Bitmap.createBitmap(300, 400, Config.ARGB_8888);
  31850. - TensorImage image = TensorImage.fromBitmap(bitmap);
  31851. -
  31852. - int width = image.getWidth();
  31853. - int height = image.getHeight();
  31854. -
  31855. - assertThat(width).isEqualTo(300);
  31856. - assertThat(height).isEqualTo(400);
  31857. - }
  31858. -
  31859. - @Test
  31860. - public void getShapeOfInternalTensorBufferShouldSuccess() {
  31861. - TensorBuffer buffer = TensorBuffer.createFixedSize(new int[] {1, 400, 300, 3}, UINT8);
  31862. - TensorImage image = new TensorImage();
  31863. - image.load(buffer);
  31864. -
  31865. - int width = image.getWidth();
  31866. - int height = image.getHeight();
  31867. -
  31868. - assertThat(width).isEqualTo(300);
  31869. - assertThat(height).isEqualTo(400);
  31870. - }
  31871. -
  31872. - @Test
  31873. - public void getShapeOfNullImageShouldThrow() {
  31874. - TensorImage image = new TensorImage();
  31875. -
  31876. - assertThrows(IllegalStateException.class, image::getHeight);
  31877. - }
  31878. -
  31879. - @Test
  31880. - public void getShapeOfACorruptedBufferShouldThrowRatherThanCrash() {
  31881. - int[] data = new int[] {1, 2, 3, 4, 5, 6};
  31882. - TensorBuffer buffer = TensorBuffer.createDynamic(UINT8);
  31883. - buffer.loadArray(data, new int[] {1, 1, 2, 3});
  31884. - TensorImage image = new TensorImage();
  31885. - image.load(buffer);
  31886. - // Reload data but with an invalid shape, which leads to `buffer` corrupted.
  31887. - int[] newData = new int[] {1, 2, 3};
  31888. - buffer.loadArray(newData, new int[] {1, 1, 1, 3});
  31889. -
  31890. - assertThrows(IllegalArgumentException.class, image::getHeight);
  31891. - }
  31892. -
  31893. - @Test
  31894. - public void getColorSpaceTypeSucceedsWithBitmapARGB_8888() {
  31895. - Bitmap rgbBitmap = createRgbBitmap();
  31896. - TensorImage tensorImage = TensorImage.fromBitmap(rgbBitmap);
  31897. -
  31898. - assertThat(tensorImage.getColorSpaceType()).isEqualTo(ColorSpaceType.RGB);
  31899. - }
  31900. -
  31901. - @Test
  31902. - public void getColorSpaceTypeSucceedsWithRgbTensorBuffer() {
  31903. - TensorBuffer rgbBuffer = createRgbTensorBuffer(UINT8, false);
  31904. - TensorImage tensorImage = new TensorImage();
  31905. - tensorImage.load(rgbBuffer);
  31906. -
  31907. - assertThat(tensorImage.getColorSpaceType()).isEqualTo(ColorSpaceType.RGB);
  31908. - }
  31909. -
  31910. - @Test
  31911. - public void getColorSpaceTypeSucceedsWithGrayscaleTensorBuffer() {
  31912. - TensorBuffer grayBuffer = createGrayscaleTensorBuffer(UINT8, false);
  31913. - TensorImage tensorImage = new TensorImage();
  31914. - tensorImage.load(grayBuffer, ColorSpaceType.GRAYSCALE);
  31915. -
  31916. - assertThat(tensorImage.getColorSpaceType()).isEqualTo(ColorSpaceType.GRAYSCALE);
  31917. - }
  31918. -
  31919. - @Test
  31920. - public void getColorSpaceTypeSucceedsWithRepeatedLoading() {
  31921. - TensorBuffer grayBuffer = createGrayscaleTensorBuffer(UINT8, false);
  31922. - TensorBuffer rgbBuffer = createRgbTensorBuffer(UINT8, false);
  31923. - Bitmap rgbBitmap = createRgbBitmap();
  31924. - TensorImage tensorImage = new TensorImage();
  31925. -
  31926. - tensorImage.load(rgbBuffer);
  31927. - assertThat(tensorImage.getColorSpaceType()).isEqualTo(ColorSpaceType.RGB);
  31928. - tensorImage.load(grayBuffer, ColorSpaceType.GRAYSCALE);
  31929. - assertThat(tensorImage.getColorSpaceType()).isEqualTo(ColorSpaceType.GRAYSCALE);
  31930. - tensorImage.load(rgbBitmap);
  31931. - assertThat(tensorImage.getColorSpaceType()).isEqualTo(ColorSpaceType.RGB);
  31932. - }
  31933. -
  31934. - @Test
  31935. - public void getColorSpaceTypeFailsWhenNoImageHasBeenLoaded() {
  31936. - TensorImage tensorImage = new TensorImage();
  31937. -
  31938. - IllegalStateException exception =
  31939. - assertThrows(IllegalStateException.class, tensorImage::getColorSpaceType);
  31940. - assertThat(exception).hasMessageThat().contains("No image has been loaded yet.");
  31941. - }
  31942. -
  31943. - /**
  31944. - * Creates an example bit map, which is a 10x10 ARGB bitmap and pixels are set by: pixel[i] =
  31945. - * {A: 0, B: i + 2, G: i + 1, G: i}, where i is the flatten index
  31946. - */
  31947. - private static Bitmap createExampleBitmap() {
  31948. - int[] colors = new int[EXAMPLE_NUM_PIXELS];
  31949. - for (int i = 0; i < EXAMPLE_NUM_PIXELS; i++) {
  31950. - colors[i] = Color.rgb(i, i + 1, i + 2);
  31951. - }
  31952. -
  31953. - return Bitmap.createBitmap(colors, EXAMPLE_WIDTH, EXAMPLE_HEIGHT, Bitmap.Config.ARGB_8888);
  31954. - }
  31955. -
  31956. - private static float[] createExampleFloatPixels() {
  31957. - float[] pixels = new float[EXAMPLE_NUM_PIXELS * EXAMPLE_NUM_CHANNELS];
  31958. - for (int i = 0, j = 0; i < EXAMPLE_NUM_PIXELS; i++) {
  31959. - pixels[j++] = (i - MEAN) / STDDEV;
  31960. - pixels[j++] = (i + 1 - MEAN) / STDDEV;
  31961. - pixels[j++] = (i + 2 - MEAN) / STDDEV;
  31962. - }
  31963. - return pixels;
  31964. - }
  31965. -
  31966. - private static int[] createExampleUint8Pixels() {
  31967. - int[] pixels = new int[EXAMPLE_NUM_PIXELS * EXAMPLE_NUM_CHANNELS];
  31968. - for (int i = 0, j = 0; i < EXAMPLE_NUM_PIXELS; i++) {
  31969. - pixels[j++] = i;
  31970. - pixels[j++] = i + 1;
  31971. - pixels[j++] = i + 2;
  31972. - }
  31973. - return pixels;
  31974. - }
  31975. - }
  31976. -
  31977. - /** Parameterized tests for loading TensorBuffers with RGB and Grayscale images. */
  31978. - @RunWith(ParameterizedRobolectricTestRunner.class)
  31979. - public static final class LoadTensorBufferWithRgbAndGrayscale extends TensorImageTest {
  31980. -
  31981. - /**
  31982. - * Difference between the pair of float and uint8 values. It is used to test the data
  31983. - * conversion.
  31984. - */
  31985. - private static final float DELTA = 0.1f;
  31986. -
  31987. - /** The data type that used to create the TensorBuffer. */
  31988. - @Parameter(0)
  31989. - public DataType tensorBufferDataType;
  31990. -
  31991. - /** Indicates whether the shape is in the normalized form of (1, h, w, 3). */
  31992. - @Parameter(1)
  31993. - public boolean isNormalized;
  31994. -
  31995. - /** The color space type of the TensorBuffer. */
  31996. - @Parameter(2)
  31997. - public ColorSpaceType colorSpaceType;
  31998. -
  31999. - /** The data type that used to create the TensorImage. */
  32000. - @Parameter(3)
  32001. - public DataType tensorImageDataType;
  32002. -
  32003. - @Parameters(
  32004. - name =
  32005. - "tensorBufferDataType={0}; isNormalized={1}; colorSpaceType={2};"
  32006. - + " tensorImageDataType={3}")
  32007. - public static Collection<Object[]> data() {
  32008. - return Arrays.asList(
  32009. - new Object[][] {
  32010. - {FLOAT32, true, ColorSpaceType.RGB, FLOAT32},
  32011. - {FLOAT32, false, ColorSpaceType.RGB, UINT8},
  32012. - {UINT8, true, ColorSpaceType.RGB, FLOAT32},
  32013. - {UINT8, false, ColorSpaceType.RGB, UINT8},
  32014. - });
  32015. - }
  32016. -
  32017. - @Test
  32018. - public void loadAndGetBitmapSucceedsWithTensorBufferAndColorSpaceType() {
  32019. - TensorBuffer tensorBuffer =
  32020. - createTensorBuffer(tensorBufferDataType, isNormalized, colorSpaceType, DELTA);
  32021. - TensorImage tensorImage = new TensorImage(tensorImageDataType);
  32022. -
  32023. - tensorImage.load(tensorBuffer, colorSpaceType);
  32024. - Bitmap bitmap = tensorImage.getBitmap();
  32025. -
  32026. - Bitmap expectedBitmap = createBitmap(colorSpaceType);
  32027. - assertThat(bitmap.sameAs(expectedBitmap)).isTrue();
  32028. - }
  32029. -
  32030. - @Test
  32031. - public void loadAndGetTensorBufferSucceedsWithTensorBufferAndColorSpaceType() {
  32032. - TensorBuffer tensorBuffer =
  32033. - createTensorBuffer(tensorBufferDataType, isNormalized, colorSpaceType, DELTA);
  32034. - TensorImage tensorImage = new TensorImage(tensorImageDataType);
  32035. -
  32036. - tensorImage.load(tensorBuffer, colorSpaceType);
  32037. - TensorBuffer buffer = tensorImage.getTensorBuffer();
  32038. -
  32039. - // If tensorBufferDataType is UINT8, expectedTensorBuffer should not contain delta.
  32040. - float expectedResidual = tensorBufferDataType == UINT8 ? 0.f : DELTA;
  32041. - TensorBuffer expectedTensorBuffer =
  32042. - createTensorBuffer(tensorImageDataType, isNormalized, colorSpaceType, expectedResidual);
  32043. - assertEqualTensorBuffers(buffer, expectedTensorBuffer);
  32044. - }
  32045. -
  32046. - private static TensorBuffer createTensorBuffer(
  32047. - DataType dataType, boolean isNormalized, ColorSpaceType colorSpaceType, float delta) {
  32048. - switch (colorSpaceType) {
  32049. - case RGB:
  32050. - return createRgbTensorBuffer(dataType, isNormalized, delta);
  32051. - case GRAYSCALE:
  32052. - return createGrayscaleTensorBuffer(dataType, isNormalized, delta);
  32053. - default:
  32054. - break;
  32055. - }
  32056. - throw new IllegalArgumentException(
  32057. - "The ColorSpaceType, " + colorSpaceType + ", is unsupported.");
  32058. - }
  32059. -
  32060. - private static Bitmap createBitmap(ColorSpaceType colorSpaceType) {
  32061. - switch (colorSpaceType) {
  32062. - case RGB:
  32063. - return createRgbBitmap();
  32064. - case GRAYSCALE:
  32065. - return createGrayscaleBitmap();
  32066. - default:
  32067. - break;
  32068. - }
  32069. - throw new IllegalArgumentException(
  32070. - "The ColorSpaceType, " + colorSpaceType + ", is unsupported.");
  32071. - }
  32072. - }
  32073. -
  32074. - /** Parameterized tests for loading TensorBuffers with YUV images. */
  32075. - @RunWith(ParameterizedRobolectricTestRunner.class)
  32076. - public static final class LoadTensorBufferWithYUV extends TensorImageTest {
  32077. -
  32078. - private static final int HEIGHT = 2;
  32079. - private static final int WIDTH = 3;
  32080. -
  32081. - @Parameter(0)
  32082. - public ColorSpaceType colorSpaceType;
  32083. -
  32084. - @Parameters(name = "colorSpaceType={0}")
  32085. - public static Collection<Object[]> data() {
  32086. - return Arrays.asList(
  32087. - new Object[][] {
  32088. - {ColorSpaceType.NV12},
  32089. - {ColorSpaceType.NV21},
  32090. - {ColorSpaceType.YV12},
  32091. - {ColorSpaceType.YV21},
  32092. - });
  32093. - }
  32094. -
  32095. - @Test
  32096. - public void loadTensorBufferWithColorSpaceShouldFail() {
  32097. - TensorImage tensorImage = new TensorImage();
  32098. -
  32099. - IllegalArgumentException exception =
  32100. - assertThrows(
  32101. - IllegalArgumentException.class,
  32102. - () -> tensorImage.load(TensorBuffer.createDynamic(DataType.FLOAT32), colorSpaceType));
  32103. - assertThat(exception)
  32104. - .hasMessageThat()
  32105. - .contains(
  32106. - "Only ColorSpaceType.RGB and ColorSpaceType.GRAYSCALE are supported. Use"
  32107. - + " `load(TensorBuffer, ImageProperties)` for other color space types.");
  32108. - }
  32109. -
  32110. - @Test
  32111. - public void loadTensorBufferAndGetBitmapShouldFail() {
  32112. - int[] data = new int[colorSpaceType.getNumElements(HEIGHT, WIDTH)];
  32113. - TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.FLOAT32);
  32114. - tensorBuffer.loadArray(data, new int[] {data.length});
  32115. -
  32116. - ImageProperties imageProperties =
  32117. - ImageProperties.builder()
  32118. - .setHeight(HEIGHT)
  32119. - .setWidth(WIDTH)
  32120. - .setColorSpaceType(colorSpaceType)
  32121. - .build();
  32122. -
  32123. - TensorImage tensorImage = new TensorImage(DataType.FLOAT32);
  32124. - tensorImage.load(tensorBuffer, imageProperties);
  32125. -
  32126. - UnsupportedOperationException exception =
  32127. - assertThrows(UnsupportedOperationException.class, () -> tensorImage.getBitmap());
  32128. - assertThat(exception)
  32129. - .hasMessageThat()
  32130. - .contains(
  32131. - "convertTensorBufferToBitmap() is unsupported for the color space type "
  32132. - + colorSpaceType.name());
  32133. - }
  32134. - }
  32135. -
  32136. - /** Parameterized tests for loading TensorBuffers with ImageProperties. */
  32137. - @RunWith(ParameterizedRobolectricTestRunner.class)
  32138. - public static final class LoadTensorBufferWithImageProperties extends TensorImageTest {
  32139. -
  32140. - private static final int HEIGHT = 2;
  32141. - private static final int WIDTH = 3;
  32142. - private static final int WRONG_WIDTH = 10;
  32143. -
  32144. - @Parameter(0)
  32145. - public ColorSpaceType colorSpaceType;
  32146. -
  32147. - @Parameters(name = "colorSpaceType={0}")
  32148. - public static Collection<Object[]> data() {
  32149. - return Arrays.asList(
  32150. - new Object[][] {
  32151. - {ColorSpaceType.RGB},
  32152. - {ColorSpaceType.GRAYSCALE},
  32153. - {ColorSpaceType.NV12},
  32154. - {ColorSpaceType.NV21},
  32155. - {ColorSpaceType.YV12},
  32156. - {ColorSpaceType.YV21},
  32157. - });
  32158. - }
  32159. -
  32160. - @Test
  32161. - public void loadAndGetTensorBufferShouldSucceedWithCorrectProperties() {
  32162. - int[] data = new int[colorSpaceType.getNumElements(HEIGHT, WIDTH)];
  32163. - TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.FLOAT32);
  32164. - tensorBuffer.loadArray(data, new int[] {data.length});
  32165. -
  32166. - ImageProperties imageProperties =
  32167. - ImageProperties.builder()
  32168. - .setHeight(HEIGHT)
  32169. - .setWidth(WIDTH)
  32170. - .setColorSpaceType(colorSpaceType)
  32171. - .build();
  32172. -
  32173. - TensorImage tensorImage = new TensorImage(DataType.FLOAT32);
  32174. - tensorImage.load(tensorBuffer, imageProperties);
  32175. -
  32176. - assertEqualTensorBuffers(tensorImage.getTensorBuffer(), tensorBuffer);
  32177. - }
  32178. -
  32179. - @Test
  32180. - public void loadAndGetTensorBufferShouldSucceedWithLargerBuffer() {
  32181. - // Should allow buffer to be greater than the size specified by height and width.
  32182. - int moreElements = 1;
  32183. - int[] data = new int[colorSpaceType.getNumElements(HEIGHT, WIDTH) + moreElements];
  32184. - TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.FLOAT32);
  32185. - tensorBuffer.loadArray(data, new int[] {data.length});
  32186. -
  32187. - ImageProperties imageProperties =
  32188. - ImageProperties.builder()
  32189. - .setHeight(HEIGHT)
  32190. - .setWidth(WIDTH)
  32191. - .setColorSpaceType(colorSpaceType)
  32192. - .build();
  32193. -
  32194. - TensorImage tensorImage = new TensorImage(DataType.FLOAT32);
  32195. - tensorImage.load(tensorBuffer, imageProperties);
  32196. -
  32197. - assertEqualTensorBuffers(tensorImage.getTensorBuffer(), tensorBuffer);
  32198. - }
  32199. -
  32200. - @Test
  32201. - public void loadAndGetByteBufferShouldSucceedWithCorrectProperties() {
  32202. - ByteBuffer byteBuffer = ByteBuffer.allocate(colorSpaceType.getNumElements(HEIGHT, WIDTH));
  32203. -
  32204. - ImageProperties imageProperties =
  32205. - ImageProperties.builder()
  32206. - .setHeight(HEIGHT)
  32207. - .setWidth(WIDTH)
  32208. - .setColorSpaceType(colorSpaceType)
  32209. - .build();
  32210. -
  32211. - TensorImage tensorImage = new TensorImage(DataType.UINT8);
  32212. - tensorImage.load(byteBuffer, imageProperties);
  32213. -
  32214. - assertEqualByteBuffers(tensorImage.getBuffer(), byteBuffer);
  32215. - }
  32216. -
  32217. - @Test
  32218. - public void loadTensorBufferWithShouldFailWithWrongImageShape() {
  32219. - int[] data = new int[colorSpaceType.getNumElements(HEIGHT, WIDTH)];
  32220. - TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.FLOAT32);
  32221. - tensorBuffer.loadArray(data, new int[] {data.length});
  32222. -
  32223. - ImageProperties imageProperties =
  32224. - ImageProperties.builder()
  32225. - .setHeight(HEIGHT)
  32226. - .setWidth(WRONG_WIDTH)
  32227. - .setColorSpaceType(colorSpaceType)
  32228. - .build();
  32229. -
  32230. - TensorImage tensorImage = new TensorImage(DataType.FLOAT32);
  32231. -
  32232. - IllegalArgumentException exception =
  32233. - assertThrows(
  32234. - IllegalArgumentException.class,
  32235. - () -> tensorImage.load(tensorBuffer, imageProperties));
  32236. - assertThat(exception)
  32237. - .hasMessageThat()
  32238. - .contains(
  32239. - String.format(
  32240. - "The given number of elements (%d) does not match the image (%s) in %d x %d. The"
  32241. - + " expected number of elements should be at least %d.",
  32242. - data.length,
  32243. - colorSpaceType.name(),
  32244. - HEIGHT,
  32245. - WRONG_WIDTH,
  32246. - colorSpaceType.getNumElements(HEIGHT, WRONG_WIDTH)));
  32247. - }
  32248. -
  32249. - @Test
  32250. - public void getShapeOfInternalTensorBufferShouldSuccess() {
  32251. - int[] data = new int[colorSpaceType.getNumElements(HEIGHT, WIDTH)];
  32252. - TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.FLOAT32);
  32253. - tensorBuffer.loadArray(data, new int[] {data.length});
  32254. -
  32255. - ImageProperties imageProperties =
  32256. - ImageProperties.builder()
  32257. - .setHeight(HEIGHT)
  32258. - .setWidth(WIDTH)
  32259. - .setColorSpaceType(colorSpaceType)
  32260. - .build();
  32261. -
  32262. - TensorImage tensorImage = new TensorImage(DataType.FLOAT32);
  32263. - tensorImage.load(tensorBuffer, imageProperties);
  32264. -
  32265. - assertThat(tensorImage.getWidth()).isEqualTo(WIDTH);
  32266. - assertThat(tensorImage.getHeight()).isEqualTo(HEIGHT);
  32267. - }
  32268. - }
  32269. -
  32270. - /** Parameterized tests for loading TensorBuffer with invalid shapes. */
  32271. - @RunWith(ParameterizedRobolectricTestRunner.class)
  32272. - public static final class LoadTensorBufferWithInvalidShapeTest extends TensorImageTest {
  32273. -
  32274. - private static final String RGB_ASSERT_SHAPE_MESSAGE =
  32275. - "The shape of a RGB image should be (h, w, c) or (1, h, w, c), and channels"
  32276. - + " representing R, G, B in order. The provided image shape is ";
  32277. - private static final String GRAYSCALE_ASSERT_SHAPE_MESSAGE =
  32278. - "The shape of a grayscale image should be (h, w) or (1, h, w, 1). The provided image"
  32279. - + " shape is ";
  32280. -
  32281. - @Parameter(0)
  32282. - public ColorSpaceType colorSpaceType;
  32283. -
  32284. - /** The shape that does not match the colorSpaceType. */
  32285. - @Parameter(1)
  32286. - public int[] invalidShape;
  32287. -
  32288. - @Parameter(2)
  32289. - public String errorMessage;
  32290. -
  32291. - @Parameters(name = "colorSpaceType={0}; invalidShape={1}")
  32292. - public static Collection<Object[]> data() {
  32293. - return Arrays.asList(
  32294. - new Object[][] {
  32295. - {ColorSpaceType.RGB, new int[] {2, 10, 20, 3}, RGB_ASSERT_SHAPE_MESSAGE},
  32296. - {ColorSpaceType.RGB, new int[] {1, 10, 20, 3, 4}, RGB_ASSERT_SHAPE_MESSAGE},
  32297. - {ColorSpaceType.RGB, new int[] {1, 10, 20, 5}, RGB_ASSERT_SHAPE_MESSAGE},
  32298. - {ColorSpaceType.RGB, new int[] {1, 10, 20}, RGB_ASSERT_SHAPE_MESSAGE},
  32299. - {ColorSpaceType.RGB, new int[] {10, 20, 3, 4}, RGB_ASSERT_SHAPE_MESSAGE},
  32300. - {ColorSpaceType.RGB, new int[] {10, 20, 5}, RGB_ASSERT_SHAPE_MESSAGE},
  32301. - {ColorSpaceType.RGB, new int[] {10, 20}, RGB_ASSERT_SHAPE_MESSAGE},
  32302. - {ColorSpaceType.GRAYSCALE, new int[] {2, 10, 20}, GRAYSCALE_ASSERT_SHAPE_MESSAGE},
  32303. - {ColorSpaceType.GRAYSCALE, new int[] {1, 10, 20, 3}, GRAYSCALE_ASSERT_SHAPE_MESSAGE},
  32304. - {ColorSpaceType.GRAYSCALE, new int[] {10, 20, 4}, GRAYSCALE_ASSERT_SHAPE_MESSAGE},
  32305. - {ColorSpaceType.GRAYSCALE, new int[] {10}, GRAYSCALE_ASSERT_SHAPE_MESSAGE},
  32306. - });
  32307. - }
  32308. -
  32309. - @Test
  32310. - public void loadTensorBufferWithInvalidShape() {
  32311. - TensorBuffer tensorBuffer = TensorBuffer.createFixedSize(invalidShape, UINT8);
  32312. - TensorImage tensorImage = new TensorImage();
  32313. -
  32314. - IllegalArgumentException exception =
  32315. - assertThrows(
  32316. - IllegalArgumentException.class, () -> tensorImage.load(tensorBuffer, colorSpaceType));
  32317. - assertThat(exception).hasMessageThat().contains(errorMessage + Arrays.toString(invalidShape));
  32318. + @RunWith(RobolectricTestRunner.class)
  32319. + public static final class General extends TensorImageTest {
  32320. + private static final Bitmap exampleBitmap = createExampleBitmap();
  32321. + private static final float[] exampleFloatPixels = createExampleFloatPixels();
  32322. + private static final int[] exampleUint8Pixels = createExampleUint8Pixels();
  32323. +
  32324. + private static final int EXAMPLE_WIDTH = 5;
  32325. + private static final int EXAMPLE_HEIGHT = 10;
  32326. + private static final int EXAMPLE_NUM_PIXELS = EXAMPLE_HEIGHT * EXAMPLE_WIDTH;
  32327. + private static final int EXAMPLE_NUM_CHANNELS = 3;
  32328. + private static final int[] EXAMPLE_SHAPE = {
  32329. + EXAMPLE_HEIGHT, EXAMPLE_WIDTH, EXAMPLE_NUM_CHANNELS};
  32330. + private static final float MEAN = 127.5f;
  32331. + private static final float STDDEV = 127.5f;
  32332. +
  32333. + @Mock
  32334. + Image imageMock;
  32335. +
  32336. + @Before
  32337. + public void setUp() {
  32338. + MockitoAnnotations.initMocks(this);
  32339. + }
  32340. +
  32341. + @Test
  32342. + public void defaultConstructorCreatesUint8TensorImage() {
  32343. + TensorImage image = new TensorImage();
  32344. + assertThat(image.getDataType()).isEqualTo(UINT8);
  32345. + }
  32346. +
  32347. + @Test
  32348. + public void createFromSucceedsWithUint8TensorImage() {
  32349. + TensorImage uint8Image = new TensorImage(UINT8);
  32350. + uint8Image.load(new int[] {1, 2, 3, 4, -5, 600}, new int[] {2, 1, 3});
  32351. +
  32352. + TensorImage floatImage = TensorImage.createFrom(uint8Image, FLOAT32);
  32353. + float[] pixels = floatImage.getTensorBuffer().getFloatArray();
  32354. + assertThat(pixels).isEqualTo(new float[] {1.0f, 2.0f, 3.0f, 4.0f, 0.0f, 255.0f});
  32355. + }
  32356. +
  32357. + @Test
  32358. + public void createFromSucceedsWithFloatTensorImage() {
  32359. + TensorImage floatImage = new TensorImage(FLOAT32);
  32360. + floatImage.load(new float[] {1, 2.495f, 3.5f, 4.5f, -5, 600}, new int[] {2, 1, 3});
  32361. +
  32362. + TensorImage uint8Image = TensorImage.createFrom(floatImage, UINT8);
  32363. + int[] pixels = uint8Image.getTensorBuffer().getIntArray();
  32364. + assertThat(pixels).isEqualTo(new int[] {1, 2, 3, 4, 0, 255});
  32365. + }
  32366. +
  32367. + @Test
  32368. + public void loadBitmapSucceedsWithUint8TensorImage() {
  32369. + Bitmap rgbBitmap = createRgbBitmap();
  32370. + TensorBuffer rgbTensorBuffer = createRgbTensorBuffer(UINT8, false, 0.0f);
  32371. + TensorImage uint8Image = new TensorImage(UINT8);
  32372. +
  32373. + uint8Image.load(rgbBitmap);
  32374. + assertThat(uint8Image.getBitmap().sameAs(rgbBitmap)).isTrue();
  32375. + assertEqualTensorBuffers(uint8Image.getTensorBuffer(), rgbTensorBuffer);
  32376. + assertThat(uint8Image.getDataType()).isEqualTo(UINT8);
  32377. + }
  32378. +
  32379. + @Test
  32380. + public void loadBitmapSucceedsWithFloatTensorImage() {
  32381. + Bitmap rgbBitmap = createRgbBitmap();
  32382. + TensorBuffer rgbTensorBuffer = createRgbTensorBuffer(FLOAT32, false, 0.0f);
  32383. + TensorImage floatImage = new TensorImage(FLOAT32);
  32384. +
  32385. + floatImage.load(rgbBitmap);
  32386. + assertThat(floatImage.getBitmap().sameAs(rgbBitmap)).isTrue();
  32387. + assertEqualTensorBuffers(floatImage.getTensorBuffer(), rgbTensorBuffer);
  32388. + assertThat(floatImage.getDataType()).isEqualTo(FLOAT32);
  32389. + }
  32390. +
  32391. + @Test
  32392. + public void loadFloatArrayWithUint8TensorImage() {
  32393. + TensorImage uint8Image = new TensorImage(UINT8);
  32394. +
  32395. + uint8Image.load(exampleFloatPixels, EXAMPLE_SHAPE);
  32396. + assertThat(uint8Image.getBitmap()).isNotNull();
  32397. + assertThat(uint8Image.getTensorBuffer().getFloatArray())
  32398. + .isEqualTo(new float[exampleFloatPixels.length]); // All zero because of
  32399. + // normalization and casting
  32400. + // when loading.
  32401. + }
  32402. +
  32403. + @Test
  32404. + public void loadFloatArrayWithFloatTensorImage() {
  32405. + TensorImage floatImage = new TensorImage(FLOAT32);
  32406. +
  32407. + floatImage.load(exampleFloatPixels, EXAMPLE_SHAPE);
  32408. + assertThat(floatImage.getTensorBuffer().getFloatArray()).isEqualTo(exampleFloatPixels);
  32409. + }
  32410. +
  32411. + @Test
  32412. + public void loadUint8ArrayWithUint8TensorImage() {
  32413. + TensorImage uint8Image = new TensorImage(UINT8);
  32414. +
  32415. + uint8Image.load(exampleUint8Pixels, EXAMPLE_SHAPE);
  32416. + assertThat(uint8Image.getBitmap().sameAs(exampleBitmap)).isTrue();
  32417. + assertThat(uint8Image.getTensorBuffer().getIntArray()).isEqualTo(exampleUint8Pixels);
  32418. + }
  32419. +
  32420. + @Test
  32421. + public void loadUint8ArrayWithFloatTensorImage() {
  32422. + TensorImage floatImage = new TensorImage(FLOAT32);
  32423. +
  32424. + floatImage.load(exampleUint8Pixels, EXAMPLE_SHAPE);
  32425. + assertThat(floatImage.getTensorBuffer().getIntArray()).isEqualTo(exampleUint8Pixels);
  32426. + }
  32427. +
  32428. + @Test
  32429. + public void loadTensorBufferWithUint8TensorImage() {
  32430. + TensorImage uint8Image = new TensorImage(UINT8);
  32431. +
  32432. + uint8Image.load(exampleBitmap);
  32433. + TensorBuffer buffer = uint8Image.getTensorBuffer();
  32434. + uint8Image.load(buffer);
  32435. + assertThat(uint8Image.getBitmap().sameAs(exampleBitmap)).isTrue();
  32436. + }
  32437. +
  32438. + @Test
  32439. + public void loadTensorBufferWithFloatTensorImage() {
  32440. + TensorImage floatImage = new TensorImage(FLOAT32);
  32441. +
  32442. + floatImage.load(exampleBitmap);
  32443. + TensorBuffer buffer = floatImage.getTensorBuffer();
  32444. + floatImage.load(buffer);
  32445. + assertThat(floatImage.getTensorBuffer().getIntArray()).isEqualTo(exampleUint8Pixels);
  32446. + }
  32447. +
  32448. + @Test
  32449. + public void loadAndGetMediaImageSucceedsWithYuv420888Format() {
  32450. + setUpImageMock(imageMock, ImageFormat.YUV_420_888);
  32451. + TensorImage tensorImage = new TensorImage(UINT8);
  32452. +
  32453. + tensorImage.load(imageMock);
  32454. + Image imageReturned = tensorImage.getMediaImage();
  32455. +
  32456. + assertThat(imageReturned).isEqualTo(imageMock);
  32457. + }
  32458. +
  32459. + @Test
  32460. + public void loadMediaImageFailsWithNonYuv420888Format() {
  32461. + setUpImageMock(imageMock, ImageFormat.YUV_422_888);
  32462. + TensorImage tensorImage = new TensorImage(UINT8);
  32463. +
  32464. + IllegalArgumentException exception =
  32465. + assertThrows(IllegalArgumentException.class, () -> tensorImage.load(imageMock));
  32466. + assertThat(exception).hasMessageThat().contains(
  32467. + "Only supports loading YUV_420_888 Image.");
  32468. + }
  32469. +
  32470. + @Test
  32471. + public void getBitmapWithUint8TensorImage() {
  32472. + TensorImage uint8Image = new TensorImage(UINT8);
  32473. +
  32474. + uint8Image.load(exampleBitmap);
  32475. + assertThat(uint8Image.getBitmap().sameAs(exampleBitmap)).isTrue();
  32476. + // Also check zero copy is effective here (input and output are references of the same
  32477. + // object).
  32478. + assertThat(uint8Image.getBitmap()).isSameInstanceAs(exampleBitmap);
  32479. + // Also check we don't create new Bitmap only with reading operations.
  32480. + assertThat(uint8Image.getBuffer().limit())
  32481. + .isEqualTo(EXAMPLE_NUM_PIXELS * EXAMPLE_NUM_CHANNELS);
  32482. + assertThat(uint8Image.getBitmap()).isSameInstanceAs(exampleBitmap);
  32483. +
  32484. + uint8Image.load(exampleUint8Pixels, EXAMPLE_SHAPE);
  32485. + assertThat(uint8Image.getBitmap()).isNotSameInstanceAs(exampleBitmap);
  32486. + }
  32487. +
  32488. + @Test
  32489. + public void getBitmapWithFloatTensorImage() {
  32490. + TensorImage floatImage = new TensorImage(FLOAT32);
  32491. +
  32492. + floatImage.load(exampleBitmap);
  32493. + assertThat(floatImage.getBitmap()).isSameInstanceAs(exampleBitmap);
  32494. + }
  32495. +
  32496. + @Test
  32497. + public void getBitmapWithEmptyTensorImage() {
  32498. + TensorImage uint8Image = new TensorImage(UINT8);
  32499. +
  32500. + assertThrows(IllegalStateException.class, uint8Image::getBitmap);
  32501. + }
  32502. +
  32503. + @Test
  32504. + public void getMediaImageFailsWithBackedBitmap() {
  32505. + TensorImage tensorImage = TensorImage.fromBitmap(exampleBitmap);
  32506. +
  32507. + UnsupportedOperationException exception = assertThrows(
  32508. + UnsupportedOperationException.class, () -> tensorImage.getMediaImage());
  32509. + assertThat(exception).hasMessageThat().contains(
  32510. + "Converting from Bitmap to android.media.Image is unsupported.");
  32511. + }
  32512. +
  32513. + @Test
  32514. + public void getMediaImageFailsWithBackedTensorBuffer() {
  32515. + TensorImage tensorImage = new TensorImage(UINT8);
  32516. + tensorImage.load(exampleFloatPixels, EXAMPLE_SHAPE);
  32517. +
  32518. + UnsupportedOperationException exception = assertThrows(
  32519. + UnsupportedOperationException.class, () -> tensorImage.getMediaImage());
  32520. + assertThat(exception).hasMessageThat().contains(
  32521. + "Converting from TensorBuffer to android.media.Image is unsupported.");
  32522. + }
  32523. +
  32524. + @Test
  32525. + public void getShapeOfInternalBitmapShouldSuccess() {
  32526. + Bitmap bitmap = Bitmap.createBitmap(300, 400, Config.ARGB_8888);
  32527. + TensorImage image = TensorImage.fromBitmap(bitmap);
  32528. +
  32529. + int width = image.getWidth();
  32530. + int height = image.getHeight();
  32531. +
  32532. + assertThat(width).isEqualTo(300);
  32533. + assertThat(height).isEqualTo(400);
  32534. + }
  32535. +
  32536. + @Test
  32537. + public void getShapeOfInternalTensorBufferShouldSuccess() {
  32538. + TensorBuffer buffer = TensorBuffer.createFixedSize(new int[] {1, 400, 300, 3}, UINT8);
  32539. + TensorImage image = new TensorImage();
  32540. + image.load(buffer);
  32541. +
  32542. + int width = image.getWidth();
  32543. + int height = image.getHeight();
  32544. +
  32545. + assertThat(width).isEqualTo(300);
  32546. + assertThat(height).isEqualTo(400);
  32547. + }
  32548. +
  32549. + @Test
  32550. + public void getShapeOfNullImageShouldThrow() {
  32551. + TensorImage image = new TensorImage();
  32552. +
  32553. + assertThrows(IllegalStateException.class, image::getHeight);
  32554. + }
  32555. +
  32556. + @Test
  32557. + public void getShapeOfACorruptedBufferShouldThrowRatherThanCrash() {
  32558. + int[] data = new int[] {1, 2, 3, 4, 5, 6};
  32559. + TensorBuffer buffer = TensorBuffer.createDynamic(UINT8);
  32560. + buffer.loadArray(data, new int[] {1, 1, 2, 3});
  32561. + TensorImage image = new TensorImage();
  32562. + image.load(buffer);
  32563. + // Reload data but with an invalid shape, which leads to `buffer` corrupted.
  32564. + int[] newData = new int[] {1, 2, 3};
  32565. + buffer.loadArray(newData, new int[] {1, 1, 1, 3});
  32566. +
  32567. + assertThrows(IllegalArgumentException.class, image::getHeight);
  32568. + }
  32569. +
  32570. + @Test
  32571. + public void getColorSpaceTypeSucceedsWithBitmapARGB_8888() {
  32572. + Bitmap rgbBitmap = createRgbBitmap();
  32573. + TensorImage tensorImage = TensorImage.fromBitmap(rgbBitmap);
  32574. +
  32575. + assertThat(tensorImage.getColorSpaceType()).isEqualTo(ColorSpaceType.RGB);
  32576. + }
  32577. +
  32578. + @Test
  32579. + public void getColorSpaceTypeSucceedsWithRgbTensorBuffer() {
  32580. + TensorBuffer rgbBuffer = createRgbTensorBuffer(UINT8, false);
  32581. + TensorImage tensorImage = new TensorImage();
  32582. + tensorImage.load(rgbBuffer);
  32583. +
  32584. + assertThat(tensorImage.getColorSpaceType()).isEqualTo(ColorSpaceType.RGB);
  32585. + }
  32586. +
  32587. + @Test
  32588. + public void getColorSpaceTypeSucceedsWithGrayscaleTensorBuffer() {
  32589. + TensorBuffer grayBuffer = createGrayscaleTensorBuffer(UINT8, false);
  32590. + TensorImage tensorImage = new TensorImage();
  32591. + tensorImage.load(grayBuffer, ColorSpaceType.GRAYSCALE);
  32592. +
  32593. + assertThat(tensorImage.getColorSpaceType()).isEqualTo(ColorSpaceType.GRAYSCALE);
  32594. + }
  32595. +
  32596. + @Test
  32597. + public void getColorSpaceTypeSucceedsWithRepeatedLoading() {
  32598. + TensorBuffer grayBuffer = createGrayscaleTensorBuffer(UINT8, false);
  32599. + TensorBuffer rgbBuffer = createRgbTensorBuffer(UINT8, false);
  32600. + Bitmap rgbBitmap = createRgbBitmap();
  32601. + TensorImage tensorImage = new TensorImage();
  32602. +
  32603. + tensorImage.load(rgbBuffer);
  32604. + assertThat(tensorImage.getColorSpaceType()).isEqualTo(ColorSpaceType.RGB);
  32605. + tensorImage.load(grayBuffer, ColorSpaceType.GRAYSCALE);
  32606. + assertThat(tensorImage.getColorSpaceType()).isEqualTo(ColorSpaceType.GRAYSCALE);
  32607. + tensorImage.load(rgbBitmap);
  32608. + assertThat(tensorImage.getColorSpaceType()).isEqualTo(ColorSpaceType.RGB);
  32609. + }
  32610. +
  32611. + @Test
  32612. + public void getColorSpaceTypeFailsWhenNoImageHasBeenLoaded() {
  32613. + TensorImage tensorImage = new TensorImage();
  32614. +
  32615. + IllegalStateException exception =
  32616. + assertThrows(IllegalStateException.class, tensorImage::getColorSpaceType);
  32617. + assertThat(exception).hasMessageThat().contains("No image has been loaded yet.");
  32618. + }
  32619. +
  32620. + /**
  32621. + * Creates an example bit map, which is a 10x10 ARGB bitmap and pixels are set by: pixel[i]
  32622. + * = {A: 0, B: i + 2, G: i + 1, G: i}, where i is the flatten index
  32623. + */
  32624. + private static Bitmap createExampleBitmap() {
  32625. + int[] colors = new int[EXAMPLE_NUM_PIXELS];
  32626. + for (int i = 0; i < EXAMPLE_NUM_PIXELS; i++) {
  32627. + colors[i] = Color.rgb(i, i + 1, i + 2);
  32628. + }
  32629. +
  32630. + return Bitmap.createBitmap(
  32631. + colors, EXAMPLE_WIDTH, EXAMPLE_HEIGHT, Bitmap.Config.ARGB_8888);
  32632. + }
  32633. +
  32634. + private static float[] createExampleFloatPixels() {
  32635. + float[] pixels = new float[EXAMPLE_NUM_PIXELS * EXAMPLE_NUM_CHANNELS];
  32636. + for (int i = 0, j = 0; i < EXAMPLE_NUM_PIXELS; i++) {
  32637. + pixels[j++] = (i - MEAN) / STDDEV;
  32638. + pixels[j++] = (i + 1 - MEAN) / STDDEV;
  32639. + pixels[j++] = (i + 2 - MEAN) / STDDEV;
  32640. + }
  32641. + return pixels;
  32642. + }
  32643. +
  32644. + private static int[] createExampleUint8Pixels() {
  32645. + int[] pixels = new int[EXAMPLE_NUM_PIXELS * EXAMPLE_NUM_CHANNELS];
  32646. + for (int i = 0, j = 0; i < EXAMPLE_NUM_PIXELS; i++) {
  32647. + pixels[j++] = i;
  32648. + pixels[j++] = i + 1;
  32649. + pixels[j++] = i + 2;
  32650. + }
  32651. + return pixels;
  32652. + }
  32653. + }
  32654. +
  32655. + /** Parameterized tests for loading TensorBuffers with RGB and Grayscale images. */
  32656. + @RunWith(ParameterizedRobolectricTestRunner.class)
  32657. + public static final class LoadTensorBufferWithRgbAndGrayscale extends TensorImageTest {
  32658. + /**
  32659. + * Difference between the pair of float and uint8 values. It is used to test the data
  32660. + * conversion.
  32661. + */
  32662. + private static final float DELTA = 0.1f;
  32663. +
  32664. + /** The data type that used to create the TensorBuffer. */
  32665. + @Parameter(0)
  32666. + public DataType tensorBufferDataType;
  32667. +
  32668. + /** Indicates whether the shape is in the normalized form of (1, h, w, 3). */
  32669. + @Parameter(1)
  32670. + public boolean isNormalized;
  32671. +
  32672. + /** The color space type of the TensorBuffer. */
  32673. + @Parameter(2)
  32674. + public ColorSpaceType colorSpaceType;
  32675. +
  32676. + /** The data type that used to create the TensorImage. */
  32677. + @Parameter(3)
  32678. + public DataType tensorImageDataType;
  32679. +
  32680. + @Parameters(name = "tensorBufferDataType={0}; isNormalized={1}; colorSpaceType={2};"
  32681. + + " tensorImageDataType={3}")
  32682. + public static Collection<Object[]>
  32683. + data() {
  32684. + return Arrays.asList(new Object[][] {
  32685. + {FLOAT32, true, ColorSpaceType.RGB, FLOAT32},
  32686. + {FLOAT32, false, ColorSpaceType.RGB, UINT8},
  32687. + {UINT8, true, ColorSpaceType.RGB, FLOAT32},
  32688. + {UINT8, false, ColorSpaceType.RGB, UINT8},
  32689. + });
  32690. + }
  32691. +
  32692. + @Test
  32693. + public void loadAndGetBitmapSucceedsWithTensorBufferAndColorSpaceType() {
  32694. + TensorBuffer tensorBuffer =
  32695. + createTensorBuffer(tensorBufferDataType, isNormalized, colorSpaceType, DELTA);
  32696. + TensorImage tensorImage = new TensorImage(tensorImageDataType);
  32697. +
  32698. + tensorImage.load(tensorBuffer, colorSpaceType);
  32699. + Bitmap bitmap = tensorImage.getBitmap();
  32700. +
  32701. + Bitmap expectedBitmap = createBitmap(colorSpaceType);
  32702. + assertThat(bitmap.sameAs(expectedBitmap)).isTrue();
  32703. + }
  32704. +
  32705. + @Test
  32706. + public void loadAndGetTensorBufferSucceedsWithTensorBufferAndColorSpaceType() {
  32707. + TensorBuffer tensorBuffer =
  32708. + createTensorBuffer(tensorBufferDataType, isNormalized, colorSpaceType, DELTA);
  32709. + TensorImage tensorImage = new TensorImage(tensorImageDataType);
  32710. +
  32711. + tensorImage.load(tensorBuffer, colorSpaceType);
  32712. + TensorBuffer buffer = tensorImage.getTensorBuffer();
  32713. +
  32714. + // If tensorBufferDataType is UINT8, expectedTensorBuffer should not contain delta.
  32715. + float expectedResidual = tensorBufferDataType == UINT8 ? 0.f : DELTA;
  32716. + TensorBuffer expectedTensorBuffer = createTensorBuffer(
  32717. + tensorImageDataType, isNormalized, colorSpaceType, expectedResidual);
  32718. + assertEqualTensorBuffers(buffer, expectedTensorBuffer);
  32719. + }
  32720. +
  32721. + private static TensorBuffer createTensorBuffer(DataType dataType, boolean isNormalized,
  32722. + ColorSpaceType colorSpaceType, float delta) {
  32723. + switch (colorSpaceType) {
  32724. + case RGB:
  32725. + return createRgbTensorBuffer(dataType, isNormalized, delta);
  32726. + case GRAYSCALE:
  32727. + return createGrayscaleTensorBuffer(dataType, isNormalized, delta);
  32728. + default:
  32729. + break;
  32730. + }
  32731. + throw new IllegalArgumentException(
  32732. + "The ColorSpaceType, " + colorSpaceType + ", is unsupported.");
  32733. + }
  32734. +
  32735. + private static Bitmap createBitmap(ColorSpaceType colorSpaceType) {
  32736. + switch (colorSpaceType) {
  32737. + case RGB:
  32738. + return createRgbBitmap();
  32739. + case GRAYSCALE:
  32740. + return createGrayscaleBitmap();
  32741. + default:
  32742. + break;
  32743. + }
  32744. + throw new IllegalArgumentException(
  32745. + "The ColorSpaceType, " + colorSpaceType + ", is unsupported.");
  32746. + }
  32747. + }
  32748. +
  32749. + /** Parameterized tests for loading TensorBuffers with YUV images. */
  32750. + @RunWith(ParameterizedRobolectricTestRunner.class)
  32751. + public static final class LoadTensorBufferWithYUV extends TensorImageTest {
  32752. + private static final int HEIGHT = 2;
  32753. + private static final int WIDTH = 3;
  32754. +
  32755. + @Parameter(0)
  32756. + public ColorSpaceType colorSpaceType;
  32757. +
  32758. + @Parameters(name = "colorSpaceType={0}")
  32759. + public static Collection<Object[]> data() {
  32760. + return Arrays.asList(new Object[][] {
  32761. + {ColorSpaceType.NV12},
  32762. + {ColorSpaceType.NV21},
  32763. + {ColorSpaceType.YV12},
  32764. + {ColorSpaceType.YV21},
  32765. + });
  32766. + }
  32767. +
  32768. + @Test
  32769. + public void loadTensorBufferWithColorSpaceShouldFail() {
  32770. + TensorImage tensorImage = new TensorImage();
  32771. +
  32772. + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
  32773. + ()
  32774. + -> tensorImage.load(
  32775. + TensorBuffer.createDynamic(DataType.FLOAT32), colorSpaceType));
  32776. + assertThat(exception).hasMessageThat().contains(
  32777. + "Only ColorSpaceType.RGB and ColorSpaceType.GRAYSCALE are supported. Use"
  32778. + + " `load(TensorBuffer, ImageProperties)` for other color space types.");
  32779. + }
  32780. +
  32781. + @Test
  32782. + public void loadTensorBufferAndGetBitmapShouldFail() {
  32783. + int[] data = new int[colorSpaceType.getNumElements(HEIGHT, WIDTH)];
  32784. + TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.FLOAT32);
  32785. + tensorBuffer.loadArray(data, new int[] {data.length});
  32786. +
  32787. + ImageProperties imageProperties = ImageProperties.builder()
  32788. + .setHeight(HEIGHT)
  32789. + .setWidth(WIDTH)
  32790. + .setColorSpaceType(colorSpaceType)
  32791. + .build();
  32792. +
  32793. + TensorImage tensorImage = new TensorImage(DataType.FLOAT32);
  32794. + tensorImage.load(tensorBuffer, imageProperties);
  32795. +
  32796. + UnsupportedOperationException exception = assertThrows(
  32797. + UnsupportedOperationException.class, () -> tensorImage.getBitmap());
  32798. + assertThat(exception).hasMessageThat().contains(
  32799. + "convertTensorBufferToBitmap() is unsupported for the color space type "
  32800. + + colorSpaceType.name());
  32801. + }
  32802. + }
  32803. +
  32804. + /** Parameterized tests for loading TensorBuffers with ImageProperties. */
  32805. + @RunWith(ParameterizedRobolectricTestRunner.class)
  32806. + public static final class LoadTensorBufferWithImageProperties extends TensorImageTest {
  32807. + private static final int HEIGHT = 2;
  32808. + private static final int WIDTH = 3;
  32809. + private static final int WRONG_WIDTH = 10;
  32810. +
  32811. + @Parameter(0)
  32812. + public ColorSpaceType colorSpaceType;
  32813. +
  32814. + @Parameters(name = "colorSpaceType={0}")
  32815. + public static Collection<Object[]> data() {
  32816. + return Arrays.asList(new Object[][] {
  32817. + {ColorSpaceType.RGB},
  32818. + {ColorSpaceType.GRAYSCALE},
  32819. + {ColorSpaceType.NV12},
  32820. + {ColorSpaceType.NV21},
  32821. + {ColorSpaceType.YV12},
  32822. + {ColorSpaceType.YV21},
  32823. + });
  32824. + }
  32825. +
  32826. + @Test
  32827. + public void loadAndGetTensorBufferShouldSucceedWithCorrectProperties() {
  32828. + int[] data = new int[colorSpaceType.getNumElements(HEIGHT, WIDTH)];
  32829. + TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.FLOAT32);
  32830. + tensorBuffer.loadArray(data, new int[] {data.length});
  32831. +
  32832. + ImageProperties imageProperties = ImageProperties.builder()
  32833. + .setHeight(HEIGHT)
  32834. + .setWidth(WIDTH)
  32835. + .setColorSpaceType(colorSpaceType)
  32836. + .build();
  32837. +
  32838. + TensorImage tensorImage = new TensorImage(DataType.FLOAT32);
  32839. + tensorImage.load(tensorBuffer, imageProperties);
  32840. +
  32841. + assertEqualTensorBuffers(tensorImage.getTensorBuffer(), tensorBuffer);
  32842. + }
  32843. +
  32844. + @Test
  32845. + public void loadAndGetTensorBufferShouldSucceedWithLargerBuffer() {
  32846. + // Should allow buffer to be greater than the size specified by height and width.
  32847. + int moreElements = 1;
  32848. + int[] data = new int[colorSpaceType.getNumElements(HEIGHT, WIDTH) + moreElements];
  32849. + TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.FLOAT32);
  32850. + tensorBuffer.loadArray(data, new int[] {data.length});
  32851. +
  32852. + ImageProperties imageProperties = ImageProperties.builder()
  32853. + .setHeight(HEIGHT)
  32854. + .setWidth(WIDTH)
  32855. + .setColorSpaceType(colorSpaceType)
  32856. + .build();
  32857. +
  32858. + TensorImage tensorImage = new TensorImage(DataType.FLOAT32);
  32859. + tensorImage.load(tensorBuffer, imageProperties);
  32860. +
  32861. + assertEqualTensorBuffers(tensorImage.getTensorBuffer(), tensorBuffer);
  32862. + }
  32863. +
  32864. + @Test
  32865. + public void loadAndGetByteBufferShouldSucceedWithCorrectProperties() {
  32866. + ByteBuffer byteBuffer =
  32867. + ByteBuffer.allocate(colorSpaceType.getNumElements(HEIGHT, WIDTH));
  32868. +
  32869. + ImageProperties imageProperties = ImageProperties.builder()
  32870. + .setHeight(HEIGHT)
  32871. + .setWidth(WIDTH)
  32872. + .setColorSpaceType(colorSpaceType)
  32873. + .build();
  32874. +
  32875. + TensorImage tensorImage = new TensorImage(DataType.UINT8);
  32876. + tensorImage.load(byteBuffer, imageProperties);
  32877. +
  32878. + assertEqualByteBuffers(tensorImage.getBuffer(), byteBuffer);
  32879. + }
  32880. +
  32881. + @Test
  32882. + public void loadTensorBufferWithShouldFailWithWrongImageShape() {
  32883. + int[] data = new int[colorSpaceType.getNumElements(HEIGHT, WIDTH)];
  32884. + TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.FLOAT32);
  32885. + tensorBuffer.loadArray(data, new int[] {data.length});
  32886. +
  32887. + ImageProperties imageProperties = ImageProperties.builder()
  32888. + .setHeight(HEIGHT)
  32889. + .setWidth(WRONG_WIDTH)
  32890. + .setColorSpaceType(colorSpaceType)
  32891. + .build();
  32892. +
  32893. + TensorImage tensorImage = new TensorImage(DataType.FLOAT32);
  32894. +
  32895. + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
  32896. + () -> tensorImage.load(tensorBuffer, imageProperties));
  32897. + assertThat(exception).hasMessageThat().contains(String.format(
  32898. + "The given number of elements (%d) does not match the image (%s) in %d x %d. The"
  32899. + + " expected number of elements should be at least %d.",
  32900. + data.length, colorSpaceType.name(), HEIGHT, WRONG_WIDTH,
  32901. + colorSpaceType.getNumElements(HEIGHT, WRONG_WIDTH)));
  32902. + }
  32903. +
  32904. + @Test
  32905. + public void getShapeOfInternalTensorBufferShouldSuccess() {
  32906. + int[] data = new int[colorSpaceType.getNumElements(HEIGHT, WIDTH)];
  32907. + TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.FLOAT32);
  32908. + tensorBuffer.loadArray(data, new int[] {data.length});
  32909. +
  32910. + ImageProperties imageProperties = ImageProperties.builder()
  32911. + .setHeight(HEIGHT)
  32912. + .setWidth(WIDTH)
  32913. + .setColorSpaceType(colorSpaceType)
  32914. + .build();
  32915. +
  32916. + TensorImage tensorImage = new TensorImage(DataType.FLOAT32);
  32917. + tensorImage.load(tensorBuffer, imageProperties);
  32918. +
  32919. + assertThat(tensorImage.getWidth()).isEqualTo(WIDTH);
  32920. + assertThat(tensorImage.getHeight()).isEqualTo(HEIGHT);
  32921. + }
  32922. + }
  32923. +
  32924. + /** Parameterized tests for loading TensorBuffer with invalid shapes. */
  32925. + @RunWith(ParameterizedRobolectricTestRunner.class)
  32926. + public static final class LoadTensorBufferWithInvalidShapeTest extends TensorImageTest {
  32927. + private static final String RGB_ASSERT_SHAPE_MESSAGE =
  32928. + "The shape of a RGB image should be (h, w, c) or (1, h, w, c), and channels"
  32929. + + " representing R, G, B in order. The provided image shape is ";
  32930. + private static final String GRAYSCALE_ASSERT_SHAPE_MESSAGE =
  32931. + "The shape of a grayscale image should be (h, w) or (1, h, w, 1). The provided image"
  32932. + + " shape is ";
  32933. +
  32934. + @Parameter(0)
  32935. + public ColorSpaceType colorSpaceType;
  32936. +
  32937. + /** The shape that does not match the colorSpaceType. */
  32938. + @Parameter(1)
  32939. + public int[] invalidShape;
  32940. +
  32941. + @Parameter(2)
  32942. + public String errorMessage;
  32943. +
  32944. + @Parameters(name = "colorSpaceType={0}; invalidShape={1}")
  32945. + public static Collection<Object[]> data() {
  32946. + return Arrays.asList(new Object[][] {
  32947. + {ColorSpaceType.RGB, new int[] {2, 10, 20, 3}, RGB_ASSERT_SHAPE_MESSAGE},
  32948. + {ColorSpaceType.RGB, new int[] {1, 10, 20, 3, 4}, RGB_ASSERT_SHAPE_MESSAGE},
  32949. + {ColorSpaceType.RGB, new int[] {1, 10, 20, 5}, RGB_ASSERT_SHAPE_MESSAGE},
  32950. + {ColorSpaceType.RGB, new int[] {1, 10, 20}, RGB_ASSERT_SHAPE_MESSAGE},
  32951. + {ColorSpaceType.RGB, new int[] {10, 20, 3, 4}, RGB_ASSERT_SHAPE_MESSAGE},
  32952. + {ColorSpaceType.RGB, new int[] {10, 20, 5}, RGB_ASSERT_SHAPE_MESSAGE},
  32953. + {ColorSpaceType.RGB, new int[] {10, 20}, RGB_ASSERT_SHAPE_MESSAGE},
  32954. + {ColorSpaceType.GRAYSCALE, new int[] {2, 10, 20},
  32955. + GRAYSCALE_ASSERT_SHAPE_MESSAGE},
  32956. + {ColorSpaceType.GRAYSCALE, new int[] {1, 10, 20, 3},
  32957. + GRAYSCALE_ASSERT_SHAPE_MESSAGE},
  32958. + {ColorSpaceType.GRAYSCALE, new int[] {10, 20, 4},
  32959. + GRAYSCALE_ASSERT_SHAPE_MESSAGE},
  32960. + {ColorSpaceType.GRAYSCALE, new int[] {10}, GRAYSCALE_ASSERT_SHAPE_MESSAGE},
  32961. + });
  32962. + }
  32963. +
  32964. + @Test
  32965. + public void loadTensorBufferWithInvalidShape() {
  32966. + TensorBuffer tensorBuffer = TensorBuffer.createFixedSize(invalidShape, UINT8);
  32967. + TensorImage tensorImage = new TensorImage();
  32968. +
  32969. + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
  32970. + () -> tensorImage.load(tensorBuffer, colorSpaceType));
  32971. + assertThat(exception).hasMessageThat().contains(
  32972. + errorMessage + Arrays.toString(invalidShape));
  32973. + }
  32974. + }
  32975. +
  32976. + private static void assertEqualTensorBuffers(
  32977. + TensorBuffer tensorBuffer1, TensorBuffer tensorBuffer2) {
  32978. + assertEqualByteBuffers(tensorBuffer1.getBuffer(), tensorBuffer2.getBuffer());
  32979. + assertArrayEquals(tensorBuffer1.getShape(), tensorBuffer2.getShape());
  32980. + }
  32981. +
  32982. + private static void assertEqualByteBuffers(ByteBuffer buffer1, ByteBuffer buffer2) {
  32983. + buffer1.rewind();
  32984. + buffer2.rewind();
  32985. + assertThat(buffer1.equals(buffer2)).isTrue();
  32986. + }
  32987. +
  32988. + private static void setUpImageMock(Image imageMock, int imageFormat) {
  32989. + when(imageMock.getFormat()).thenReturn(imageFormat);
  32990. }
  32991. - }
  32992. -
  32993. - private static void assertEqualTensorBuffers(
  32994. - TensorBuffer tensorBuffer1, TensorBuffer tensorBuffer2) {
  32995. - assertEqualByteBuffers(tensorBuffer1.getBuffer(), tensorBuffer2.getBuffer());
  32996. - assertArrayEquals(tensorBuffer1.getShape(), tensorBuffer2.getShape());
  32997. - }
  32998. -
  32999. - private static void assertEqualByteBuffers(ByteBuffer buffer1, ByteBuffer buffer2) {
  33000. - buffer1.rewind();
  33001. - buffer2.rewind();
  33002. - assertThat(buffer1.equals(buffer2)).isTrue();
  33003. - }
  33004. -
  33005. - private static void setUpImageMock(Image imageMock, int imageFormat) {
  33006. - when(imageMock.getFormat()).thenReturn(imageFormat);
  33007. - }
  33008. }
  33009. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/TestImageCreator.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/TestImageCreator.java
  33010. index 7a5d0e9a9ea33..4ac2eca0b8cc6 100644
  33011. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/TestImageCreator.java
  33012. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/TestImageCreator.java
  33013. @@ -17,109 +17,112 @@ package org.tensorflow.lite.support.image;
  33014. import android.graphics.Bitmap;
  33015. import android.graphics.Color;
  33016. -import java.nio.ByteBuffer;
  33017. +
  33018. import org.tensorflow.lite.DataType;
  33019. import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
  33020. +import java.nio.ByteBuffer;
  33021. +
  33022. /** Creates test images for other test files. */
  33023. final class TestImageCreator {
  33024. - /**
  33025. - * Creates an example bitmap, which is a 10x10 ARGB bitmap and pixels are set by: <br>
  33026. - * pixel[i] = {A: 255, B: i + 2, G: i + 1, R: i}, where i is the flatten index.
  33027. - */
  33028. - static Bitmap createRgbBitmap() {
  33029. - int[] colors = new int[100];
  33030. - for (int i = 0; i < 100; i++) {
  33031. - colors[i] = Color.rgb(i, i + 1, i + 2);
  33032. + /**
  33033. + * Creates an example bitmap, which is a 10x10 ARGB bitmap and pixels are set by: <br>
  33034. + * pixel[i] = {A: 255, B: i + 2, G: i + 1, R: i}, where i is the flatten index.
  33035. + */
  33036. + static Bitmap createRgbBitmap() {
  33037. + int[] colors = new int[100];
  33038. + for (int i = 0; i < 100; i++) {
  33039. + colors[i] = Color.rgb(i, i + 1, i + 2);
  33040. + }
  33041. + return Bitmap.createBitmap(colors, 10, 10, Bitmap.Config.ARGB_8888);
  33042. }
  33043. - return Bitmap.createBitmap(colors, 10, 10, Bitmap.Config.ARGB_8888);
  33044. - }
  33045. - /**
  33046. - * Creates a 10*10*3 float or uint8 TensorBuffer representing the same image in createRgbBitmap.
  33047. - *
  33048. - * <p>Adds a default delta, 0.1f, to the generated float values, such that the float array is
  33049. - * [0.1, 1.1, 2.1, 3.1, ...], while the uint8 array is[0, 1, 2, 3, ...].
  33050. - *
  33051. - * @param isNormalized if true, the shape is (1, h, w, 3), otherwise it's (h, w, 3)
  33052. - */
  33053. - static TensorBuffer createRgbTensorBuffer(DataType dataType, boolean isNormalized) {
  33054. - return createRgbTensorBuffer(dataType, isNormalized, /*delta=*/ 0.1f);
  33055. - }
  33056. -
  33057. - /**
  33058. - * Creates a 10*10*3 float or uint8 TensorBuffer representing the same image in createRgbBitmap.
  33059. - *
  33060. - * @param isNormalized if true, the shape is (1, h, w, 3), otherwise it's (h, w)
  33061. - * @param delta the delta that applied to the float values, such that the float array is [0 + +
  33062. - * delta, 1+ delta, 2+ delta, 3+ delta, ...], while the uint8 array is [0, 1, 2, 3, ...]
  33063. - */
  33064. - static TensorBuffer createRgbTensorBuffer(DataType dataType, boolean isNormalized, float delta) {
  33065. - float[] rgbValues = new float[300];
  33066. - for (int i = 0, j = 0; i < 100; i++) {
  33067. - rgbValues[j++] = i + delta;
  33068. - rgbValues[j++] = i + 1 + delta;
  33069. - rgbValues[j++] = i + 2 + delta;
  33070. + /**
  33071. + * Creates a 10*10*3 float or uint8 TensorBuffer representing the same image in createRgbBitmap.
  33072. + *
  33073. + * <p>Adds a default delta, 0.1f, to the generated float values, such that the float array is
  33074. + * [0.1, 1.1, 2.1, 3.1, ...], while the uint8 array is[0, 1, 2, 3, ...].
  33075. + *
  33076. + * @param isNormalized if true, the shape is (1, h, w, 3), otherwise it's (h, w, 3)
  33077. + */
  33078. + static TensorBuffer createRgbTensorBuffer(DataType dataType, boolean isNormalized) {
  33079. + return createRgbTensorBuffer(dataType, isNormalized, /*delta=*/0.1f);
  33080. }
  33081. - int[] shape = isNormalized ? new int[] {1, 10, 10, 3} : new int[] {10, 10, 3};
  33082. - TensorBuffer buffer = TensorBuffer.createFixedSize(shape, dataType);
  33083. - // If dataType is UINT8, rgbValues will be converted into uint8, such as from
  33084. - // [0.1, 1.1, 2.1, 3.1, ...] to [0, 1, 2, 3, ...].
  33085. - buffer.loadArray(rgbValues, shape);
  33086. - return buffer;
  33087. - }
  33088. + /**
  33089. + * Creates a 10*10*3 float or uint8 TensorBuffer representing the same image in createRgbBitmap.
  33090. + *
  33091. + * @param isNormalized if true, the shape is (1, h, w, 3), otherwise it's (h, w)
  33092. + * @param delta the delta that applied to the float values, such that the float array is [0 + +
  33093. + * delta, 1+ delta, 2+ delta, 3+ delta, ...], while the uint8 array is [0, 1, 2, 3, ...]
  33094. + */
  33095. + static TensorBuffer createRgbTensorBuffer(
  33096. + DataType dataType, boolean isNormalized, float delta) {
  33097. + float[] rgbValues = new float[300];
  33098. + for (int i = 0, j = 0; i < 100; i++) {
  33099. + rgbValues[j++] = i + delta;
  33100. + rgbValues[j++] = i + 1 + delta;
  33101. + rgbValues[j++] = i + 2 + delta;
  33102. + }
  33103. +
  33104. + int[] shape = isNormalized ? new int[] {1, 10, 10, 3} : new int[] {10, 10, 3};
  33105. + TensorBuffer buffer = TensorBuffer.createFixedSize(shape, dataType);
  33106. + // If dataType is UINT8, rgbValues will be converted into uint8, such as from
  33107. + // [0.1, 1.1, 2.1, 3.1, ...] to [0, 1, 2, 3, ...].
  33108. + buffer.loadArray(rgbValues, shape);
  33109. + return buffer;
  33110. + }
  33111. - /**
  33112. - * Creates an example bitmap, which is a 10x10 ALPHA_8 bitmap and pixels are set by: <br>
  33113. - * pixel[i] = i, where i is the flatten index.
  33114. - */
  33115. - static Bitmap createGrayscaleBitmap() {
  33116. - byte[] grayValues = new byte[100];
  33117. - for (int i = 0; i < 100; i++) {
  33118. - grayValues[i] = (byte) i;
  33119. + /**
  33120. + * Creates an example bitmap, which is a 10x10 ALPHA_8 bitmap and pixels are set by: <br>
  33121. + * pixel[i] = i, where i is the flatten index.
  33122. + */
  33123. + static Bitmap createGrayscaleBitmap() {
  33124. + byte[] grayValues = new byte[100];
  33125. + for (int i = 0; i < 100; i++) {
  33126. + grayValues[i] = (byte) i;
  33127. + }
  33128. + ByteBuffer buffer = ByteBuffer.wrap(grayValues);
  33129. + Bitmap bitmap = Bitmap.createBitmap(10, 10, Bitmap.Config.ALPHA_8);
  33130. + buffer.rewind();
  33131. + bitmap.copyPixelsFromBuffer(buffer);
  33132. + return bitmap;
  33133. }
  33134. - ByteBuffer buffer = ByteBuffer.wrap(grayValues);
  33135. - Bitmap bitmap = Bitmap.createBitmap(10, 10, Bitmap.Config.ALPHA_8);
  33136. - buffer.rewind();
  33137. - bitmap.copyPixelsFromBuffer(buffer);
  33138. - return bitmap;
  33139. - }
  33140. - /**
  33141. - * Creates a 10*10 float or uint8 TensorBuffer representing the same image in
  33142. - * createGrayscaleBitmap.
  33143. - *
  33144. - * <p>Adds a default delta, 0.1f, to the generated float values, such that the float array is
  33145. - * [0.1, 1.1, 2.1, 3.1, ...], while the uint8 array is[0, 1, 2, 3, ...].
  33146. - *
  33147. - * @param isNormalized if true, the shape is (1, h, w, 1), otherwise it's (h, w)
  33148. - */
  33149. - static TensorBuffer createGrayscaleTensorBuffer(DataType dataType, boolean isNormalized) {
  33150. - return createGrayscaleTensorBuffer(dataType, isNormalized, /*delta=*/ 0.1f);
  33151. - }
  33152. + /**
  33153. + * Creates a 10*10 float or uint8 TensorBuffer representing the same image in
  33154. + * createGrayscaleBitmap.
  33155. + *
  33156. + * <p>Adds a default delta, 0.1f, to the generated float values, such that the float array is
  33157. + * [0.1, 1.1, 2.1, 3.1, ...], while the uint8 array is[0, 1, 2, 3, ...].
  33158. + *
  33159. + * @param isNormalized if true, the shape is (1, h, w, 1), otherwise it's (h, w)
  33160. + */
  33161. + static TensorBuffer createGrayscaleTensorBuffer(DataType dataType, boolean isNormalized) {
  33162. + return createGrayscaleTensorBuffer(dataType, isNormalized, /*delta=*/0.1f);
  33163. + }
  33164. - /**
  33165. - * Creates a 10*10 float or uint8 TensorBuffer representing the same image in
  33166. - * createGrayscaleBitmap.
  33167. - *
  33168. - * @param isNormalized if true, the shape is (1, h, w, 1), otherwise it's (h, w)
  33169. - * @param delta the delta that applied to the float values, such that the float array is [0 +
  33170. - * delta, 1+ delta, 2+ delta, 3+ delta, ...], while the uint8 array is [0, 1, 2, 3, ...]
  33171. - */
  33172. - static TensorBuffer createGrayscaleTensorBuffer(
  33173. - DataType dataType, boolean isNormalized, float delta) {
  33174. - float[] grayValues = new float[100];
  33175. - for (int i = 0; i < 100; i++) {
  33176. - grayValues[i] = i + delta;
  33177. + /**
  33178. + * Creates a 10*10 float or uint8 TensorBuffer representing the same image in
  33179. + * createGrayscaleBitmap.
  33180. + *
  33181. + * @param isNormalized if true, the shape is (1, h, w, 1), otherwise it's (h, w)
  33182. + * @param delta the delta that applied to the float values, such that the float array is [0 +
  33183. + * delta, 1+ delta, 2+ delta, 3+ delta, ...], while the uint8 array is [0, 1, 2, 3, ...]
  33184. + */
  33185. + static TensorBuffer createGrayscaleTensorBuffer(
  33186. + DataType dataType, boolean isNormalized, float delta) {
  33187. + float[] grayValues = new float[100];
  33188. + for (int i = 0; i < 100; i++) {
  33189. + grayValues[i] = i + delta;
  33190. + }
  33191. + int[] shape = isNormalized ? new int[] {1, 10, 10, 1} : new int[] {10, 10};
  33192. + TensorBuffer buffer = TensorBuffer.createFixedSize(shape, dataType);
  33193. + // If dataType is UINT8, grayValues will be converted into uint8, such as from
  33194. + // [0.1, 1.1, 2.1, 3.1, ...] to [0, 1, 2, 3, ...].
  33195. + buffer.loadArray(grayValues, shape);
  33196. + return buffer;
  33197. }
  33198. - int[] shape = isNormalized ? new int[] {1, 10, 10, 1} : new int[] {10, 10};
  33199. - TensorBuffer buffer = TensorBuffer.createFixedSize(shape, dataType);
  33200. - // If dataType is UINT8, grayValues will be converted into uint8, such as from
  33201. - // [0.1, 1.1, 2.1, 3.1, ...] to [0, 1, 2, 3, ...].
  33202. - buffer.loadArray(grayValues, shape);
  33203. - return buffer;
  33204. - }
  33205. - private TestImageCreator() {}
  33206. + private TestImageCreator() {}
  33207. }
  33208. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ops/ResizeOpInstrumentedTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ops/ResizeOpInstrumentedTest.java
  33209. index a34f47d44c0ac..070e17893ad76 100644
  33210. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ops/ResizeOpInstrumentedTest.java
  33211. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ops/ResizeOpInstrumentedTest.java
  33212. @@ -19,7 +19,9 @@ import static com.google.common.truth.Truth.assertThat;
  33213. import android.graphics.Bitmap;
  33214. import android.graphics.PointF;
  33215. +
  33216. import androidx.test.ext.junit.runners.AndroidJUnit4;
  33217. +
  33218. import org.junit.Before;
  33219. import org.junit.Test;
  33220. import org.junit.runner.RunWith;
  33221. @@ -31,63 +33,62 @@ import org.tensorflow.lite.support.image.ops.ResizeOp.ResizeMethod;
  33222. /** Instrumented unit test for {@link ResizeOp}. */
  33223. @RunWith(AndroidJUnit4.class)
  33224. public class ResizeOpInstrumentedTest {
  33225. + private static final int EXAMPLE_WIDTH = 10;
  33226. + private static final int EXAMPLE_HEIGHT = 15;
  33227. - private static final int EXAMPLE_WIDTH = 10;
  33228. - private static final int EXAMPLE_HEIGHT = 15;
  33229. -
  33230. - private Bitmap exampleBitmap;
  33231. - private TensorImage input;
  33232. + private Bitmap exampleBitmap;
  33233. + private TensorImage input;
  33234. - @Before
  33235. - public void setUp() {
  33236. - exampleBitmap = createExampleBitmap();
  33237. - input = new TensorImage(DataType.UINT8);
  33238. - input.load(exampleBitmap);
  33239. - }
  33240. + @Before
  33241. + public void setUp() {
  33242. + exampleBitmap = createExampleBitmap();
  33243. + input = new TensorImage(DataType.UINT8);
  33244. + input.load(exampleBitmap);
  33245. + }
  33246. - @Test
  33247. - public void resizeShouldSuccess() {
  33248. - int targetWidth = EXAMPLE_WIDTH * 2;
  33249. - int targetHeight = EXAMPLE_HEIGHT * 2;
  33250. - ImageProcessor processor =
  33251. - new ImageProcessor.Builder()
  33252. - .add(new ResizeOp(targetHeight, targetWidth, ResizeMethod.NEAREST_NEIGHBOR))
  33253. - .build();
  33254. - TensorImage output = processor.process(input);
  33255. + @Test
  33256. + public void resizeShouldSuccess() {
  33257. + int targetWidth = EXAMPLE_WIDTH * 2;
  33258. + int targetHeight = EXAMPLE_HEIGHT * 2;
  33259. + ImageProcessor processor =
  33260. + new ImageProcessor.Builder()
  33261. + .add(new ResizeOp(targetHeight, targetWidth, ResizeMethod.NEAREST_NEIGHBOR))
  33262. + .build();
  33263. + TensorImage output = processor.process(input);
  33264. - Bitmap outputBitmap = output.getBitmap();
  33265. - assertThat(outputBitmap.getWidth()).isEqualTo(targetWidth);
  33266. - assertThat(outputBitmap.getHeight()).isEqualTo(targetHeight);
  33267. - for (int i = 0; i < outputBitmap.getWidth(); i++) {
  33268. - for (int j = 0; j < outputBitmap.getHeight(); j++) {
  33269. - int expected = exampleBitmap.getPixel(i / 2, j / 2);
  33270. - assertThat(outputBitmap.getPixel(i, j)).isEqualTo(expected);
  33271. - }
  33272. + Bitmap outputBitmap = output.getBitmap();
  33273. + assertThat(outputBitmap.getWidth()).isEqualTo(targetWidth);
  33274. + assertThat(outputBitmap.getHeight()).isEqualTo(targetHeight);
  33275. + for (int i = 0; i < outputBitmap.getWidth(); i++) {
  33276. + for (int j = 0; j < outputBitmap.getHeight(); j++) {
  33277. + int expected = exampleBitmap.getPixel(i / 2, j / 2);
  33278. + assertThat(outputBitmap.getPixel(i, j)).isEqualTo(expected);
  33279. + }
  33280. + }
  33281. }
  33282. - }
  33283. - @Test
  33284. - public void inverseTransformPointShouldSuccess() {
  33285. - ResizeOp op = new ResizeOp(200, 300, ResizeMethod.NEAREST_NEIGHBOR);
  33286. - PointF transformed = new PointF(32.0f, 42.0f);
  33287. - // The original image size is 900x400 assumed
  33288. - PointF original = op.inverseTransform(transformed, 400, 900);
  33289. - assertThat(original.x).isEqualTo(96);
  33290. - assertThat(original.y).isEqualTo(84);
  33291. - PointF outside = op.inverseTransform(new PointF(500, 1000), 400, 900);
  33292. - assertThat(outside.x).isEqualTo(1500);
  33293. - assertThat(outside.y).isEqualTo(2000);
  33294. - }
  33295. + @Test
  33296. + public void inverseTransformPointShouldSuccess() {
  33297. + ResizeOp op = new ResizeOp(200, 300, ResizeMethod.NEAREST_NEIGHBOR);
  33298. + PointF transformed = new PointF(32.0f, 42.0f);
  33299. + // The original image size is 900x400 assumed
  33300. + PointF original = op.inverseTransform(transformed, 400, 900);
  33301. + assertThat(original.x).isEqualTo(96);
  33302. + assertThat(original.y).isEqualTo(84);
  33303. + PointF outside = op.inverseTransform(new PointF(500, 1000), 400, 900);
  33304. + assertThat(outside.x).isEqualTo(1500);
  33305. + assertThat(outside.y).isEqualTo(2000);
  33306. + }
  33307. - /**
  33308. - * Creates an example bitmap, which is a 10x15 ARGB bitmap and pixels are set by: - pixel[i] = {A:
  33309. - * 255, B: i + 2, G: i + 1, G: i}, where i is the flatten index
  33310. - */
  33311. - private static Bitmap createExampleBitmap() {
  33312. - int[] colors = new int[EXAMPLE_WIDTH * EXAMPLE_HEIGHT];
  33313. - for (int i = 0; i < EXAMPLE_WIDTH * EXAMPLE_HEIGHT; i++) {
  33314. - colors[i] = (i << 16) | ((i + 1) << 8) | (i + 2);
  33315. + /**
  33316. + * Creates an example bitmap, which is a 10x15 ARGB bitmap and pixels are set by: - pixel[i] =
  33317. + * {A: 255, B: i + 2, G: i + 1, G: i}, where i is the flatten index
  33318. + */
  33319. + private static Bitmap createExampleBitmap() {
  33320. + int[] colors = new int[EXAMPLE_WIDTH * EXAMPLE_HEIGHT];
  33321. + for (int i = 0; i < EXAMPLE_WIDTH * EXAMPLE_HEIGHT; i++) {
  33322. + colors[i] = (i << 16) | ((i + 1) << 8) | (i + 2);
  33323. + }
  33324. + return Bitmap.createBitmap(colors, EXAMPLE_WIDTH, EXAMPLE_HEIGHT, Bitmap.Config.ARGB_8888);
  33325. }
  33326. - return Bitmap.createBitmap(colors, EXAMPLE_WIDTH, EXAMPLE_HEIGHT, Bitmap.Config.ARGB_8888);
  33327. - }
  33328. }
  33329. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ops/ResizeWithCropOrPadOpInstrumentedTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ops/ResizeWithCropOrPadOpInstrumentedTest.java
  33330. index 5c483780b30f4..85c777904f2ec 100644
  33331. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ops/ResizeWithCropOrPadOpInstrumentedTest.java
  33332. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ops/ResizeWithCropOrPadOpInstrumentedTest.java
  33333. @@ -19,7 +19,9 @@ import static com.google.common.truth.Truth.assertThat;
  33334. import android.graphics.Bitmap;
  33335. import android.graphics.PointF;
  33336. +
  33337. import androidx.test.ext.junit.runners.AndroidJUnit4;
  33338. +
  33339. import org.junit.Before;
  33340. import org.junit.Test;
  33341. import org.junit.runner.RunWith;
  33342. @@ -30,131 +32,128 @@ import org.tensorflow.lite.support.image.TensorImage;
  33343. /** Instrumented unit test for {@link ResizeWithCropOrPadOp}. */
  33344. @RunWith(AndroidJUnit4.class)
  33345. public class ResizeWithCropOrPadOpInstrumentedTest {
  33346. + private Bitmap exampleBitmap;
  33347. + private TensorImage input;
  33348. - private Bitmap exampleBitmap;
  33349. - private TensorImage input;
  33350. -
  33351. - private static final int EXAMPLE_WIDTH = 10;
  33352. - private static final int EXAMPLE_HEIGHT = 15;
  33353. -
  33354. - @Before
  33355. - public void setUp() {
  33356. - exampleBitmap = createExampleBitmap();
  33357. - input = new TensorImage(DataType.UINT8);
  33358. - input.load(exampleBitmap);
  33359. - }
  33360. -
  33361. - @Test
  33362. - public void testResizeWithCrop() {
  33363. - int targetWidth = 6;
  33364. - int targetHeight = 5;
  33365. - ImageProcessor processor =
  33366. - new ImageProcessor.Builder()
  33367. - .add(new ResizeWithCropOrPadOp(targetHeight, targetWidth))
  33368. - .build();
  33369. - TensorImage output = processor.process(input);
  33370. -
  33371. - Bitmap outputBitmap = output.getBitmap();
  33372. - assertThat(outputBitmap.getWidth()).isEqualTo(targetWidth);
  33373. - assertThat(outputBitmap.getHeight()).isEqualTo(targetHeight);
  33374. - for (int i = 0; i < outputBitmap.getWidth(); i++) {
  33375. - for (int j = 0; j < outputBitmap.getHeight(); j++) {
  33376. - int expected =
  33377. - exampleBitmap.getPixel(
  33378. - i + (EXAMPLE_WIDTH - targetWidth) / 2, j + (EXAMPLE_HEIGHT - targetHeight) / 2);
  33379. - assertThat(outputBitmap.getPixel(i, j)).isEqualTo(expected);
  33380. - }
  33381. + private static final int EXAMPLE_WIDTH = 10;
  33382. + private static final int EXAMPLE_HEIGHT = 15;
  33383. +
  33384. + @Before
  33385. + public void setUp() {
  33386. + exampleBitmap = createExampleBitmap();
  33387. + input = new TensorImage(DataType.UINT8);
  33388. + input.load(exampleBitmap);
  33389. }
  33390. - }
  33391. -
  33392. - @Test
  33393. - public void testResizeWithPad() {
  33394. - int targetWidth = 15;
  33395. - int targetHeight = 20;
  33396. - ImageProcessor processor =
  33397. - new ImageProcessor.Builder()
  33398. - .add(new ResizeWithCropOrPadOp(targetHeight, targetWidth))
  33399. - .build();
  33400. - TensorImage output = processor.process(input);
  33401. - // Pad 2 rows / columns on top / left, and 3 rows / columns on bottom / right
  33402. -
  33403. - Bitmap outputBitmap = output.getBitmap();
  33404. - assertThat(outputBitmap.getWidth()).isEqualTo(targetWidth);
  33405. - assertThat(outputBitmap.getHeight()).isEqualTo(targetHeight);
  33406. - int leftPad = (targetWidth - EXAMPLE_WIDTH) / 2;
  33407. - int topPad = (targetHeight - EXAMPLE_HEIGHT) / 2;
  33408. - for (int i = 0; i < outputBitmap.getWidth(); i++) {
  33409. - for (int j = 0; j < outputBitmap.getHeight(); j++) {
  33410. - int expected = 0; // ZERO padding
  33411. - if (i >= leftPad
  33412. - && i < leftPad + EXAMPLE_WIDTH
  33413. - && j >= topPad
  33414. - && j < topPad + EXAMPLE_HEIGHT) {
  33415. - expected = exampleBitmap.getPixel(i - leftPad, j - topPad);
  33416. +
  33417. + @Test
  33418. + public void testResizeWithCrop() {
  33419. + int targetWidth = 6;
  33420. + int targetHeight = 5;
  33421. + ImageProcessor processor =
  33422. + new ImageProcessor.Builder()
  33423. + .add(new ResizeWithCropOrPadOp(targetHeight, targetWidth))
  33424. + .build();
  33425. + TensorImage output = processor.process(input);
  33426. +
  33427. + Bitmap outputBitmap = output.getBitmap();
  33428. + assertThat(outputBitmap.getWidth()).isEqualTo(targetWidth);
  33429. + assertThat(outputBitmap.getHeight()).isEqualTo(targetHeight);
  33430. + for (int i = 0; i < outputBitmap.getWidth(); i++) {
  33431. + for (int j = 0; j < outputBitmap.getHeight(); j++) {
  33432. + int expected = exampleBitmap.getPixel(i + (EXAMPLE_WIDTH - targetWidth) / 2,
  33433. + j + (EXAMPLE_HEIGHT - targetHeight) / 2);
  33434. + assertThat(outputBitmap.getPixel(i, j)).isEqualTo(expected);
  33435. + }
  33436. + }
  33437. + }
  33438. +
  33439. + @Test
  33440. + public void testResizeWithPad() {
  33441. + int targetWidth = 15;
  33442. + int targetHeight = 20;
  33443. + ImageProcessor processor =
  33444. + new ImageProcessor.Builder()
  33445. + .add(new ResizeWithCropOrPadOp(targetHeight, targetWidth))
  33446. + .build();
  33447. + TensorImage output = processor.process(input);
  33448. + // Pad 2 rows / columns on top / left, and 3 rows / columns on bottom / right
  33449. +
  33450. + Bitmap outputBitmap = output.getBitmap();
  33451. + assertThat(outputBitmap.getWidth()).isEqualTo(targetWidth);
  33452. + assertThat(outputBitmap.getHeight()).isEqualTo(targetHeight);
  33453. + int leftPad = (targetWidth - EXAMPLE_WIDTH) / 2;
  33454. + int topPad = (targetHeight - EXAMPLE_HEIGHT) / 2;
  33455. + for (int i = 0; i < outputBitmap.getWidth(); i++) {
  33456. + for (int j = 0; j < outputBitmap.getHeight(); j++) {
  33457. + int expected = 0; // ZERO padding
  33458. + if (i >= leftPad && i < leftPad + EXAMPLE_WIDTH && j >= topPad
  33459. + && j < topPad + EXAMPLE_HEIGHT) {
  33460. + expected = exampleBitmap.getPixel(i - leftPad, j - topPad);
  33461. + }
  33462. + assertThat(outputBitmap.getPixel(i, j)).isEqualTo(expected);
  33463. + }
  33464. }
  33465. - assertThat(outputBitmap.getPixel(i, j)).isEqualTo(expected);
  33466. - }
  33467. }
  33468. - }
  33469. -
  33470. - @Test
  33471. - public void testResizeWithCropAndPad() {
  33472. - int targetSize = 12;
  33473. - // Pad 1 column on left & right, crop 1 row on top and 2 rows on bottom
  33474. - ImageProcessor processor =
  33475. - new ImageProcessor.Builder().add(new ResizeWithCropOrPadOp(targetSize, targetSize)).build();
  33476. - TensorImage output = processor.process(input);
  33477. -
  33478. - Bitmap outputBitmap = output.getBitmap();
  33479. - assertThat(outputBitmap.getWidth()).isEqualTo(targetSize);
  33480. - assertThat(outputBitmap.getHeight()).isEqualTo(targetSize);
  33481. -
  33482. - int leftPad = (targetSize - EXAMPLE_WIDTH) / 2;
  33483. - int topCrop = (EXAMPLE_HEIGHT - targetSize) / 2;
  33484. - for (int i = 0; i < outputBitmap.getWidth(); i++) {
  33485. - for (int j = 0; j < outputBitmap.getHeight(); j++) {
  33486. - int expected = 0;
  33487. - if (i >= leftPad && i < leftPad + EXAMPLE_WIDTH) {
  33488. - expected = exampleBitmap.getPixel(i - leftPad, j + topCrop);
  33489. +
  33490. + @Test
  33491. + public void testResizeWithCropAndPad() {
  33492. + int targetSize = 12;
  33493. + // Pad 1 column on left & right, crop 1 row on top and 2 rows on bottom
  33494. + ImageProcessor processor = new ImageProcessor.Builder()
  33495. + .add(new ResizeWithCropOrPadOp(targetSize, targetSize))
  33496. + .build();
  33497. + TensorImage output = processor.process(input);
  33498. +
  33499. + Bitmap outputBitmap = output.getBitmap();
  33500. + assertThat(outputBitmap.getWidth()).isEqualTo(targetSize);
  33501. + assertThat(outputBitmap.getHeight()).isEqualTo(targetSize);
  33502. +
  33503. + int leftPad = (targetSize - EXAMPLE_WIDTH) / 2;
  33504. + int topCrop = (EXAMPLE_HEIGHT - targetSize) / 2;
  33505. + for (int i = 0; i < outputBitmap.getWidth(); i++) {
  33506. + for (int j = 0; j < outputBitmap.getHeight(); j++) {
  33507. + int expected = 0;
  33508. + if (i >= leftPad && i < leftPad + EXAMPLE_WIDTH) {
  33509. + expected = exampleBitmap.getPixel(i - leftPad, j + topCrop);
  33510. + }
  33511. + assertThat(outputBitmap.getPixel(i, j)).isEqualTo(expected);
  33512. + }
  33513. }
  33514. - assertThat(outputBitmap.getPixel(i, j)).isEqualTo(expected);
  33515. - }
  33516. }
  33517. - }
  33518. -
  33519. - @Test
  33520. - public void inverseTransformCorrectlyWhenCropped() {
  33521. - ResizeWithCropOrPadOp op = new ResizeWithCropOrPadOp(300, 300);
  33522. - // The point (100, 50) is transformed from 600x500 image
  33523. - PointF original = op.inverseTransform(new PointF(100, 50), 500, 600);
  33524. - assertThat(original.x).isEqualTo(250);
  33525. - assertThat(original.y).isEqualTo(150);
  33526. - PointF cropped = op.inverseTransform(new PointF(-10, -10), 500, 600);
  33527. - assertThat(cropped.x).isEqualTo(140);
  33528. - assertThat(cropped.y).isEqualTo(90);
  33529. - }
  33530. -
  33531. - @Test
  33532. - public void inverseTransformCorrectlyWhenPadded() {
  33533. - ResizeWithCropOrPadOp op = new ResizeWithCropOrPadOp(300, 300);
  33534. - // The point (100, 50) is transformed from 100x200 image
  33535. - PointF original = op.inverseTransform(new PointF(100, 50), 200, 100);
  33536. - assertThat(original.x).isEqualTo(0);
  33537. - assertThat(original.y).isEqualTo(0);
  33538. - PointF outside = op.inverseTransform(new PointF(50, 10), 200, 100);
  33539. - assertThat(outside.x).isEqualTo(-50);
  33540. - assertThat(outside.y).isEqualTo(-40);
  33541. - }
  33542. -
  33543. - /**
  33544. - * Creates an example bitmap, which is a 10x15 ARGB bitmap and pixels are set by: - pixel[i] = {A:
  33545. - * 255, R: i + 2, G: i + 1, B: i}, where i is the flatten index
  33546. - */
  33547. - private static Bitmap createExampleBitmap() {
  33548. - int[] colors = new int[EXAMPLE_WIDTH * EXAMPLE_HEIGHT];
  33549. - for (int i = 0; i < EXAMPLE_WIDTH * EXAMPLE_HEIGHT; i++) {
  33550. - colors[i] = (i << 16) | ((i + 1) << 8) | (i + 2);
  33551. +
  33552. + @Test
  33553. + public void inverseTransformCorrectlyWhenCropped() {
  33554. + ResizeWithCropOrPadOp op = new ResizeWithCropOrPadOp(300, 300);
  33555. + // The point (100, 50) is transformed from 600x500 image
  33556. + PointF original = op.inverseTransform(new PointF(100, 50), 500, 600);
  33557. + assertThat(original.x).isEqualTo(250);
  33558. + assertThat(original.y).isEqualTo(150);
  33559. + PointF cropped = op.inverseTransform(new PointF(-10, -10), 500, 600);
  33560. + assertThat(cropped.x).isEqualTo(140);
  33561. + assertThat(cropped.y).isEqualTo(90);
  33562. + }
  33563. +
  33564. + @Test
  33565. + public void inverseTransformCorrectlyWhenPadded() {
  33566. + ResizeWithCropOrPadOp op = new ResizeWithCropOrPadOp(300, 300);
  33567. + // The point (100, 50) is transformed from 100x200 image
  33568. + PointF original = op.inverseTransform(new PointF(100, 50), 200, 100);
  33569. + assertThat(original.x).isEqualTo(0);
  33570. + assertThat(original.y).isEqualTo(0);
  33571. + PointF outside = op.inverseTransform(new PointF(50, 10), 200, 100);
  33572. + assertThat(outside.x).isEqualTo(-50);
  33573. + assertThat(outside.y).isEqualTo(-40);
  33574. + }
  33575. +
  33576. + /**
  33577. + * Creates an example bitmap, which is a 10x15 ARGB bitmap and pixels are set by: - pixel[i] =
  33578. + * {A: 255, R: i + 2, G: i + 1, B: i}, where i is the flatten index
  33579. + */
  33580. + private static Bitmap createExampleBitmap() {
  33581. + int[] colors = new int[EXAMPLE_WIDTH * EXAMPLE_HEIGHT];
  33582. + for (int i = 0; i < EXAMPLE_WIDTH * EXAMPLE_HEIGHT; i++) {
  33583. + colors[i] = (i << 16) | ((i + 1) << 8) | (i + 2);
  33584. + }
  33585. + return Bitmap.createBitmap(colors, EXAMPLE_WIDTH, EXAMPLE_HEIGHT, Bitmap.Config.ARGB_8888);
  33586. }
  33587. - return Bitmap.createBitmap(colors, EXAMPLE_WIDTH, EXAMPLE_HEIGHT, Bitmap.Config.ARGB_8888);
  33588. - }
  33589. }
  33590. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ops/Rot90OpInstrumentedTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ops/Rot90OpInstrumentedTest.java
  33591. index eb54788764f1e..d00fe0e44422e 100644
  33592. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ops/Rot90OpInstrumentedTest.java
  33593. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ops/Rot90OpInstrumentedTest.java
  33594. @@ -19,7 +19,9 @@ import static com.google.common.truth.Truth.assertThat;
  33595. import android.graphics.Bitmap;
  33596. import android.graphics.PointF;
  33597. +
  33598. import androidx.test.ext.junit.runners.AndroidJUnit4;
  33599. +
  33600. import org.junit.Before;
  33601. import org.junit.Test;
  33602. import org.junit.runner.RunWith;
  33603. @@ -30,68 +32,68 @@ import org.tensorflow.lite.support.image.TensorImage;
  33604. /** Instrumented unit test for {@link Rot90Op}. */
  33605. @RunWith(AndroidJUnit4.class)
  33606. public class Rot90OpInstrumentedTest {
  33607. + private Bitmap exampleBitmap;
  33608. + private TensorImage input;
  33609. +
  33610. + private static final int EXAMPLE_WIDTH = 10;
  33611. + private static final int EXAMPLE_HEIGHT = 15;
  33612. - private Bitmap exampleBitmap;
  33613. - private TensorImage input;
  33614. -
  33615. - private static final int EXAMPLE_WIDTH = 10;
  33616. - private static final int EXAMPLE_HEIGHT = 15;
  33617. -
  33618. - @Before
  33619. - public void setUp() {
  33620. - exampleBitmap = createExampleBitmap();
  33621. - input = new TensorImage(DataType.UINT8);
  33622. - input.load(exampleBitmap);
  33623. - }
  33624. -
  33625. - @Test
  33626. - public void testRot90() {
  33627. - ImageProcessor processor = new ImageProcessor.Builder().add(new Rot90Op()).build();
  33628. - TensorImage output = processor.process(input);
  33629. -
  33630. - Bitmap outputBitmap = output.getBitmap();
  33631. - assertThat(outputBitmap.getWidth()).isEqualTo(EXAMPLE_HEIGHT);
  33632. - assertThat(outputBitmap.getHeight()).isEqualTo(EXAMPLE_WIDTH);
  33633. - for (int i = 0; i < exampleBitmap.getWidth(); i++) {
  33634. - for (int j = 0; j < exampleBitmap.getHeight(); j++) {
  33635. - assertThat(exampleBitmap.getPixel(i, j))
  33636. - .isEqualTo(outputBitmap.getPixel(j, EXAMPLE_WIDTH - 1 - i));
  33637. - }
  33638. + @Before
  33639. + public void setUp() {
  33640. + exampleBitmap = createExampleBitmap();
  33641. + input = new TensorImage(DataType.UINT8);
  33642. + input.load(exampleBitmap);
  33643. }
  33644. - }
  33645. -
  33646. - @Test
  33647. - public void testRot90Twice() {
  33648. - ImageProcessor processor = new ImageProcessor.Builder().add(new Rot90Op(2)).build();
  33649. - TensorImage output = processor.process(input);
  33650. -
  33651. - Bitmap outputBitmap = output.getBitmap();
  33652. - assertThat(outputBitmap.getWidth()).isEqualTo(EXAMPLE_WIDTH);
  33653. - assertThat(outputBitmap.getHeight()).isEqualTo(EXAMPLE_HEIGHT);
  33654. - for (int i = 0; i < exampleBitmap.getWidth(); i++) {
  33655. - for (int j = 0; j < exampleBitmap.getHeight(); j++) {
  33656. - assertThat(exampleBitmap.getPixel(i, j))
  33657. - .isEqualTo(outputBitmap.getPixel(EXAMPLE_WIDTH - 1 - i, EXAMPLE_HEIGHT - 1 - j));
  33658. - }
  33659. +
  33660. + @Test
  33661. + public void testRot90() {
  33662. + ImageProcessor processor = new ImageProcessor.Builder().add(new Rot90Op()).build();
  33663. + TensorImage output = processor.process(input);
  33664. +
  33665. + Bitmap outputBitmap = output.getBitmap();
  33666. + assertThat(outputBitmap.getWidth()).isEqualTo(EXAMPLE_HEIGHT);
  33667. + assertThat(outputBitmap.getHeight()).isEqualTo(EXAMPLE_WIDTH);
  33668. + for (int i = 0; i < exampleBitmap.getWidth(); i++) {
  33669. + for (int j = 0; j < exampleBitmap.getHeight(); j++) {
  33670. + assertThat(exampleBitmap.getPixel(i, j))
  33671. + .isEqualTo(outputBitmap.getPixel(j, EXAMPLE_WIDTH - 1 - i));
  33672. + }
  33673. + }
  33674. }
  33675. - }
  33676. -
  33677. - @Test
  33678. - public void inverseTransformCorrectlyWhenRotated() {
  33679. - Rot90Op op = new Rot90Op(3);
  33680. - PointF original = op.inverseTransform(new PointF(20, 10), 200, 100);
  33681. - assertThat(original.x).isEqualTo(10);
  33682. - assertThat(original.y).isEqualTo(180);
  33683. - PointF outside = op.inverseTransform(new PointF(-10, 110), 200, 100);
  33684. - assertThat(outside.x).isEqualTo(110);
  33685. - assertThat(outside.y).isEqualTo(210);
  33686. - }
  33687. -
  33688. - private static Bitmap createExampleBitmap() {
  33689. - int[] colors = new int[EXAMPLE_WIDTH * EXAMPLE_HEIGHT];
  33690. - for (int i = 0; i < EXAMPLE_WIDTH * EXAMPLE_HEIGHT; i++) {
  33691. - colors[i] = (i << 16) | ((i + 1) << 8) | (i + 2);
  33692. +
  33693. + @Test
  33694. + public void testRot90Twice() {
  33695. + ImageProcessor processor = new ImageProcessor.Builder().add(new Rot90Op(2)).build();
  33696. + TensorImage output = processor.process(input);
  33697. +
  33698. + Bitmap outputBitmap = output.getBitmap();
  33699. + assertThat(outputBitmap.getWidth()).isEqualTo(EXAMPLE_WIDTH);
  33700. + assertThat(outputBitmap.getHeight()).isEqualTo(EXAMPLE_HEIGHT);
  33701. + for (int i = 0; i < exampleBitmap.getWidth(); i++) {
  33702. + for (int j = 0; j < exampleBitmap.getHeight(); j++) {
  33703. + assertThat(exampleBitmap.getPixel(i, j))
  33704. + .isEqualTo(outputBitmap.getPixel(
  33705. + EXAMPLE_WIDTH - 1 - i, EXAMPLE_HEIGHT - 1 - j));
  33706. + }
  33707. + }
  33708. + }
  33709. +
  33710. + @Test
  33711. + public void inverseTransformCorrectlyWhenRotated() {
  33712. + Rot90Op op = new Rot90Op(3);
  33713. + PointF original = op.inverseTransform(new PointF(20, 10), 200, 100);
  33714. + assertThat(original.x).isEqualTo(10);
  33715. + assertThat(original.y).isEqualTo(180);
  33716. + PointF outside = op.inverseTransform(new PointF(-10, 110), 200, 100);
  33717. + assertThat(outside.x).isEqualTo(110);
  33718. + assertThat(outside.y).isEqualTo(210);
  33719. + }
  33720. +
  33721. + private static Bitmap createExampleBitmap() {
  33722. + int[] colors = new int[EXAMPLE_WIDTH * EXAMPLE_HEIGHT];
  33723. + for (int i = 0; i < EXAMPLE_WIDTH * EXAMPLE_HEIGHT; i++) {
  33724. + colors[i] = (i << 16) | ((i + 1) << 8) | (i + 2);
  33725. + }
  33726. + return Bitmap.createBitmap(colors, EXAMPLE_WIDTH, EXAMPLE_HEIGHT, Bitmap.Config.ARGB_8888);
  33727. }
  33728. - return Bitmap.createBitmap(colors, EXAMPLE_WIDTH, EXAMPLE_HEIGHT, Bitmap.Config.ARGB_8888);
  33729. - }
  33730. }
  33731. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ops/TransformToGrayScaleOpInstrumentedTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ops/TransformToGrayScaleOpInstrumentedTest.java
  33732. index 46713fd486fa7..f024f68911d27 100644
  33733. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ops/TransformToGrayScaleOpInstrumentedTest.java
  33734. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ops/TransformToGrayScaleOpInstrumentedTest.java
  33735. @@ -16,6 +16,7 @@ limitations under the License.
  33736. package org.tensorflow.lite.support.image.ops;
  33737. import static com.google.common.truth.Truth.assertThat;
  33738. +
  33739. import static org.junit.Assert.assertThrows;
  33740. import static org.mockito.Mockito.doReturn;
  33741. import static org.tensorflow.lite.DataType.UINT8;
  33742. @@ -24,7 +25,9 @@ import android.graphics.Bitmap;
  33743. import android.graphics.Color;
  33744. import android.graphics.ImageFormat;
  33745. import android.media.Image;
  33746. +
  33747. import androidx.test.ext.junit.runners.AndroidJUnit4;
  33748. +
  33749. import org.junit.Before;
  33750. import org.junit.Rule;
  33751. import org.junit.Test;
  33752. @@ -40,54 +43,55 @@ import org.tensorflow.lite.support.image.TensorImage;
  33753. /** Instrumented unit test for {@link TransformToGrayscaleOp}. */
  33754. @RunWith(AndroidJUnit4.class)
  33755. public class TransformToGrayScaleOpInstrumentedTest {
  33756. -
  33757. - @Rule public final MockitoRule mockito = MockitoJUnit.rule();
  33758. -
  33759. - private TensorImage input;
  33760. -
  33761. - private static final int EXAMPLE_WIDTH = 2;
  33762. - private static final int EXAMPLE_HEIGHT = 3;
  33763. - @Mock Image imageMock;
  33764. -
  33765. - @Before
  33766. - public void setUp() {
  33767. - Bitmap exampleBitmap = createExampleBitmap();
  33768. - input = new TensorImage(DataType.UINT8);
  33769. - input.load(exampleBitmap);
  33770. - }
  33771. -
  33772. - @Test
  33773. - public void apply_onRgb_succeeds() {
  33774. - ImageProcessor processor =
  33775. - new ImageProcessor.Builder().add(new TransformToGrayscaleOp()).build();
  33776. -
  33777. - TensorImage output = processor.process(input);
  33778. - int[] pixels = output.getTensorBuffer().getIntArray();
  33779. -
  33780. - assertThat(output.getWidth()).isEqualTo(EXAMPLE_WIDTH);
  33781. - assertThat(output.getHeight()).isEqualTo(EXAMPLE_HEIGHT);
  33782. - assertThat(output.getColorSpaceType()).isEqualTo(ColorSpaceType.GRAYSCALE);
  33783. - assertThat(pixels).isEqualTo(new int[] {0, 255, 76, 29, 150, 179});
  33784. - }
  33785. -
  33786. - @Test
  33787. - public void apply_onYuv_throws() {
  33788. - setUpImageMock(imageMock, ImageFormat.YUV_420_888);
  33789. - TensorImage tensorImage = new TensorImage(UINT8);
  33790. - tensorImage.load(imageMock);
  33791. - ImageProcessor processor =
  33792. - new ImageProcessor.Builder().add(new TransformToGrayscaleOp()).build();
  33793. -
  33794. - assertThrows(IllegalArgumentException.class, () -> processor.process(tensorImage));
  33795. - }
  33796. -
  33797. - private static Bitmap createExampleBitmap() {
  33798. - int[] colors =
  33799. - new int[] {Color.BLACK, Color.WHITE, Color.RED, Color.BLUE, Color.GREEN, Color.CYAN};
  33800. - return Bitmap.createBitmap(colors, EXAMPLE_WIDTH, EXAMPLE_HEIGHT, Bitmap.Config.ARGB_8888);
  33801. - }
  33802. -
  33803. - private static void setUpImageMock(Image imageMock, int imageFormat) {
  33804. - doReturn(imageFormat).when(imageMock).getFormat();
  33805. - }
  33806. + @Rule
  33807. + public final MockitoRule mockito = MockitoJUnit.rule();
  33808. +
  33809. + private TensorImage input;
  33810. +
  33811. + private static final int EXAMPLE_WIDTH = 2;
  33812. + private static final int EXAMPLE_HEIGHT = 3;
  33813. + @Mock
  33814. + Image imageMock;
  33815. +
  33816. + @Before
  33817. + public void setUp() {
  33818. + Bitmap exampleBitmap = createExampleBitmap();
  33819. + input = new TensorImage(DataType.UINT8);
  33820. + input.load(exampleBitmap);
  33821. + }
  33822. +
  33823. + @Test
  33824. + public void apply_onRgb_succeeds() {
  33825. + ImageProcessor processor =
  33826. + new ImageProcessor.Builder().add(new TransformToGrayscaleOp()).build();
  33827. +
  33828. + TensorImage output = processor.process(input);
  33829. + int[] pixels = output.getTensorBuffer().getIntArray();
  33830. +
  33831. + assertThat(output.getWidth()).isEqualTo(EXAMPLE_WIDTH);
  33832. + assertThat(output.getHeight()).isEqualTo(EXAMPLE_HEIGHT);
  33833. + assertThat(output.getColorSpaceType()).isEqualTo(ColorSpaceType.GRAYSCALE);
  33834. + assertThat(pixels).isEqualTo(new int[] {0, 255, 76, 29, 150, 179});
  33835. + }
  33836. +
  33837. + @Test
  33838. + public void apply_onYuv_throws() {
  33839. + setUpImageMock(imageMock, ImageFormat.YUV_420_888);
  33840. + TensorImage tensorImage = new TensorImage(UINT8);
  33841. + tensorImage.load(imageMock);
  33842. + ImageProcessor processor =
  33843. + new ImageProcessor.Builder().add(new TransformToGrayscaleOp()).build();
  33844. +
  33845. + assertThrows(IllegalArgumentException.class, () -> processor.process(tensorImage));
  33846. + }
  33847. +
  33848. + private static Bitmap createExampleBitmap() {
  33849. + int[] colors = new int[] {
  33850. + Color.BLACK, Color.WHITE, Color.RED, Color.BLUE, Color.GREEN, Color.CYAN};
  33851. + return Bitmap.createBitmap(colors, EXAMPLE_WIDTH, EXAMPLE_HEIGHT, Bitmap.Config.ARGB_8888);
  33852. + }
  33853. +
  33854. + private static void setUpImageMock(Image imageMock, int imageFormat) {
  33855. + doReturn(imageFormat).when(imageMock).getFormat();
  33856. + }
  33857. }
  33858. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/label/CategoryTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/label/CategoryTest.java
  33859. index 28620dd941e9c..98d1f92f56c6d 100644
  33860. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/label/CategoryTest.java
  33861. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/label/CategoryTest.java
  33862. @@ -24,114 +24,98 @@ import org.robolectric.RobolectricTestRunner;
  33863. /** Tests of {@link org.tensorflow.lite.support.label.Category}. */
  33864. @RunWith(RobolectricTestRunner.class)
  33865. public final class CategoryTest {
  33866. - private static final String APPLE_LABEL = "apple";
  33867. - private static final String DEFAULT_DISPLAY_NAME = "";
  33868. - private static final String APPLE_DISPLAY_NAME = "manzana"; // "apple" in Spanish.
  33869. - private static final float APPLE_SCORE = 0.5f;
  33870. - private static final int APPLE_INDEX = 10;
  33871. -
  33872. - @Test
  33873. - public void createShouldSucceed() {
  33874. - Category category = Category.create(APPLE_LABEL, APPLE_DISPLAY_NAME, APPLE_SCORE);
  33875. -
  33876. - assertThat(category.getLabel()).isEqualTo(APPLE_LABEL);
  33877. - assertThat(category.getDisplayName()).isEqualTo(APPLE_DISPLAY_NAME);
  33878. - assertThat(category.getScore()).isWithin(1e-7f).of(APPLE_SCORE);
  33879. - }
  33880. -
  33881. - @Test
  33882. - public void createWithIndexShouldSucceed() {
  33883. - Category category = Category.create(APPLE_LABEL, APPLE_DISPLAY_NAME, APPLE_SCORE, APPLE_INDEX);
  33884. -
  33885. - assertThat(category.getLabel()).isEqualTo(APPLE_LABEL);
  33886. - assertThat(category.getDisplayName()).isEqualTo(APPLE_DISPLAY_NAME);
  33887. - assertThat(category.getScore()).isWithin(1e-7f).of(APPLE_SCORE);
  33888. - assertThat(category.getIndex()).isEqualTo(APPLE_INDEX);
  33889. - }
  33890. -
  33891. - @Test
  33892. - public void constructorShouldSucceed() {
  33893. - Category category = new Category(APPLE_LABEL, APPLE_SCORE);
  33894. -
  33895. - assertThat(category.getLabel()).isEqualTo(APPLE_LABEL);
  33896. - // Using the constructor, displayName will be default to an empty string.
  33897. - assertThat(category.getDisplayName()).isEqualTo(DEFAULT_DISPLAY_NAME);
  33898. - assertThat(category.getScore()).isWithin(1e-7f).of(APPLE_SCORE);
  33899. - }
  33900. -
  33901. - @Test
  33902. - public void toStringWithCreateShouldProvideReadableResult() {
  33903. - Category category = Category.create(APPLE_LABEL, APPLE_DISPLAY_NAME, APPLE_SCORE);
  33904. - String categoryString = category.toString();
  33905. -
  33906. - assertThat(categoryString)
  33907. - .isEqualTo(
  33908. - "<Category \""
  33909. - + APPLE_LABEL
  33910. - + "\" (displayName="
  33911. - + APPLE_DISPLAY_NAME
  33912. - + " score="
  33913. - + APPLE_SCORE
  33914. - + " index=-1"
  33915. - + ")>");
  33916. - }
  33917. -
  33918. - @Test
  33919. - public void toStringWithCreateIndexShouldProvideReadableResult() {
  33920. - Category category = Category.create(APPLE_LABEL, APPLE_DISPLAY_NAME, APPLE_SCORE, APPLE_INDEX);
  33921. - String categoryString = category.toString();
  33922. -
  33923. - assertThat(categoryString)
  33924. - .isEqualTo(
  33925. - "<Category \""
  33926. - + APPLE_LABEL
  33927. - + "\" (displayName="
  33928. - + APPLE_DISPLAY_NAME
  33929. - + " score="
  33930. - + APPLE_SCORE
  33931. - + " index="
  33932. - + APPLE_INDEX
  33933. - + ")>");
  33934. - }
  33935. -
  33936. - @Test
  33937. - public void toStringWithConstuctorShouldProvideReadableResult() {
  33938. - Category category = new Category(APPLE_LABEL, APPLE_SCORE);
  33939. - String categoryString = category.toString();
  33940. -
  33941. - assertThat(categoryString)
  33942. - .isEqualTo(
  33943. - "<Category \""
  33944. - + APPLE_LABEL
  33945. - + "\" (displayName="
  33946. - + DEFAULT_DISPLAY_NAME
  33947. - + " score="
  33948. - + APPLE_SCORE
  33949. - + " index=-1"
  33950. - + ")>");
  33951. - }
  33952. -
  33953. - @Test
  33954. - public void equalsShouldSucceedWithCreate() {
  33955. - Category categoryA = Category.create(APPLE_LABEL, APPLE_DISPLAY_NAME, APPLE_SCORE);
  33956. - Category categoryB = Category.create(APPLE_LABEL, APPLE_DISPLAY_NAME, APPLE_SCORE);
  33957. -
  33958. - assertThat(categoryA).isEqualTo(categoryB);
  33959. - }
  33960. -
  33961. - @Test
  33962. - public void equalsShouldSucceedWithCreateIndex() {
  33963. - Category categoryA = Category.create(APPLE_LABEL, APPLE_DISPLAY_NAME, APPLE_SCORE, APPLE_INDEX);
  33964. - Category categoryB = Category.create(APPLE_LABEL, APPLE_DISPLAY_NAME, APPLE_SCORE, APPLE_INDEX);
  33965. -
  33966. - assertThat(categoryA).isEqualTo(categoryB);
  33967. - }
  33968. -
  33969. - @Test
  33970. - public void equalsShouldSucceedWithConstructor() {
  33971. - Category categoryA = new Category(APPLE_LABEL, APPLE_SCORE);
  33972. - Category categoryB = new Category(APPLE_LABEL, APPLE_SCORE);
  33973. -
  33974. - assertThat(categoryA).isEqualTo(categoryB);
  33975. - }
  33976. + private static final String APPLE_LABEL = "apple";
  33977. + private static final String DEFAULT_DISPLAY_NAME = "";
  33978. + private static final String APPLE_DISPLAY_NAME = "manzana"; // "apple" in Spanish.
  33979. + private static final float APPLE_SCORE = 0.5f;
  33980. + private static final int APPLE_INDEX = 10;
  33981. +
  33982. + @Test
  33983. + public void createShouldSucceed() {
  33984. + Category category = Category.create(APPLE_LABEL, APPLE_DISPLAY_NAME, APPLE_SCORE);
  33985. +
  33986. + assertThat(category.getLabel()).isEqualTo(APPLE_LABEL);
  33987. + assertThat(category.getDisplayName()).isEqualTo(APPLE_DISPLAY_NAME);
  33988. + assertThat(category.getScore()).isWithin(1e-7f).of(APPLE_SCORE);
  33989. + }
  33990. +
  33991. + @Test
  33992. + public void createWithIndexShouldSucceed() {
  33993. + Category category =
  33994. + Category.create(APPLE_LABEL, APPLE_DISPLAY_NAME, APPLE_SCORE, APPLE_INDEX);
  33995. +
  33996. + assertThat(category.getLabel()).isEqualTo(APPLE_LABEL);
  33997. + assertThat(category.getDisplayName()).isEqualTo(APPLE_DISPLAY_NAME);
  33998. + assertThat(category.getScore()).isWithin(1e-7f).of(APPLE_SCORE);
  33999. + assertThat(category.getIndex()).isEqualTo(APPLE_INDEX);
  34000. + }
  34001. +
  34002. + @Test
  34003. + public void constructorShouldSucceed() {
  34004. + Category category = new Category(APPLE_LABEL, APPLE_SCORE);
  34005. +
  34006. + assertThat(category.getLabel()).isEqualTo(APPLE_LABEL);
  34007. + // Using the constructor, displayName will be default to an empty string.
  34008. + assertThat(category.getDisplayName()).isEqualTo(DEFAULT_DISPLAY_NAME);
  34009. + assertThat(category.getScore()).isWithin(1e-7f).of(APPLE_SCORE);
  34010. + }
  34011. +
  34012. + @Test
  34013. + public void toStringWithCreateShouldProvideReadableResult() {
  34014. + Category category = Category.create(APPLE_LABEL, APPLE_DISPLAY_NAME, APPLE_SCORE);
  34015. + String categoryString = category.toString();
  34016. +
  34017. + assertThat(categoryString)
  34018. + .isEqualTo("<Category \"" + APPLE_LABEL + "\" (displayName=" + APPLE_DISPLAY_NAME
  34019. + + " score=" + APPLE_SCORE + " index=-1"
  34020. + + ")>");
  34021. + }
  34022. +
  34023. + @Test
  34024. + public void toStringWithCreateIndexShouldProvideReadableResult() {
  34025. + Category category =
  34026. + Category.create(APPLE_LABEL, APPLE_DISPLAY_NAME, APPLE_SCORE, APPLE_INDEX);
  34027. + String categoryString = category.toString();
  34028. +
  34029. + assertThat(categoryString)
  34030. + .isEqualTo("<Category \"" + APPLE_LABEL + "\" (displayName=" + APPLE_DISPLAY_NAME
  34031. + + " score=" + APPLE_SCORE + " index=" + APPLE_INDEX + ")>");
  34032. + }
  34033. +
  34034. + @Test
  34035. + public void toStringWithConstuctorShouldProvideReadableResult() {
  34036. + Category category = new Category(APPLE_LABEL, APPLE_SCORE);
  34037. + String categoryString = category.toString();
  34038. +
  34039. + assertThat(categoryString)
  34040. + .isEqualTo("<Category \"" + APPLE_LABEL + "\" (displayName=" + DEFAULT_DISPLAY_NAME
  34041. + + " score=" + APPLE_SCORE + " index=-1"
  34042. + + ")>");
  34043. + }
  34044. +
  34045. + @Test
  34046. + public void equalsShouldSucceedWithCreate() {
  34047. + Category categoryA = Category.create(APPLE_LABEL, APPLE_DISPLAY_NAME, APPLE_SCORE);
  34048. + Category categoryB = Category.create(APPLE_LABEL, APPLE_DISPLAY_NAME, APPLE_SCORE);
  34049. +
  34050. + assertThat(categoryA).isEqualTo(categoryB);
  34051. + }
  34052. +
  34053. + @Test
  34054. + public void equalsShouldSucceedWithCreateIndex() {
  34055. + Category categoryA =
  34056. + Category.create(APPLE_LABEL, APPLE_DISPLAY_NAME, APPLE_SCORE, APPLE_INDEX);
  34057. + Category categoryB =
  34058. + Category.create(APPLE_LABEL, APPLE_DISPLAY_NAME, APPLE_SCORE, APPLE_INDEX);
  34059. +
  34060. + assertThat(categoryA).isEqualTo(categoryB);
  34061. + }
  34062. +
  34063. + @Test
  34064. + public void equalsShouldSucceedWithConstructor() {
  34065. + Category categoryA = new Category(APPLE_LABEL, APPLE_SCORE);
  34066. + Category categoryB = new Category(APPLE_LABEL, APPLE_SCORE);
  34067. +
  34068. + assertThat(categoryA).isEqualTo(categoryB);
  34069. + }
  34070. }
  34071. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/label/LabelUtilTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/label/LabelUtilTest.java
  34072. index caa468bb0a9ec..91c81c4932b81 100644
  34073. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/label/LabelUtilTest.java
  34074. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/label/LabelUtilTest.java
  34075. @@ -17,35 +17,38 @@ package org.tensorflow.lite.support.label;
  34076. import static com.google.common.truth.Truth.assertThat;
  34077. -import java.util.Arrays;
  34078. -import java.util.List;
  34079. import org.junit.Test;
  34080. import org.junit.runner.RunWith;
  34081. import org.robolectric.RobolectricTestRunner;
  34082. import org.tensorflow.lite.DataType;
  34083. import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
  34084. +import java.util.Arrays;
  34085. +import java.util.List;
  34086. +
  34087. /** Tests of {@link org.tensorflow.lite.support.label.LabelUtil}. */
  34088. @RunWith(RobolectricTestRunner.class)
  34089. public class LabelUtilTest {
  34090. -
  34091. - @Test
  34092. - public void mapIndexToStringsWithInvalidValues() {
  34093. - String[] labels = new String[] {"background", "apple", "banana", "cherry", "date"};
  34094. - TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.UINT8);
  34095. - tensorBuffer.loadArray(new int[] {0, 1, 2, 3, 2, 5}, new int[] {1, 6});
  34096. - List<String> categories = LabelUtil.mapValueToLabels(tensorBuffer, Arrays.asList(labels), 1);
  34097. - assertThat(categories.toArray())
  34098. - .isEqualTo(new String[] {"apple", "banana", "cherry", "date", "cherry", ""});
  34099. - }
  34100. -
  34101. - @Test
  34102. - public void mapFloatIndexShouldCast() {
  34103. - String[] labels = new String[] {"background", "apple", "banana", "cherry", "date"};
  34104. - TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.FLOAT32);
  34105. - tensorBuffer.loadArray(new float[] {-1.1f, -0.3f, 0.3f, 1.2f, 1.8f, 1}, new int[] {1, 6});
  34106. - List<String> categories = LabelUtil.mapValueToLabels(tensorBuffer, Arrays.asList(labels), 1);
  34107. - assertThat(categories.toArray())
  34108. - .isEqualTo(new String[] {"background", "apple", "apple", "banana", "banana", "banana"});
  34109. - }
  34110. + @Test
  34111. + public void mapIndexToStringsWithInvalidValues() {
  34112. + String[] labels = new String[] {"background", "apple", "banana", "cherry", "date"};
  34113. + TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.UINT8);
  34114. + tensorBuffer.loadArray(new int[] {0, 1, 2, 3, 2, 5}, new int[] {1, 6});
  34115. + List<String> categories =
  34116. + LabelUtil.mapValueToLabels(tensorBuffer, Arrays.asList(labels), 1);
  34117. + assertThat(categories.toArray())
  34118. + .isEqualTo(new String[] {"apple", "banana", "cherry", "date", "cherry", ""});
  34119. + }
  34120. +
  34121. + @Test
  34122. + public void mapFloatIndexShouldCast() {
  34123. + String[] labels = new String[] {"background", "apple", "banana", "cherry", "date"};
  34124. + TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.FLOAT32);
  34125. + tensorBuffer.loadArray(new float[] {-1.1f, -0.3f, 0.3f, 1.2f, 1.8f, 1}, new int[] {1, 6});
  34126. + List<String> categories =
  34127. + LabelUtil.mapValueToLabels(tensorBuffer, Arrays.asList(labels), 1);
  34128. + assertThat(categories.toArray())
  34129. + .isEqualTo(new String[] {
  34130. + "background", "apple", "apple", "banana", "banana", "banana"});
  34131. + }
  34132. }
  34133. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/label/TensorLabelTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/label/TensorLabelTest.java
  34134. index 4f296b7476c2d..857a77a2a4bd4 100644
  34135. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/label/TensorLabelTest.java
  34136. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/label/TensorLabelTest.java
  34137. @@ -17,10 +17,6 @@ package org.tensorflow.lite.support.label;
  34138. import static com.google.common.truth.Truth.assertThat;
  34139. -import java.util.Arrays;
  34140. -import java.util.HashMap;
  34141. -import java.util.List;
  34142. -import java.util.Map;
  34143. import org.junit.Assert;
  34144. import org.junit.Test;
  34145. import org.junit.runner.RunWith;
  34146. @@ -28,169 +24,180 @@ import org.robolectric.RobolectricTestRunner;
  34147. import org.tensorflow.lite.DataType;
  34148. import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
  34149. +import java.util.Arrays;
  34150. +import java.util.HashMap;
  34151. +import java.util.List;
  34152. +import java.util.Map;
  34153. +
  34154. /** Tests of {@link org.tensorflow.lite.support.label.TensorLabel}. */
  34155. @RunWith(RobolectricTestRunner.class)
  34156. public final class TensorLabelTest {
  34157. - @Test
  34158. - public void createTensorLabelWithNullAxisLabelsShouldFail() {
  34159. - int[] shape = {2};
  34160. - int[] arr = {1, 2};
  34161. - TensorBuffer buffer = TensorBuffer.createFixedSize(shape, DataType.UINT8);
  34162. - buffer.loadArray(arr, shape);
  34163. - Map<Integer, List<String>> nullAxisLabels = null;
  34164. -
  34165. - Assert.assertThrows(NullPointerException.class, () -> new TensorLabel(nullAxisLabels, buffer));
  34166. - }
  34167. -
  34168. - @Test
  34169. - public void createTensorLabelWithNullTensorBufferShouldFail() {
  34170. - Map<Integer, List<String>> axisLabels = new HashMap<>();
  34171. - axisLabels.put(1, Arrays.asList("a", "b", "c", "d"));
  34172. - TensorBuffer nullTensorBuffer = null;
  34173. -
  34174. - Assert.assertThrows(
  34175. - NullPointerException.class, () -> new TensorLabel(axisLabels, nullTensorBuffer));
  34176. - }
  34177. -
  34178. - @Test
  34179. - public void createTensorLabelWithStringListShouldSuccess() {
  34180. - TensorBuffer buffer = TensorBuffer.createFixedSize(new int[] {1, 4, 3}, DataType.FLOAT32);
  34181. -
  34182. - TensorLabel tensorLabel = new TensorLabel(Arrays.asList("a", "b", "c", "d"), buffer);
  34183. -
  34184. - assertThat(tensorLabel.getMapWithTensorBuffer()).isNotNull();
  34185. - assertThat(tensorLabel.getMapWithTensorBuffer().keySet()).contains("c"); // randomly pick one
  34186. - }
  34187. -
  34188. - @Test
  34189. - public void createTensorLabelWithEmptyShapeShouldFail() {
  34190. - int[] shape = new int[] {};
  34191. - TensorBuffer buffer = TensorBuffer.createFixedSize(shape, DataType.FLOAT32);
  34192. - Map<Integer, List<String>> axisLabels = new HashMap<>();
  34193. - axisLabels.put(1, Arrays.asList("a", "b", "c", "d"));
  34194. -
  34195. - Assert.assertThrows(IllegalArgumentException.class, () -> new TensorLabel(axisLabels, buffer));
  34196. - }
  34197. -
  34198. - @Test
  34199. - public void createTensorLabelWithMismatchedAxisShouldFail() {
  34200. - int[] shape = {1, 4};
  34201. - TensorBuffer buffer = TensorBuffer.createFixedSize(shape, DataType.FLOAT32);
  34202. - Map<Integer, List<String>> axisLabels = new HashMap<>();
  34203. - axisLabels.put(0, Arrays.asList("a", "b", "c", "d"));
  34204. -
  34205. - Assert.assertThrows(IllegalArgumentException.class, () -> new TensorLabel(axisLabels, buffer));
  34206. - }
  34207. -
  34208. - @Test
  34209. - public void createTensorLabelWithMismatchedShapeShouldFail() {
  34210. - int[] shape = {1, 3};
  34211. - TensorBuffer buffer = TensorBuffer.createFixedSize(shape, DataType.FLOAT32);
  34212. - Map<Integer, List<String>> axisLabels = new HashMap<>();
  34213. - axisLabels.put(1, Arrays.asList("a", "b", "c", "d"));
  34214. -
  34215. - Assert.assertThrows(IllegalArgumentException.class, () -> new TensorLabel(axisLabels, buffer));
  34216. - }
  34217. -
  34218. - @Test
  34219. - public void getMapWithFloatBufferValuesShouldSuccess() {
  34220. - int numberLabel = 4;
  34221. - float[] inputArr = {0.5f, 0.2f, 0.2f, 0.1f};
  34222. - int[] shape = {1, numberLabel};
  34223. - TensorBuffer input = TensorBuffer.createFixedSize(shape, DataType.FLOAT32);
  34224. - input.loadArray(inputArr, shape);
  34225. - Map<Integer, List<String>> axisLabels = new HashMap<>();
  34226. - int labelAxis = 1;
  34227. - axisLabels.put(labelAxis, Arrays.asList("a", "b", "c", "d"));
  34228. -
  34229. - TensorLabel tensorLabeled = new TensorLabel(axisLabels, input);
  34230. - Map<String, TensorBuffer> map = tensorLabeled.getMapWithTensorBuffer();
  34231. -
  34232. - for (int i = 0; i < numberLabel; i++) {
  34233. - String label = axisLabels.get(labelAxis).get(i);
  34234. - assertThat(map).containsKey(label);
  34235. - float[] array = map.get(label).getFloatArray();
  34236. - assertThat(array).hasLength(1);
  34237. - assertThat(array[0]).isEqualTo(inputArr[i]);
  34238. + @Test
  34239. + public void createTensorLabelWithNullAxisLabelsShouldFail() {
  34240. + int[] shape = {2};
  34241. + int[] arr = {1, 2};
  34242. + TensorBuffer buffer = TensorBuffer.createFixedSize(shape, DataType.UINT8);
  34243. + buffer.loadArray(arr, shape);
  34244. + Map<Integer, List<String>> nullAxisLabels = null;
  34245. +
  34246. + Assert.assertThrows(
  34247. + NullPointerException.class, () -> new TensorLabel(nullAxisLabels, buffer));
  34248. + }
  34249. +
  34250. + @Test
  34251. + public void createTensorLabelWithNullTensorBufferShouldFail() {
  34252. + Map<Integer, List<String>> axisLabels = new HashMap<>();
  34253. + axisLabels.put(1, Arrays.asList("a", "b", "c", "d"));
  34254. + TensorBuffer nullTensorBuffer = null;
  34255. +
  34256. + Assert.assertThrows(
  34257. + NullPointerException.class, () -> new TensorLabel(axisLabels, nullTensorBuffer));
  34258. + }
  34259. +
  34260. + @Test
  34261. + public void createTensorLabelWithStringListShouldSuccess() {
  34262. + TensorBuffer buffer = TensorBuffer.createFixedSize(new int[] {1, 4, 3}, DataType.FLOAT32);
  34263. +
  34264. + TensorLabel tensorLabel = new TensorLabel(Arrays.asList("a", "b", "c", "d"), buffer);
  34265. +
  34266. + assertThat(tensorLabel.getMapWithTensorBuffer()).isNotNull();
  34267. + assertThat(tensorLabel.getMapWithTensorBuffer().keySet())
  34268. + .contains("c"); // randomly pick one
  34269. + }
  34270. +
  34271. + @Test
  34272. + public void createTensorLabelWithEmptyShapeShouldFail() {
  34273. + int[] shape = new int[] {};
  34274. + TensorBuffer buffer = TensorBuffer.createFixedSize(shape, DataType.FLOAT32);
  34275. + Map<Integer, List<String>> axisLabels = new HashMap<>();
  34276. + axisLabels.put(1, Arrays.asList("a", "b", "c", "d"));
  34277. +
  34278. + Assert.assertThrows(
  34279. + IllegalArgumentException.class, () -> new TensorLabel(axisLabels, buffer));
  34280. }
  34281. - }
  34282. -
  34283. - @Test
  34284. - public void getMapWithIntBufferValuesShouldSuccess() {
  34285. - int numberLabel = 3;
  34286. - int[] inputArr = {1, 2, 0};
  34287. - int[] shape = {1, 1, numberLabel};
  34288. - TensorBuffer input = TensorBuffer.createFixedSize(shape, DataType.UINT8);
  34289. - input.loadArray(inputArr, shape);
  34290. - Map<Integer, List<String>> axisLabels = new HashMap<>();
  34291. - int labelAxis = 2;
  34292. - axisLabels.put(labelAxis, Arrays.asList("x", "y", "z"));
  34293. -
  34294. - TensorLabel tensorLabeled = new TensorLabel(axisLabels, input);
  34295. - Map<String, TensorBuffer> map = tensorLabeled.getMapWithTensorBuffer();
  34296. -
  34297. - for (int i = 0; i < numberLabel; i++) {
  34298. - String label = axisLabels.get(labelAxis).get(i);
  34299. - assertThat(map).containsKey(label);
  34300. - int[] array = map.get(label).getIntArray();
  34301. - assertThat(array).hasLength(1);
  34302. - assertThat(array[0]).isEqualTo(inputArr[i]);
  34303. +
  34304. + @Test
  34305. + public void createTensorLabelWithMismatchedAxisShouldFail() {
  34306. + int[] shape = {1, 4};
  34307. + TensorBuffer buffer = TensorBuffer.createFixedSize(shape, DataType.FLOAT32);
  34308. + Map<Integer, List<String>> axisLabels = new HashMap<>();
  34309. + axisLabels.put(0, Arrays.asList("a", "b", "c", "d"));
  34310. +
  34311. + Assert.assertThrows(
  34312. + IllegalArgumentException.class, () -> new TensorLabel(axisLabels, buffer));
  34313. }
  34314. - }
  34315. -
  34316. - @Test
  34317. - public void getFloatMapShouldSuccess() {
  34318. - int[] shape = {1, 3};
  34319. - TensorBuffer buffer = TensorBuffer.createFixedSize(shape, DataType.FLOAT32);
  34320. - buffer.loadArray(new float[] {1.0f, 2.0f, 3.0f});
  34321. -
  34322. - TensorLabel tensorLabeled = new TensorLabel(Arrays.asList("a", "b", "c"), buffer);
  34323. - Map<String, Float> map = tensorLabeled.getMapWithFloatValue();
  34324. -
  34325. - assertThat(map).hasSize(3);
  34326. - assertThat(map).containsEntry("a", 1.0f);
  34327. - assertThat(map).containsEntry("b", 2.0f);
  34328. - assertThat(map).containsEntry("c", 3.0f);
  34329. - }
  34330. -
  34331. - @Test
  34332. - public void getMapFromMultiDimensionalTensorBufferShouldSuccess() {
  34333. - int numberLabel = 2;
  34334. - int numDim = 3;
  34335. - float[] inputArr = {0.5f, 0.1f, 0.3f, 0.2f, 0.2f, 0.1f};
  34336. - int[] shape = {numberLabel, numDim};
  34337. - TensorBuffer input = TensorBuffer.createFixedSize(shape, DataType.FLOAT32);
  34338. - input.loadArray(inputArr, shape);
  34339. - Map<Integer, List<String>> axisLabels = new HashMap<>();
  34340. - int labelAxis = 0;
  34341. - axisLabels.put(labelAxis, Arrays.asList("pos", "neg"));
  34342. -
  34343. - TensorLabel tensorLabeled = new TensorLabel(axisLabels, input);
  34344. - Map<String, TensorBuffer> map = tensorLabeled.getMapWithTensorBuffer();
  34345. -
  34346. - for (int i = 0; i < numberLabel; i++) {
  34347. - String label = axisLabels.get(labelAxis).get(i);
  34348. - assertThat(map).containsKey(label);
  34349. -
  34350. - float[] array = map.get(label).getFloatArray();
  34351. - assertThat(array).hasLength(numDim);
  34352. - for (int j = 0; j < numDim; j++) {
  34353. - assertThat(array[j]).isEqualTo(inputArr[i * numDim + j]);
  34354. - }
  34355. +
  34356. + @Test
  34357. + public void createTensorLabelWithMismatchedShapeShouldFail() {
  34358. + int[] shape = {1, 3};
  34359. + TensorBuffer buffer = TensorBuffer.createFixedSize(shape, DataType.FLOAT32);
  34360. + Map<Integer, List<String>> axisLabels = new HashMap<>();
  34361. + axisLabels.put(1, Arrays.asList("a", "b", "c", "d"));
  34362. +
  34363. + Assert.assertThrows(
  34364. + IllegalArgumentException.class, () -> new TensorLabel(axisLabels, buffer));
  34365. + }
  34366. +
  34367. + @Test
  34368. + public void getMapWithFloatBufferValuesShouldSuccess() {
  34369. + int numberLabel = 4;
  34370. + float[] inputArr = {0.5f, 0.2f, 0.2f, 0.1f};
  34371. + int[] shape = {1, numberLabel};
  34372. + TensorBuffer input = TensorBuffer.createFixedSize(shape, DataType.FLOAT32);
  34373. + input.loadArray(inputArr, shape);
  34374. + Map<Integer, List<String>> axisLabels = new HashMap<>();
  34375. + int labelAxis = 1;
  34376. + axisLabels.put(labelAxis, Arrays.asList("a", "b", "c", "d"));
  34377. +
  34378. + TensorLabel tensorLabeled = new TensorLabel(axisLabels, input);
  34379. + Map<String, TensorBuffer> map = tensorLabeled.getMapWithTensorBuffer();
  34380. +
  34381. + for (int i = 0; i < numberLabel; i++) {
  34382. + String label = axisLabels.get(labelAxis).get(i);
  34383. + assertThat(map).containsKey(label);
  34384. + float[] array = map.get(label).getFloatArray();
  34385. + assertThat(array).hasLength(1);
  34386. + assertThat(array[0]).isEqualTo(inputArr[i]);
  34387. + }
  34388. }
  34389. - }
  34390. - @Test
  34391. - public void getCategoryListShouldSuccess() {
  34392. - int[] shape = {1, 3};
  34393. - TensorBuffer buffer = TensorBuffer.createFixedSize(shape, DataType.FLOAT32);
  34394. - buffer.loadArray(new float[] {1.0f, 2.0f, 3.0f});
  34395. + @Test
  34396. + public void getMapWithIntBufferValuesShouldSuccess() {
  34397. + int numberLabel = 3;
  34398. + int[] inputArr = {1, 2, 0};
  34399. + int[] shape = {1, 1, numberLabel};
  34400. + TensorBuffer input = TensorBuffer.createFixedSize(shape, DataType.UINT8);
  34401. + input.loadArray(inputArr, shape);
  34402. + Map<Integer, List<String>> axisLabels = new HashMap<>();
  34403. + int labelAxis = 2;
  34404. + axisLabels.put(labelAxis, Arrays.asList("x", "y", "z"));
  34405. +
  34406. + TensorLabel tensorLabeled = new TensorLabel(axisLabels, input);
  34407. + Map<String, TensorBuffer> map = tensorLabeled.getMapWithTensorBuffer();
  34408. +
  34409. + for (int i = 0; i < numberLabel; i++) {
  34410. + String label = axisLabels.get(labelAxis).get(i);
  34411. + assertThat(map).containsKey(label);
  34412. + int[] array = map.get(label).getIntArray();
  34413. + assertThat(array).hasLength(1);
  34414. + assertThat(array[0]).isEqualTo(inputArr[i]);
  34415. + }
  34416. + }
  34417. - TensorLabel tensorLabeled = new TensorLabel(Arrays.asList("a", "b", "c"), buffer);
  34418. - List<Category> categories = tensorLabeled.getCategoryList();
  34419. + @Test
  34420. + public void getFloatMapShouldSuccess() {
  34421. + int[] shape = {1, 3};
  34422. + TensorBuffer buffer = TensorBuffer.createFixedSize(shape, DataType.FLOAT32);
  34423. + buffer.loadArray(new float[] {1.0f, 2.0f, 3.0f});
  34424. - assertThat(categories).hasSize(3);
  34425. - assertThat(categories)
  34426. - .containsExactly(new Category("a", 1.0f), new Category("b", 2.0f), new Category("c", 3.0f));
  34427. - }
  34428. + TensorLabel tensorLabeled = new TensorLabel(Arrays.asList("a", "b", "c"), buffer);
  34429. + Map<String, Float> map = tensorLabeled.getMapWithFloatValue();
  34430. +
  34431. + assertThat(map).hasSize(3);
  34432. + assertThat(map).containsEntry("a", 1.0f);
  34433. + assertThat(map).containsEntry("b", 2.0f);
  34434. + assertThat(map).containsEntry("c", 3.0f);
  34435. + }
  34436. +
  34437. + @Test
  34438. + public void getMapFromMultiDimensionalTensorBufferShouldSuccess() {
  34439. + int numberLabel = 2;
  34440. + int numDim = 3;
  34441. + float[] inputArr = {0.5f, 0.1f, 0.3f, 0.2f, 0.2f, 0.1f};
  34442. + int[] shape = {numberLabel, numDim};
  34443. + TensorBuffer input = TensorBuffer.createFixedSize(shape, DataType.FLOAT32);
  34444. + input.loadArray(inputArr, shape);
  34445. + Map<Integer, List<String>> axisLabels = new HashMap<>();
  34446. + int labelAxis = 0;
  34447. + axisLabels.put(labelAxis, Arrays.asList("pos", "neg"));
  34448. +
  34449. + TensorLabel tensorLabeled = new TensorLabel(axisLabels, input);
  34450. + Map<String, TensorBuffer> map = tensorLabeled.getMapWithTensorBuffer();
  34451. +
  34452. + for (int i = 0; i < numberLabel; i++) {
  34453. + String label = axisLabels.get(labelAxis).get(i);
  34454. + assertThat(map).containsKey(label);
  34455. +
  34456. + float[] array = map.get(label).getFloatArray();
  34457. + assertThat(array).hasLength(numDim);
  34458. + for (int j = 0; j < numDim; j++) {
  34459. + assertThat(array[j]).isEqualTo(inputArr[i * numDim + j]);
  34460. + }
  34461. + }
  34462. + }
  34463. +
  34464. + @Test
  34465. + public void getCategoryListShouldSuccess() {
  34466. + int[] shape = {1, 3};
  34467. + TensorBuffer buffer = TensorBuffer.createFixedSize(shape, DataType.FLOAT32);
  34468. + buffer.loadArray(new float[] {1.0f, 2.0f, 3.0f});
  34469. +
  34470. + TensorLabel tensorLabeled = new TensorLabel(Arrays.asList("a", "b", "c"), buffer);
  34471. + List<Category> categories = tensorLabeled.getCategoryList();
  34472. +
  34473. + assertThat(categories).hasSize(3);
  34474. + assertThat(categories)
  34475. + .containsExactly(
  34476. + new Category("a", 1.0f), new Category("b", 2.0f), new Category("c", 3.0f));
  34477. + }
  34478. }
  34479. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/label/ops/LabelAxisOpTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/label/ops/LabelAxisOpTest.java
  34480. index 8fa8860a09ef5..c1afe99f34f34 100644
  34481. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/label/ops/LabelAxisOpTest.java
  34482. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/label/ops/LabelAxisOpTest.java
  34483. @@ -18,11 +18,9 @@ package org.tensorflow.lite.support.label.ops;
  34484. import static com.google.common.truth.Truth.assertThat;
  34485. import android.content.Context;
  34486. +
  34487. import androidx.test.core.app.ApplicationProvider;
  34488. -import java.io.IOException;
  34489. -import java.util.Arrays;
  34490. -import java.util.List;
  34491. -import java.util.Map;
  34492. +
  34493. import org.junit.Test;
  34494. import org.junit.runner.RunWith;
  34495. import org.robolectric.RobolectricTestRunner;
  34496. @@ -31,90 +29,94 @@ import org.tensorflow.lite.support.common.FileUtil;
  34497. import org.tensorflow.lite.support.label.TensorLabel;
  34498. import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
  34499. +import java.io.IOException;
  34500. +import java.util.Arrays;
  34501. +import java.util.List;
  34502. +import java.util.Map;
  34503. +
  34504. /** Tests of {@link org.tensorflow.lite.support.label.ops.LabelAxisOp}. */
  34505. @RunWith(RobolectricTestRunner.class)
  34506. public final class LabelAxisOpTest {
  34507. + private final Context context = ApplicationProvider.getApplicationContext();
  34508. + private static final String LABEL_PATH = "flower_labels.txt";
  34509. +
  34510. + @Test
  34511. + public void testAddAxisLabelByStringList() {
  34512. + int numberLabel = 2;
  34513. + float[] inputArr = {0.7f, 0.3f};
  34514. +
  34515. + int[] shape = {numberLabel};
  34516. + TensorBuffer input = TensorBuffer.createFixedSize(shape, DataType.FLOAT32);
  34517. + input.loadArray(inputArr, shape);
  34518. +
  34519. + List<String> labels = Arrays.asList("pos", "neg");
  34520. + LabelAxisOp op = new LabelAxisOp.Builder().addAxisLabel(0, labels).build();
  34521. + TensorLabel output = op.apply(input);
  34522. + Map<String, TensorBuffer> map = output.getMapWithTensorBuffer();
  34523. +
  34524. + assertThat(map).containsKey("pos");
  34525. + float[] array = map.get("pos").getFloatArray();
  34526. + assertThat(array).hasLength(1);
  34527. + assertThat(array[0]).isEqualTo(0.7f);
  34528. +
  34529. + assertThat(map).containsKey("neg");
  34530. + array = map.get("neg").getFloatArray();
  34531. + assertThat(array).hasLength(1);
  34532. + assertThat(array[0]).isEqualTo(0.3f);
  34533. + }
  34534. +
  34535. + @Test
  34536. + public void testAddAxisLabelWithMultiDimensionTensor() throws IOException {
  34537. + int numberLabel = 2;
  34538. + int numDim = 3;
  34539. + float[] inputArr = {0.5f, 0.1f, 0.3f, 0.2f, 0.2f, 0.1f};
  34540. +
  34541. + int[] shape = {1, numberLabel, numDim};
  34542. + TensorBuffer input = TensorBuffer.createFixedSize(shape, DataType.FLOAT32);
  34543. + input.loadArray(inputArr, shape);
  34544. - private final Context context = ApplicationProvider.getApplicationContext();
  34545. - private static final String LABEL_PATH = "flower_labels.txt";
  34546. -
  34547. - @Test
  34548. - public void testAddAxisLabelByStringList() {
  34549. - int numberLabel = 2;
  34550. - float[] inputArr = {0.7f, 0.3f};
  34551. -
  34552. - int[] shape = {numberLabel};
  34553. - TensorBuffer input = TensorBuffer.createFixedSize(shape, DataType.FLOAT32);
  34554. - input.loadArray(inputArr, shape);
  34555. -
  34556. - List<String> labels = Arrays.asList("pos", "neg");
  34557. - LabelAxisOp op = new LabelAxisOp.Builder().addAxisLabel(0, labels).build();
  34558. - TensorLabel output = op.apply(input);
  34559. - Map<String, TensorBuffer> map = output.getMapWithTensorBuffer();
  34560. -
  34561. - assertThat(map).containsKey("pos");
  34562. - float[] array = map.get("pos").getFloatArray();
  34563. - assertThat(array).hasLength(1);
  34564. - assertThat(array[0]).isEqualTo(0.7f);
  34565. -
  34566. - assertThat(map).containsKey("neg");
  34567. - array = map.get("neg").getFloatArray();
  34568. - assertThat(array).hasLength(1);
  34569. - assertThat(array[0]).isEqualTo(0.3f);
  34570. - }
  34571. -
  34572. - @Test
  34573. - public void testAddAxisLabelWithMultiDimensionTensor() throws IOException {
  34574. - int numberLabel = 2;
  34575. - int numDim = 3;
  34576. - float[] inputArr = {0.5f, 0.1f, 0.3f, 0.2f, 0.2f, 0.1f};
  34577. -
  34578. - int[] shape = {1, numberLabel, numDim};
  34579. - TensorBuffer input = TensorBuffer.createFixedSize(shape, DataType.FLOAT32);
  34580. - input.loadArray(inputArr, shape);
  34581. -
  34582. - List<String> labels = Arrays.asList("pos", "neg");
  34583. - LabelAxisOp op = new LabelAxisOp.Builder().addAxisLabel(1, labels).build();
  34584. -
  34585. - TensorLabel output = op.apply(input);
  34586. - Map<String, TensorBuffer> map = output.getMapWithTensorBuffer();
  34587. -
  34588. - assertThat(map).containsKey("pos");
  34589. - float[] array = map.get("pos").getFloatArray();
  34590. - assertThat(array).hasLength(numDim);
  34591. - assertThat(array).isEqualTo(new float[] {0.5f, 0.1f, 0.3f});
  34592. -
  34593. - assertThat(map).containsKey("neg");
  34594. - array = map.get("neg").getFloatArray();
  34595. - assertThat(array).hasLength(numDim);
  34596. - assertThat(array).isEqualTo(new float[] {0.2f, 0.2f, 0.1f});
  34597. - }
  34598. -
  34599. - @Test
  34600. - public void testAddAxisLabelByFilePath() throws IOException {
  34601. - int numberLabel = 5;
  34602. - int[] inputArr = new int[numberLabel];
  34603. - for (int i = 0; i < numberLabel; i++) {
  34604. - inputArr[i] = i;
  34605. + List<String> labels = Arrays.asList("pos", "neg");
  34606. + LabelAxisOp op = new LabelAxisOp.Builder().addAxisLabel(1, labels).build();
  34607. +
  34608. + TensorLabel output = op.apply(input);
  34609. + Map<String, TensorBuffer> map = output.getMapWithTensorBuffer();
  34610. +
  34611. + assertThat(map).containsKey("pos");
  34612. + float[] array = map.get("pos").getFloatArray();
  34613. + assertThat(array).hasLength(numDim);
  34614. + assertThat(array).isEqualTo(new float[] {0.5f, 0.1f, 0.3f});
  34615. +
  34616. + assertThat(map).containsKey("neg");
  34617. + array = map.get("neg").getFloatArray();
  34618. + assertThat(array).hasLength(numDim);
  34619. + assertThat(array).isEqualTo(new float[] {0.2f, 0.2f, 0.1f});
  34620. }
  34621. - int[] shape = {numberLabel};
  34622. - TensorBuffer input = TensorBuffer.createFixedSize(shape, DataType.UINT8);
  34623. - input.loadArray(inputArr, shape);
  34624. + @Test
  34625. + public void testAddAxisLabelByFilePath() throws IOException {
  34626. + int numberLabel = 5;
  34627. + int[] inputArr = new int[numberLabel];
  34628. + for (int i = 0; i < numberLabel; i++) {
  34629. + inputArr[i] = i;
  34630. + }
  34631. +
  34632. + int[] shape = {numberLabel};
  34633. + TensorBuffer input = TensorBuffer.createFixedSize(shape, DataType.UINT8);
  34634. + input.loadArray(inputArr, shape);
  34635. - LabelAxisOp op = new LabelAxisOp.Builder().addAxisLabel(context, 0, LABEL_PATH).build();
  34636. - TensorLabel output = op.apply(input);
  34637. - Map<String, TensorBuffer> map = output.getMapWithTensorBuffer();
  34638. + LabelAxisOp op = new LabelAxisOp.Builder().addAxisLabel(context, 0, LABEL_PATH).build();
  34639. + TensorLabel output = op.apply(input);
  34640. + Map<String, TensorBuffer> map = output.getMapWithTensorBuffer();
  34641. - List<String> labels = FileUtil.loadLabels(context, LABEL_PATH);
  34642. - for (int i = 0; i < numberLabel; i++) {
  34643. - String label = labels.get(i);
  34644. + List<String> labels = FileUtil.loadLabels(context, LABEL_PATH);
  34645. + for (int i = 0; i < numberLabel; i++) {
  34646. + String label = labels.get(i);
  34647. - assertThat(map).containsKey(label);
  34648. + assertThat(map).containsKey(label);
  34649. - int[] array = map.get(label).getIntArray();
  34650. - assertThat(array).hasLength(1);
  34651. - assertThat(array[0]).isEqualTo(inputArr[i]);
  34652. + int[] array = map.get(label).getIntArray();
  34653. + assertThat(array).hasLength(1);
  34654. + assertThat(array[0]).isEqualTo(inputArr[i]);
  34655. + }
  34656. }
  34657. - }
  34658. }
  34659. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/model/GpuDelegateProxyInstrumentedTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/model/GpuDelegateProxyInstrumentedTest.java
  34660. index bd59051ce4ccb..d7449187cb54c 100644
  34661. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/model/GpuDelegateProxyInstrumentedTest.java
  34662. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/model/GpuDelegateProxyInstrumentedTest.java
  34663. @@ -17,6 +17,7 @@ package org.tensorflow.lite.support.model;
  34664. import static com.google.common.truth.Truth.assertThat;
  34665. import androidx.test.ext.junit.runners.AndroidJUnit4;
  34666. +
  34667. import org.junit.Test;
  34668. import org.junit.runner.RunWith;
  34669. @@ -27,13 +28,12 @@ import org.junit.runner.RunWith;
  34670. */
  34671. @RunWith(AndroidJUnit4.class)
  34672. public final class GpuDelegateProxyInstrumentedTest {
  34673. -
  34674. - @Test
  34675. - public void createGpuDelegateProxyShouldSuccess() {
  34676. - GpuDelegateProxy proxy = GpuDelegateProxy.maybeNewInstance();
  34677. -
  34678. - assertThat(proxy).isNotNull();
  34679. - proxy.getNativeHandle();
  34680. - proxy.close();
  34681. - }
  34682. + @Test
  34683. + public void createGpuDelegateProxyShouldSuccess() {
  34684. + GpuDelegateProxy proxy = GpuDelegateProxy.maybeNewInstance();
  34685. +
  34686. + assertThat(proxy).isNotNull();
  34687. + proxy.getNativeHandle();
  34688. + proxy.close();
  34689. + }
  34690. }
  34691. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/model/GpuDelegateProxyTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/model/GpuDelegateProxyTest.java
  34692. index c1bbcc223a895..4eb2e2920c3bc 100644
  34693. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/model/GpuDelegateProxyTest.java
  34694. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/model/GpuDelegateProxyTest.java
  34695. @@ -23,11 +23,10 @@ import org.robolectric.RobolectricTestRunner;
  34696. /** Tests of {@link org.tensorflow.lite.support.model.GpuDelegateProxy}. */
  34697. @RunWith(RobolectricTestRunner.class)
  34698. public final class GpuDelegateProxyTest {
  34699. + @Test
  34700. + public void createGpuDelegateProxyWithoutDependencyShouldReturnNull() {
  34701. + GpuDelegateProxy proxy = GpuDelegateProxy.maybeNewInstance();
  34702. - @Test
  34703. - public void createGpuDelegateProxyWithoutDependencyShouldReturnNull() {
  34704. - GpuDelegateProxy proxy = GpuDelegateProxy.maybeNewInstance();
  34705. -
  34706. - assertThat(proxy).isNull();
  34707. - }
  34708. + assertThat(proxy).isNull();
  34709. + }
  34710. }
  34711. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/model/ModelTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/model/ModelTest.java
  34712. index 86e4f72769216..342e82b2de3bb 100644
  34713. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/model/ModelTest.java
  34714. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/model/ModelTest.java
  34715. @@ -16,143 +16,145 @@ limitations under the License.
  34716. package org.tensorflow.lite.support.model;
  34717. import static com.google.common.truth.Truth.assertThat;
  34718. +
  34719. import static org.junit.Assert.fail;
  34720. import android.content.Context;
  34721. +
  34722. import androidx.test.core.app.ApplicationProvider;
  34723. -import java.io.IOException;
  34724. -import java.nio.MappedByteBuffer;
  34725. -import java.util.HashMap;
  34726. -import java.util.Map;
  34727. +
  34728. +import org.junit.Ignore;
  34729. import org.junit.Test;
  34730. import org.junit.runner.RunWith;
  34731. import org.robolectric.RobolectricTestRunner;
  34732. import org.tensorflow.lite.support.model.Model.Device;
  34733. import org.tensorflow.lite.support.model.Model.Options;
  34734. -import org.junit.Ignore;
  34735. +import java.io.IOException;
  34736. +import java.nio.MappedByteBuffer;
  34737. +import java.util.HashMap;
  34738. +import java.util.Map;
  34739. /** Tests of {@link org.tensorflow.lite.support.model.Model}. */
  34740. @RunWith(RobolectricTestRunner.class)
  34741. public final class ModelTest {
  34742. + private final Context context = ApplicationProvider.getApplicationContext();
  34743. + private static final String MODEL_PATH = "add.tflite";
  34744. +
  34745. + @Ignore
  34746. + @Test
  34747. + public void testLoadLocalModel() throws IOException {
  34748. + MappedByteBuffer byteModel = new Model.Builder(context, MODEL_PATH).build().getData();
  34749. + assertThat(byteModel).isNotNull();
  34750. + }
  34751. +
  34752. + @Ignore
  34753. + @Test
  34754. + public void testBuildMultiThreadModel() throws IOException {
  34755. + MappedByteBuffer byteModel =
  34756. + new Model.Builder(context, MODEL_PATH).setNumThreads(4).build().getData();
  34757. + assertThat(byteModel).isNotNull();
  34758. + }
  34759. +
  34760. + @Ignore
  34761. + @Test
  34762. + public void buildModelWithOptionsShouldSuccess() throws IOException {
  34763. + Options options = new Options.Builder().setNumThreads(2).setDevice(Device.NNAPI).build();
  34764. + Model model = Model.createModel(context, MODEL_PATH, options);
  34765. + assertThat(model.getData()).isNotNull();
  34766. + }
  34767. - private final Context context = ApplicationProvider.getApplicationContext();
  34768. - private static final String MODEL_PATH = "add.tflite";
  34769. -
  34770. - @Ignore
  34771. - @Test
  34772. - public void testLoadLocalModel() throws IOException {
  34773. - MappedByteBuffer byteModel = new Model.Builder(context, MODEL_PATH).build().getData();
  34774. - assertThat(byteModel).isNotNull();
  34775. - }
  34776. -
  34777. - @Ignore
  34778. - @Test
  34779. - public void testBuildMultiThreadModel() throws IOException {
  34780. - MappedByteBuffer byteModel =
  34781. - new Model.Builder(context, MODEL_PATH).setNumThreads(4).build().getData();
  34782. - assertThat(byteModel).isNotNull();
  34783. - }
  34784. -
  34785. - @Ignore
  34786. - @Test
  34787. - public void buildModelWithOptionsShouldSuccess() throws IOException {
  34788. - Options options = new Options.Builder().setNumThreads(2).setDevice(Device.NNAPI).build();
  34789. - Model model = Model.createModel(context, MODEL_PATH, options);
  34790. - assertThat(model.getData()).isNotNull();
  34791. - }
  34792. -
  34793. - @Ignore
  34794. - @Test
  34795. - public void testGetModelPath() throws IOException {
  34796. - String modelPath = new Model.Builder(context, MODEL_PATH).build().getPath();
  34797. - assertThat(modelPath).isEqualTo(MODEL_PATH);
  34798. - }
  34799. -
  34800. - @Test
  34801. - public void testNonExistingLocalModel() {
  34802. - try {
  34803. - new Model.Builder(context, "non_exist_model_file").build();
  34804. - fail();
  34805. - } catch (IOException e) {
  34806. - assertThat(e).hasMessageThat().contains("non_exist_model_file");
  34807. + @Ignore
  34808. + @Test
  34809. + public void testGetModelPath() throws IOException {
  34810. + String modelPath = new Model.Builder(context, MODEL_PATH).build().getPath();
  34811. + assertThat(modelPath).isEqualTo(MODEL_PATH);
  34812. }
  34813. - }
  34814. -
  34815. - @Test
  34816. - public void testNullLocalModelPath() throws IOException {
  34817. - try {
  34818. - new Model.Builder(context, null).build();
  34819. - fail();
  34820. - } catch (NullPointerException e) {
  34821. - assertThat(e).hasMessageThat().contains("File path cannot be null.");
  34822. +
  34823. + @Test
  34824. + public void testNonExistingLocalModel() {
  34825. + try {
  34826. + new Model.Builder(context, "non_exist_model_file").build();
  34827. + fail();
  34828. + } catch (IOException e) {
  34829. + assertThat(e).hasMessageThat().contains("non_exist_model_file");
  34830. + }
  34831. }
  34832. - }
  34833. -
  34834. - @Test
  34835. - public void testNullContext() throws IOException {
  34836. - try {
  34837. - new Model.Builder(null, MODEL_PATH).build();
  34838. - fail();
  34839. - } catch (NullPointerException e) {
  34840. - assertThat(e).hasMessageThat().contains("Context should not be null.");
  34841. +
  34842. + @Test
  34843. + public void testNullLocalModelPath() throws IOException {
  34844. + try {
  34845. + new Model.Builder(context, null).build();
  34846. + fail();
  34847. + } catch (NullPointerException e) {
  34848. + assertThat(e).hasMessageThat().contains("File path cannot be null.");
  34849. + }
  34850. + }
  34851. +
  34852. + @Test
  34853. + public void testNullContext() throws IOException {
  34854. + try {
  34855. + new Model.Builder(null, MODEL_PATH).build();
  34856. + fail();
  34857. + } catch (NullPointerException e) {
  34858. + assertThat(e).hasMessageThat().contains("Context should not be null.");
  34859. + }
  34860. + }
  34861. +
  34862. + @Ignore
  34863. + @Test
  34864. + public void testGetInputTensor() throws IOException {
  34865. + Options options = new Options.Builder().build();
  34866. + Model model = Model.createModel(context, MODEL_PATH, options);
  34867. + assertThat(model.getInputTensor(0)).isNotNull();
  34868. + }
  34869. +
  34870. + @Ignore
  34871. + @Test
  34872. + public void testGetOutputTensor() throws IOException {
  34873. + Options options = new Options.Builder().build();
  34874. + Model model = Model.createModel(context, MODEL_PATH, options);
  34875. + assertThat(model.getOutputTensor(0)).isNotNull();
  34876. + }
  34877. +
  34878. + @Ignore
  34879. + @Test
  34880. + public void testRun() throws IOException {
  34881. + Context context = ApplicationProvider.getApplicationContext();
  34882. + Model model = new Model.Builder(context, MODEL_PATH).build();
  34883. + runModel(model);
  34884. + }
  34885. +
  34886. + @Ignore
  34887. + @Test
  34888. + public void testMultiThreadingRun() throws IOException {
  34889. + Context context = ApplicationProvider.getApplicationContext();
  34890. + Model model = new Model.Builder(context, MODEL_PATH).setNumThreads(4).build();
  34891. + runModel(model);
  34892. + }
  34893. +
  34894. + @Ignore
  34895. + @Test
  34896. + public void testNnApiRun() throws IOException {
  34897. + Context context = ApplicationProvider.getApplicationContext();
  34898. + Model model = new Model.Builder(context, MODEL_PATH).setDevice(Device.NNAPI).build();
  34899. + runModel(model);
  34900. + }
  34901. +
  34902. + private static void runModel(Model model) throws IOException {
  34903. + // Creates the inputs.
  34904. + float[] x = {1.5f};
  34905. + float[] y = {0.5f};
  34906. + float[] expectedSum = {2.0f};
  34907. + Object[] inputs = {x, y};
  34908. +
  34909. + // Creates the outputs buffer.
  34910. + float[] sum = new float[1];
  34911. + Map<Integer, Object> outputs = new HashMap<>();
  34912. + outputs.put(0, sum);
  34913. +
  34914. + // Runs inference.
  34915. + model.run(inputs, outputs);
  34916. + assertThat(sum).isEqualTo(expectedSum);
  34917. }
  34918. - }
  34919. -
  34920. - @Ignore
  34921. - @Test
  34922. - public void testGetInputTensor() throws IOException {
  34923. - Options options = new Options.Builder().build();
  34924. - Model model = Model.createModel(context, MODEL_PATH, options);
  34925. - assertThat(model.getInputTensor(0)).isNotNull();
  34926. - }
  34927. -
  34928. - @Ignore
  34929. - @Test
  34930. - public void testGetOutputTensor() throws IOException {
  34931. - Options options = new Options.Builder().build();
  34932. - Model model = Model.createModel(context, MODEL_PATH, options);
  34933. - assertThat(model.getOutputTensor(0)).isNotNull();
  34934. - }
  34935. -
  34936. - @Ignore
  34937. - @Test
  34938. - public void testRun() throws IOException {
  34939. - Context context = ApplicationProvider.getApplicationContext();
  34940. - Model model = new Model.Builder(context, MODEL_PATH).build();
  34941. - runModel(model);
  34942. - }
  34943. -
  34944. - @Ignore
  34945. - @Test
  34946. - public void testMultiThreadingRun() throws IOException {
  34947. - Context context = ApplicationProvider.getApplicationContext();
  34948. - Model model = new Model.Builder(context, MODEL_PATH).setNumThreads(4).build();
  34949. - runModel(model);
  34950. - }
  34951. -
  34952. - @Ignore
  34953. - @Test
  34954. - public void testNnApiRun() throws IOException {
  34955. - Context context = ApplicationProvider.getApplicationContext();
  34956. - Model model = new Model.Builder(context, MODEL_PATH).setDevice(Device.NNAPI).build();
  34957. - runModel(model);
  34958. - }
  34959. -
  34960. - private static void runModel(Model model) throws IOException {
  34961. - // Creates the inputs.
  34962. - float[] x = {1.5f};
  34963. - float[] y = {0.5f};
  34964. - float[] expectedSum = {2.0f};
  34965. - Object[] inputs = {x, y};
  34966. -
  34967. - // Creates the outputs buffer.
  34968. - float[] sum = new float[1];
  34969. - Map<Integer, Object> outputs = new HashMap<>();
  34970. - outputs.put(0, sum);
  34971. -
  34972. - // Runs inference.
  34973. - model.run(inputs, outputs);
  34974. - assertThat(sum).isEqualTo(expectedSum);
  34975. - }
  34976. }
  34977. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/tensorbuffer/TensorBufferFloatTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/tensorbuffer/TensorBufferFloatTest.java
  34978. index 3a4d09d8e5701..82b59b36155f3 100644
  34979. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/tensorbuffer/TensorBufferFloatTest.java
  34980. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/tensorbuffer/TensorBufferFloatTest.java
  34981. @@ -26,51 +26,51 @@ import org.tensorflow.lite.DataType;
  34982. /** Tests of {@link org.tensorflow.lite.support.tensorbuffer.TensorBufferFloat}. */
  34983. @RunWith(RobolectricTestRunner.class)
  34984. public final class TensorBufferFloatTest {
  34985. - @Test
  34986. - public void testCreateDynamic() {
  34987. - TensorBufferFloat tensorBufferFloat = new TensorBufferFloat();
  34988. - assertThat(tensorBufferFloat).isNotNull();
  34989. - }
  34990. + @Test
  34991. + public void testCreateDynamic() {
  34992. + TensorBufferFloat tensorBufferFloat = new TensorBufferFloat();
  34993. + assertThat(tensorBufferFloat).isNotNull();
  34994. + }
  34995. - @Test
  34996. - public void testCreateFixedSize() {
  34997. - int[] shape = new int[] {1, 2, 3};
  34998. - TensorBufferFloat tensorBufferFloat = new TensorBufferFloat(shape);
  34999. - assertThat(tensorBufferFloat).isNotNull();
  35000. - assertThat(tensorBufferFloat.getFlatSize()).isEqualTo(6);
  35001. - }
  35002. + @Test
  35003. + public void testCreateFixedSize() {
  35004. + int[] shape = new int[] {1, 2, 3};
  35005. + TensorBufferFloat tensorBufferFloat = new TensorBufferFloat(shape);
  35006. + assertThat(tensorBufferFloat).isNotNull();
  35007. + assertThat(tensorBufferFloat.getFlatSize()).isEqualTo(6);
  35008. + }
  35009. - @Test
  35010. - public void testCreateFixedSizeWithScalarShape() {
  35011. - int[] shape = new int[] {};
  35012. - TensorBufferFloat tensorBufferFloat = new TensorBufferFloat(shape);
  35013. - assertThat(tensorBufferFloat).isNotNull();
  35014. - assertThat(tensorBufferFloat.getFlatSize()).isEqualTo(1);
  35015. - }
  35016. + @Test
  35017. + public void testCreateFixedSizeWithScalarShape() {
  35018. + int[] shape = new int[] {};
  35019. + TensorBufferFloat tensorBufferFloat = new TensorBufferFloat(shape);
  35020. + assertThat(tensorBufferFloat).isNotNull();
  35021. + assertThat(tensorBufferFloat.getFlatSize()).isEqualTo(1);
  35022. + }
  35023. - @Test
  35024. - public void testCreateWithNullShape() {
  35025. - int[] shape = null;
  35026. - Assert.assertThrows(NullPointerException.class, () -> new TensorBufferFloat(shape));
  35027. - }
  35028. + @Test
  35029. + public void testCreateWithNullShape() {
  35030. + int[] shape = null;
  35031. + Assert.assertThrows(NullPointerException.class, () -> new TensorBufferFloat(shape));
  35032. + }
  35033. - @Test
  35034. - public void testCreateWithInvalidShape() {
  35035. - int[] shape = new int[] {1, -1, 2};
  35036. - Assert.assertThrows(IllegalArgumentException.class, () -> new TensorBufferFloat(shape));
  35037. - }
  35038. + @Test
  35039. + public void testCreateWithInvalidShape() {
  35040. + int[] shape = new int[] {1, -1, 2};
  35041. + Assert.assertThrows(IllegalArgumentException.class, () -> new TensorBufferFloat(shape));
  35042. + }
  35043. - @Test
  35044. - public void testCreateUsingShapeWithZero() {
  35045. - int[] shape = new int[] {1, 0, 2};
  35046. - TensorBufferFloat tensorBufferFloat = new TensorBufferFloat(shape);
  35047. - assertThat(tensorBufferFloat).isNotNull();
  35048. - assertThat(tensorBufferFloat.getFlatSize()).isEqualTo(0);
  35049. - }
  35050. + @Test
  35051. + public void testCreateUsingShapeWithZero() {
  35052. + int[] shape = new int[] {1, 0, 2};
  35053. + TensorBufferFloat tensorBufferFloat = new TensorBufferFloat(shape);
  35054. + assertThat(tensorBufferFloat).isNotNull();
  35055. + assertThat(tensorBufferFloat.getFlatSize()).isEqualTo(0);
  35056. + }
  35057. - @Test
  35058. - public void testGetDataType() {
  35059. - TensorBufferFloat tensorBufferFloat = new TensorBufferFloat();
  35060. - assertThat(tensorBufferFloat.getDataType()).isEqualTo(DataType.FLOAT32);
  35061. - }
  35062. + @Test
  35063. + public void testGetDataType() {
  35064. + TensorBufferFloat tensorBufferFloat = new TensorBufferFloat();
  35065. + assertThat(tensorBufferFloat.getDataType()).isEqualTo(DataType.FLOAT32);
  35066. + }
  35067. }
  35068. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/tensorbuffer/TensorBufferTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/tensorbuffer/TensorBufferTest.java
  35069. index c55affe733eac..763356f493390 100644
  35070. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/tensorbuffer/TensorBufferTest.java
  35071. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/tensorbuffer/TensorBufferTest.java
  35072. @@ -16,877 +16,878 @@ limitations under the License.
  35073. package org.tensorflow.lite.support.tensorbuffer;
  35074. import static com.google.common.truth.Truth.assertThat;
  35075. +
  35076. import static org.junit.Assert.assertThrows;
  35077. -import java.io.IOException;
  35078. -import java.nio.ByteBuffer;
  35079. -import java.nio.FloatBuffer;
  35080. -import java.util.ArrayList;
  35081. import org.junit.Assert;
  35082. import org.junit.Test;
  35083. import org.junit.runner.RunWith;
  35084. import org.robolectric.RobolectricTestRunner;
  35085. import org.tensorflow.lite.DataType;
  35086. +import java.io.IOException;
  35087. +import java.nio.ByteBuffer;
  35088. +import java.nio.FloatBuffer;
  35089. +import java.util.ArrayList;
  35090. +
  35091. /** Test helper class for inserting and retrieving arrays. */
  35092. class ArrayTestRunner {
  35093. - // List of TensorBuffer types to be tested.
  35094. - private static final DataType[] BUFFER_TYPE_LIST = {DataType.FLOAT32, DataType.UINT8};
  35095. - // List of source arrays to be loaded into TensorBuffer during the tests.
  35096. - private final ArrayList<Object> srcArrays;
  35097. - // List of array data type with respect to srcArrays.
  35098. - private final ArrayList<DataType> arrDataTypes;
  35099. - // List of array shape with respect to srcArrays.
  35100. - private final ArrayList<int[]> arrShapes;
  35101. - private final int[] tensorBufferShape;
  35102. - private final ExpectedResults expectedResForFloatBuf;
  35103. - private final ExpectedResults expectedResForByteBuf;
  35104. -
  35105. - public ArrayTestRunner(Builder builder) {
  35106. - if (builder.srcArrays.size() != builder.arrDataTypes.size()) {
  35107. - throw new IllegalArgumentException(
  35108. - "Number of source arrays and number of data types do not match.");
  35109. - }
  35110. -
  35111. - this.srcArrays = builder.srcArrays;
  35112. - this.arrDataTypes = builder.arrDataTypes;
  35113. - this.arrShapes = builder.arrShapes;
  35114. - this.tensorBufferShape = builder.tensorBufferShape;
  35115. - this.expectedResForFloatBuf = builder.expectedResForFloatBuf;
  35116. - this.expectedResForByteBuf = builder.expectedResForByteBuf;
  35117. - }
  35118. -
  35119. - static class ExpectedResults {
  35120. - public float[] floatArr;
  35121. - public int[] intArr;
  35122. - public int[] shape;
  35123. - }
  35124. -
  35125. - public static class Builder {
  35126. - private final ArrayList<Object> srcArrays = new ArrayList<>();
  35127. - private final ArrayList<DataType> arrDataTypes = new ArrayList<>();
  35128. - private final ArrayList<int[]> arrShapes = new ArrayList<>();
  35129. - private int[] tensorBufferShape;
  35130. - private final ExpectedResults expectedResForFloatBuf = new ExpectedResults();
  35131. - private final ExpectedResults expectedResForByteBuf = new ExpectedResults();
  35132. -
  35133. - public static Builder newInstance() {
  35134. - return new Builder();
  35135. - }
  35136. -
  35137. - private Builder() {}
  35138. -
  35139. - /** Loads a test array into the test runner. */
  35140. - public Builder addSrcArray(Object src, int[] shape) {
  35141. - // src should be a primitive 1D array.
  35142. - DataType dataType = dataTypeOfArray(src);
  35143. - switch (dataType) {
  35144. - case INT32:
  35145. - case FLOAT32:
  35146. - srcArrays.add(src);
  35147. - arrDataTypes.add(dataType);
  35148. - arrShapes.add(shape);
  35149. - return this;
  35150. - default:
  35151. - throw new AssertionError("Cannot resolve srouce arrays in the DataType of " + dataType);
  35152. - }
  35153. - }
  35154. -
  35155. - public Builder setTensorBufferShape(int[] tensorBufferShape) {
  35156. - this.tensorBufferShape = tensorBufferShape;
  35157. - return this;
  35158. - }
  35159. -
  35160. - public Builder setExpectedResults(
  35161. - DataType bufferType, float[] expectedFloatArr, int[] expectedIntArr) {
  35162. - ExpectedResults er;
  35163. - switch (bufferType) {
  35164. - case UINT8:
  35165. - er = expectedResForByteBuf;
  35166. - break;
  35167. - case FLOAT32:
  35168. - er = expectedResForFloatBuf;
  35169. - break;
  35170. - default:
  35171. - throw new AssertionError("Cannot test TensorBuffer in the DataType of " + bufferType);
  35172. - }
  35173. -
  35174. - er.floatArr = expectedFloatArr;
  35175. - er.intArr = expectedIntArr;
  35176. - return this;
  35177. - }
  35178. -
  35179. - public ArrayTestRunner build() {
  35180. - int[] expectedShape;
  35181. - if (arrShapes.isEmpty()) {
  35182. - // If no array will be loaded, the array is an empty array.
  35183. - expectedShape = new int[] {0};
  35184. - } else {
  35185. - expectedShape = arrShapes.get(arrShapes.size() - 1);
  35186. - }
  35187. - expectedResForByteBuf.shape = expectedShape;
  35188. - expectedResForFloatBuf.shape = expectedShape;
  35189. - return new ArrayTestRunner(this);
  35190. - }
  35191. - }
  35192. -
  35193. - public static DataType[] getBufferTypeList() {
  35194. - return BUFFER_TYPE_LIST;
  35195. - }
  35196. -
  35197. - /**
  35198. - * Runs tests in the following steps: 1. Create a TensorBuffer. If tensorBufferShape is null,
  35199. - * create a dynamic buffer. Otherwise, create a fixed-size buffer accordingly. 2. Load arrays in
  35200. - * srcArrays one by one into the TensotBuffer. 3. Get arrays for each supported primitive types in
  35201. - * TensorBuffer, such as int array and float array for now. Check if the results are correct. 4.
  35202. - * Repeat Step 1 to 3 for all buffer types in BUFFER_TYPE_LIST.
  35203. - */
  35204. - public void run() {
  35205. - for (DataType bufferDataType : BUFFER_TYPE_LIST) {
  35206. - TensorBuffer tensorBuffer;
  35207. - if (tensorBufferShape == null) {
  35208. - tensorBuffer = TensorBuffer.createDynamic(bufferDataType);
  35209. - } else {
  35210. - tensorBuffer = TensorBuffer.createFixedSize(tensorBufferShape, bufferDataType);
  35211. - }
  35212. - for (int i = 0; i < srcArrays.size(); i++) {
  35213. - switch (arrDataTypes.get(i)) {
  35214. - case INT32:
  35215. - int[] arrInt = (int[]) srcArrays.get(i);
  35216. - tensorBuffer.loadArray(arrInt, arrShapes.get(i));
  35217. - break;
  35218. - case FLOAT32:
  35219. - float[] arrFloat = (float[]) srcArrays.get(i);
  35220. - tensorBuffer.loadArray(arrFloat, arrShapes.get(i));
  35221. - break;
  35222. - default:
  35223. - break;
  35224. + // List of TensorBuffer types to be tested.
  35225. + private static final DataType[] BUFFER_TYPE_LIST = {DataType.FLOAT32, DataType.UINT8};
  35226. + // List of source arrays to be loaded into TensorBuffer during the tests.
  35227. + private final ArrayList<Object> srcArrays;
  35228. + // List of array data type with respect to srcArrays.
  35229. + private final ArrayList<DataType> arrDataTypes;
  35230. + // List of array shape with respect to srcArrays.
  35231. + private final ArrayList<int[]> arrShapes;
  35232. + private final int[] tensorBufferShape;
  35233. + private final ExpectedResults expectedResForFloatBuf;
  35234. + private final ExpectedResults expectedResForByteBuf;
  35235. +
  35236. + public ArrayTestRunner(Builder builder) {
  35237. + if (builder.srcArrays.size() != builder.arrDataTypes.size()) {
  35238. + throw new IllegalArgumentException(
  35239. + "Number of source arrays and number of data types do not match.");
  35240. }
  35241. - }
  35242. - checkResults(tensorBuffer);
  35243. - }
  35244. - }
  35245. -
  35246. - private void checkResults(TensorBuffer tensorBuffer) {
  35247. - ExpectedResults er;
  35248. - switch (tensorBuffer.getDataType()) {
  35249. - case UINT8:
  35250. - er = expectedResForByteBuf;
  35251. - break;
  35252. - case FLOAT32:
  35253. - er = expectedResForFloatBuf;
  35254. - break;
  35255. - default:
  35256. - throw new AssertionError(
  35257. - "Cannot test TensorBuffer in the DataType of " + tensorBuffer.getDataType());
  35258. - }
  35259. -
  35260. - // Checks getIntArray() and getFloatArray().
  35261. - int[] resIntArr = tensorBuffer.getIntArray();
  35262. - assertThat(resIntArr).isEqualTo(er.intArr);
  35263. - float[] resFloatArr = tensorBuffer.getFloatArray();
  35264. - assertThat(resFloatArr).isEqualTo(er.floatArr);
  35265. - assertThat(tensorBuffer.getShape()).isEqualTo(er.shape);
  35266. -
  35267. - // Checks getIntValue(int index) and getFloatValue(int index).
  35268. - int flatSize = tensorBuffer.getFlatSize();
  35269. - float[] resFloatValues = new float[flatSize];
  35270. - int[] resIntValues = new int[flatSize];
  35271. - for (int i = 0; i < flatSize; i++) {
  35272. - resFloatValues[i] = tensorBuffer.getFloatValue(i);
  35273. - resIntValues[i] = tensorBuffer.getIntValue(i);
  35274. - }
  35275. - assertThat(resFloatValues).isEqualTo(er.floatArr);
  35276. - assertThat(resIntValues).isEqualTo(er.intArr);
  35277. - }
  35278. -
  35279. - /** Gets the data type of an 1D array. */
  35280. - private static DataType dataTypeOfArray(Object arr) {
  35281. - if (arr != null) {
  35282. - Class<?> c = arr.getClass();
  35283. - if (c.isArray()) {
  35284. - c = c.getComponentType();
  35285. - if (float.class.equals(c)) {
  35286. - return DataType.FLOAT32;
  35287. - } else if (int.class.equals(c)) {
  35288. - return DataType.INT32;
  35289. - } else if (byte.class.equals(c)) {
  35290. - return DataType.UINT8;
  35291. - } else if (long.class.equals(c)) {
  35292. - return DataType.INT64;
  35293. - } else if (String.class.equals(c)) {
  35294. - return DataType.STRING;
  35295. +
  35296. + this.srcArrays = builder.srcArrays;
  35297. + this.arrDataTypes = builder.arrDataTypes;
  35298. + this.arrShapes = builder.arrShapes;
  35299. + this.tensorBufferShape = builder.tensorBufferShape;
  35300. + this.expectedResForFloatBuf = builder.expectedResForFloatBuf;
  35301. + this.expectedResForByteBuf = builder.expectedResForByteBuf;
  35302. + }
  35303. +
  35304. + static class ExpectedResults {
  35305. + public float[] floatArr;
  35306. + public int[] intArr;
  35307. + public int[] shape;
  35308. + }
  35309. +
  35310. + public static class Builder {
  35311. + private final ArrayList<Object> srcArrays = new ArrayList<>();
  35312. + private final ArrayList<DataType> arrDataTypes = new ArrayList<>();
  35313. + private final ArrayList<int[]> arrShapes = new ArrayList<>();
  35314. + private int[] tensorBufferShape;
  35315. + private final ExpectedResults expectedResForFloatBuf = new ExpectedResults();
  35316. + private final ExpectedResults expectedResForByteBuf = new ExpectedResults();
  35317. +
  35318. + public static Builder newInstance() {
  35319. + return new Builder();
  35320. + }
  35321. +
  35322. + private Builder() {}
  35323. +
  35324. + /** Loads a test array into the test runner. */
  35325. + public Builder addSrcArray(Object src, int[] shape) {
  35326. + // src should be a primitive 1D array.
  35327. + DataType dataType = dataTypeOfArray(src);
  35328. + switch (dataType) {
  35329. + case INT32:
  35330. + case FLOAT32:
  35331. + srcArrays.add(src);
  35332. + arrDataTypes.add(dataType);
  35333. + arrShapes.add(shape);
  35334. + return this;
  35335. + default:
  35336. + throw new AssertionError(
  35337. + "Cannot resolve srouce arrays in the DataType of " + dataType);
  35338. + }
  35339. + }
  35340. +
  35341. + public Builder setTensorBufferShape(int[] tensorBufferShape) {
  35342. + this.tensorBufferShape = tensorBufferShape;
  35343. + return this;
  35344. }
  35345. - }
  35346. +
  35347. + public Builder setExpectedResults(
  35348. + DataType bufferType, float[] expectedFloatArr, int[] expectedIntArr) {
  35349. + ExpectedResults er;
  35350. + switch (bufferType) {
  35351. + case UINT8:
  35352. + er = expectedResForByteBuf;
  35353. + break;
  35354. + case FLOAT32:
  35355. + er = expectedResForFloatBuf;
  35356. + break;
  35357. + default:
  35358. + throw new AssertionError(
  35359. + "Cannot test TensorBuffer in the DataType of " + bufferType);
  35360. + }
  35361. +
  35362. + er.floatArr = expectedFloatArr;
  35363. + er.intArr = expectedIntArr;
  35364. + return this;
  35365. + }
  35366. +
  35367. + public ArrayTestRunner build() {
  35368. + int[] expectedShape;
  35369. + if (arrShapes.isEmpty()) {
  35370. + // If no array will be loaded, the array is an empty array.
  35371. + expectedShape = new int[] {0};
  35372. + } else {
  35373. + expectedShape = arrShapes.get(arrShapes.size() - 1);
  35374. + }
  35375. + expectedResForByteBuf.shape = expectedShape;
  35376. + expectedResForFloatBuf.shape = expectedShape;
  35377. + return new ArrayTestRunner(this);
  35378. + }
  35379. + }
  35380. +
  35381. + public static DataType[] getBufferTypeList() {
  35382. + return BUFFER_TYPE_LIST;
  35383. + }
  35384. +
  35385. + /**
  35386. + * Runs tests in the following steps: 1. Create a TensorBuffer. If tensorBufferShape is null,
  35387. + * create a dynamic buffer. Otherwise, create a fixed-size buffer accordingly. 2. Load arrays in
  35388. + * srcArrays one by one into the TensotBuffer. 3. Get arrays for each supported primitive types
  35389. + * in TensorBuffer, such as int array and float array for now. Check if the results are
  35390. + * correct. 4. Repeat Step 1 to 3 for all buffer types in BUFFER_TYPE_LIST.
  35391. + */
  35392. + public void run() {
  35393. + for (DataType bufferDataType : BUFFER_TYPE_LIST) {
  35394. + TensorBuffer tensorBuffer;
  35395. + if (tensorBufferShape == null) {
  35396. + tensorBuffer = TensorBuffer.createDynamic(bufferDataType);
  35397. + } else {
  35398. + tensorBuffer = TensorBuffer.createFixedSize(tensorBufferShape, bufferDataType);
  35399. + }
  35400. + for (int i = 0; i < srcArrays.size(); i++) {
  35401. + switch (arrDataTypes.get(i)) {
  35402. + case INT32:
  35403. + int[] arrInt = (int[]) srcArrays.get(i);
  35404. + tensorBuffer.loadArray(arrInt, arrShapes.get(i));
  35405. + break;
  35406. + case FLOAT32:
  35407. + float[] arrFloat = (float[]) srcArrays.get(i);
  35408. + tensorBuffer.loadArray(arrFloat, arrShapes.get(i));
  35409. + break;
  35410. + default:
  35411. + break;
  35412. + }
  35413. + }
  35414. + checkResults(tensorBuffer);
  35415. + }
  35416. + }
  35417. +
  35418. + private void checkResults(TensorBuffer tensorBuffer) {
  35419. + ExpectedResults er;
  35420. + switch (tensorBuffer.getDataType()) {
  35421. + case UINT8:
  35422. + er = expectedResForByteBuf;
  35423. + break;
  35424. + case FLOAT32:
  35425. + er = expectedResForFloatBuf;
  35426. + break;
  35427. + default:
  35428. + throw new AssertionError("Cannot test TensorBuffer in the DataType of "
  35429. + + tensorBuffer.getDataType());
  35430. + }
  35431. +
  35432. + // Checks getIntArray() and getFloatArray().
  35433. + int[] resIntArr = tensorBuffer.getIntArray();
  35434. + assertThat(resIntArr).isEqualTo(er.intArr);
  35435. + float[] resFloatArr = tensorBuffer.getFloatArray();
  35436. + assertThat(resFloatArr).isEqualTo(er.floatArr);
  35437. + assertThat(tensorBuffer.getShape()).isEqualTo(er.shape);
  35438. +
  35439. + // Checks getIntValue(int index) and getFloatValue(int index).
  35440. + int flatSize = tensorBuffer.getFlatSize();
  35441. + float[] resFloatValues = new float[flatSize];
  35442. + int[] resIntValues = new int[flatSize];
  35443. + for (int i = 0; i < flatSize; i++) {
  35444. + resFloatValues[i] = tensorBuffer.getFloatValue(i);
  35445. + resIntValues[i] = tensorBuffer.getIntValue(i);
  35446. + }
  35447. + assertThat(resFloatValues).isEqualTo(er.floatArr);
  35448. + assertThat(resIntValues).isEqualTo(er.intArr);
  35449. + }
  35450. +
  35451. + /** Gets the data type of an 1D array. */
  35452. + private static DataType dataTypeOfArray(Object arr) {
  35453. + if (arr != null) {
  35454. + Class<?> c = arr.getClass();
  35455. + if (c.isArray()) {
  35456. + c = c.getComponentType();
  35457. + if (float.class.equals(c)) {
  35458. + return DataType.FLOAT32;
  35459. + } else if (int.class.equals(c)) {
  35460. + return DataType.INT32;
  35461. + } else if (byte.class.equals(c)) {
  35462. + return DataType.UINT8;
  35463. + } else if (long.class.equals(c)) {
  35464. + return DataType.INT64;
  35465. + } else if (String.class.equals(c)) {
  35466. + return DataType.STRING;
  35467. + }
  35468. + }
  35469. + }
  35470. + throw new IllegalArgumentException(
  35471. + "Requires a 1D array. Cannot resolve data type of " + arr.getClass().getName());
  35472. }
  35473. - throw new IllegalArgumentException(
  35474. - "Requires a 1D array. Cannot resolve data type of " + arr.getClass().getName());
  35475. - }
  35476. }
  35477. /** Tests of {@link org.tensorflow.lite.support.tensorbuffer.TensorBuffer}. */
  35478. @RunWith(RobolectricTestRunner.class)
  35479. public final class TensorBufferTest {
  35480. - // FLOAT_ARRAY1 and INT_ARRAY1 correspond to each other.
  35481. - private static final int[] ARRAY1_SHAPE = new int[] {2, 3};
  35482. - private static final float[] FLOAT_ARRAY1 = new float[] {500.1f, 4.2f, 3.3f, 2.4f, 1.5f, 6.1f};
  35483. - private static final float[] FLOAT_ARRAY1_ROUNDED =
  35484. - new float[] {500.0f, 4.0f, 3.0f, 2.0f, 1.0f, 6.0f};
  35485. - // FLOAT_ARRAY1_CAPPED and INT_ARRAY1_CAPPED correspond to the expected values when converted into
  35486. - // uint8.
  35487. - private static final float[] FLOAT_ARRAY1_CAPPED =
  35488. - new float[] {255.0f, 4.0f, 3.0f, 2.0f, 1.0f, 6.0f};
  35489. - private static final int[] INT_ARRAY1 = new int[] {500, 4, 3, 2, 1, 6};
  35490. - private static final int[] INT_ARRAY1_CAPPED = new int[] {255, 4, 3, 2, 1, 6};
  35491. - // FLOAT_ARRAY2 and INT_ARRAY2 correspond to each other.
  35492. - private static final int[] ARRAY2_SHAPE = new int[] {2, 1};
  35493. - private static final float[] FLOAT_ARRAY2 = new float[] {6.7f, 7.6f};
  35494. - private static final float[] FLOAT_ARRAY2_ROUNDED = new float[] {6.0f, 7.0f};
  35495. - private static final int[] INT_ARRAY2 = new int[] {6, 7};
  35496. - // FLOAT_ARRAY2 and FLOAT_ARRAY3 have the same size.
  35497. - private static final int[] ARRAY3_SHAPE = new int[] {2, 1};
  35498. - private static final float[] FLOAT_ARRAY3 = new float[] {8.2f, 9.9f};
  35499. - private static final float[] FLOAT_ARRAY3_ROUNDED = new float[] {8.0f, 9.0f};
  35500. - // INT_ARRAY2 and INT_ARRAY3 have the same size.
  35501. - private static final int[] INT_ARRAY3 = new int[] {8, 9};
  35502. - private static final int[] EMPTY_ARRAY_SHAPE = new int[] {0};
  35503. - private static final int[] EMPTY_INT_ARRAY = new int[0];
  35504. - private static final float[] EMPTY_FLOAT_ARRAY = new float[0];
  35505. - // Single element array which represents a scalar.
  35506. - private static final int[] SCALAR_ARRAY_SHAPE = new int[] {};
  35507. - private static final float[] FLOAT_SCALAR_ARRAY = new float[] {800.2f};
  35508. - private static final float[] FLOAT_SCALAR_ARRAY_ROUNDED = new float[] {800.0f};
  35509. - private static final float[] FLOAT_SCALAR_ARRAY_CAPPED = new float[] {255.0f};
  35510. - private static final int[] INT_SCALAR_ARRAY = new int[] {800};
  35511. - private static final int[] INT_SCALAR_ARRAY_CAPPED = new int[] {255};
  35512. - // Several different ByteBuffer.
  35513. - private static final ByteBuffer EMPTY_BYTE_BUFFER = ByteBuffer.allocateDirect(0);
  35514. - private static final ByteBuffer FLOAT_BYTE_BUFFER1 = ByteBuffer.allocateDirect(24);
  35515. -
  35516. - static {
  35517. - FLOAT_BYTE_BUFFER1.rewind();
  35518. -
  35519. - FloatBuffer floatBuffer = FLOAT_BYTE_BUFFER1.asFloatBuffer();
  35520. - floatBuffer.put(FLOAT_ARRAY1);
  35521. - }
  35522. -
  35523. - private static final ByteBuffer INT_BYTE_BUFFER2 = ByteBuffer.allocateDirect(2);
  35524. -
  35525. - static {
  35526. - INT_BYTE_BUFFER2.rewind();
  35527. -
  35528. - for (int a : INT_ARRAY2) {
  35529. - INT_BYTE_BUFFER2.put((byte) a);
  35530. - }
  35531. - }
  35532. -
  35533. - @Test
  35534. - public void testCreateFixedSizeTensorBufferFloat() {
  35535. - int[] shape = new int[] {1, 2, 3};
  35536. - TensorBuffer tensorBufferFloat = TensorBuffer.createFixedSize(shape, DataType.FLOAT32);
  35537. - assertThat(tensorBufferFloat).isNotNull();
  35538. - assertThat(tensorBufferFloat.getFlatSize()).isEqualTo(6);
  35539. - }
  35540. -
  35541. - @Test
  35542. - public void testCreateFixedSizeTensorBufferUint8() {
  35543. - int[] shape = new int[] {1, 2, 3};
  35544. - TensorBuffer tensorBufferUint8 = TensorBuffer.createFixedSize(shape, DataType.UINT8);
  35545. - assertThat(tensorBufferUint8).isNotNull();
  35546. - assertThat(tensorBufferUint8.getFlatSize()).isEqualTo(6);
  35547. - }
  35548. -
  35549. - @Test
  35550. - public void testCreateDynamicTensorBufferFloat() {
  35551. - TensorBuffer tensorBufferFloat = TensorBuffer.createDynamic(DataType.FLOAT32);
  35552. - assertThat(tensorBufferFloat).isNotNull();
  35553. - }
  35554. -
  35555. - @Test
  35556. - public void testCreateDynamicTensorBufferUint8() {
  35557. - TensorBuffer tensorBufferUint8 = TensorBuffer.createDynamic(DataType.UINT8);
  35558. - assertThat(tensorBufferUint8).isNotNull();
  35559. - }
  35560. -
  35561. - @Test
  35562. - public void testCreateTensorBufferFromFixedSize() {
  35563. - int[] shape = new int[] {1, 2, 3};
  35564. - TensorBuffer src = TensorBuffer.createFixedSize(shape, DataType.UINT8);
  35565. - TensorBuffer dst = TensorBuffer.createFrom(src, DataType.FLOAT32);
  35566. - assertThat(dst.getShape()).isEqualTo(new int[] {1, 2, 3});
  35567. - }
  35568. -
  35569. - @Test
  35570. - public void testCreateTensorBufferFromDynamicSize() {
  35571. - int[] shape = new int[] {1, 2, 3};
  35572. - TensorBuffer src = TensorBuffer.createDynamic(DataType.UINT8);
  35573. - src.resize(shape);
  35574. - TensorBuffer dst = TensorBuffer.createFrom(src, DataType.FLOAT32);
  35575. - assertThat(dst.getShape()).isEqualTo(new int[] {1, 2, 3});
  35576. - }
  35577. -
  35578. - @Test
  35579. - public void testCreateTensorBufferUInt8FromUInt8() {
  35580. - int[] shape = new int[] {INT_ARRAY1.length};
  35581. - TensorBuffer src = TensorBuffer.createFixedSize(shape, DataType.UINT8);
  35582. - src.loadArray(INT_ARRAY1);
  35583. - TensorBuffer dst = TensorBuffer.createFrom(src, DataType.UINT8);
  35584. - int[] data = dst.getIntArray();
  35585. - assertThat(data).isEqualTo(INT_ARRAY1_CAPPED);
  35586. - }
  35587. -
  35588. - @Test
  35589. - public void testCreateTensorBufferUInt8FromFloat32() {
  35590. - TensorBuffer src = TensorBuffer.createDynamic(DataType.FLOAT32);
  35591. - src.loadArray(FLOAT_ARRAY1, ARRAY1_SHAPE);
  35592. - TensorBuffer dst = TensorBuffer.createFrom(src, DataType.UINT8);
  35593. - int[] data = dst.getIntArray();
  35594. - assertThat(data).isEqualTo(INT_ARRAY1_CAPPED);
  35595. - }
  35596. -
  35597. - @Test
  35598. - public void testCreateTensorBufferFloat32FromUInt8() {
  35599. - TensorBuffer src = TensorBuffer.createDynamic(DataType.UINT8);
  35600. - src.loadArray(INT_ARRAY1, ARRAY1_SHAPE);
  35601. - TensorBuffer dst = TensorBuffer.createFrom(src, DataType.FLOAT32);
  35602. - float[] data = dst.getFloatArray();
  35603. - assertThat(data).isEqualTo(FLOAT_ARRAY1_CAPPED);
  35604. - }
  35605. -
  35606. - @Test
  35607. - public void testCreateTensorBufferFloat32FromFloat32() {
  35608. - int[] shape = new int[] {FLOAT_ARRAY1.length};
  35609. - TensorBuffer src = TensorBuffer.createFixedSize(shape, DataType.FLOAT32);
  35610. - src.loadArray(FLOAT_ARRAY1);
  35611. - TensorBuffer dst = TensorBuffer.createFrom(src, DataType.FLOAT32);
  35612. - float[] data = dst.getFloatArray();
  35613. - assertThat(data).isEqualTo(FLOAT_ARRAY1);
  35614. - }
  35615. -
  35616. - @Test
  35617. - public void testGetBuffer() throws IOException {
  35618. - int[] shape = new int[] {1, 2, 3};
  35619. - TensorBuffer tensorBufferUint8 = TensorBuffer.createFixedSize(shape, DataType.UINT8);
  35620. - assertThat(tensorBufferUint8.getBuffer()).isNotNull();
  35621. - }
  35622. -
  35623. - @Test
  35624. - public void testLoadAndGetIntArrayWithFixedSizeForScalarArray() throws IOException {
  35625. - ArrayTestRunner.Builder.newInstance()
  35626. - .addSrcArray(INT_SCALAR_ARRAY, SCALAR_ARRAY_SHAPE)
  35627. - .setTensorBufferShape(SCALAR_ARRAY_SHAPE)
  35628. - .setExpectedResults(
  35629. - /*bufferType = */ DataType.FLOAT32,
  35630. - /*expectedFloatArr=*/ FLOAT_SCALAR_ARRAY_ROUNDED,
  35631. - /*expectedIntArr=*/ INT_SCALAR_ARRAY)
  35632. - .setExpectedResults(
  35633. - /*bufferType = */ DataType.UINT8,
  35634. - /*expectedFloatArr=*/ FLOAT_SCALAR_ARRAY_CAPPED,
  35635. - /*expectedIntArr=*/ INT_SCALAR_ARRAY_CAPPED)
  35636. - .build()
  35637. - .run();
  35638. - }
  35639. -
  35640. - @Test
  35641. - public void testLoadAndGetFloatArrayWithFixedSizeForScalarArray() throws IOException {
  35642. - ArrayTestRunner.Builder.newInstance()
  35643. - .addSrcArray(FLOAT_SCALAR_ARRAY, SCALAR_ARRAY_SHAPE)
  35644. - .setTensorBufferShape(SCALAR_ARRAY_SHAPE)
  35645. - .setExpectedResults(
  35646. - /*bufferType = */ DataType.FLOAT32,
  35647. - /*expectedFloatArr=*/ FLOAT_SCALAR_ARRAY,
  35648. - /*expectedIntArr=*/ INT_SCALAR_ARRAY)
  35649. - .setExpectedResults(
  35650. - /*bufferType = */ DataType.UINT8,
  35651. - /*expectedFloatArr=*/ FLOAT_SCALAR_ARRAY_CAPPED,
  35652. - /*expectedIntArr=*/ INT_SCALAR_ARRAY_CAPPED)
  35653. - .build()
  35654. - .run();
  35655. - }
  35656. -
  35657. - @Test
  35658. - public void testLoadAndGetIntArrayWithFixedSize() {
  35659. - ArrayTestRunner.Builder.newInstance()
  35660. - .addSrcArray(INT_ARRAY1, ARRAY1_SHAPE)
  35661. - .setTensorBufferShape(ARRAY1_SHAPE)
  35662. - .setExpectedResults(
  35663. - /*bufferType = */ DataType.FLOAT32,
  35664. - /*expectedFloatArr=*/ FLOAT_ARRAY1_ROUNDED,
  35665. - /*expectedIntArr=*/ INT_ARRAY1)
  35666. - .setExpectedResults(
  35667. - /*bufferType = */ DataType.UINT8,
  35668. - /*expectedFloatArr=*/ FLOAT_ARRAY1_CAPPED,
  35669. - /*expectedIntArr=*/ INT_ARRAY1_CAPPED)
  35670. - .build()
  35671. - .run();
  35672. - }
  35673. -
  35674. - @Test
  35675. - public void testLoadAndGetFloatArrayWithFixedSize() {
  35676. - ArrayTestRunner.Builder.newInstance()
  35677. - .addSrcArray(FLOAT_ARRAY1, ARRAY1_SHAPE)
  35678. - .setTensorBufferShape(ARRAY1_SHAPE)
  35679. - .setExpectedResults(
  35680. - /*bufferType = */ DataType.FLOAT32,
  35681. - /*expectedFloatArr=*/ FLOAT_ARRAY1,
  35682. - /*expectedIntArr=*/ INT_ARRAY1)
  35683. - .setExpectedResults(
  35684. - /*bufferType = */ DataType.UINT8,
  35685. - /*expectedFloatArr=*/ FLOAT_ARRAY1_CAPPED,
  35686. - /*expectedIntArr=*/ INT_ARRAY1_CAPPED)
  35687. - .build()
  35688. - .run();
  35689. - }
  35690. -
  35691. - @Test
  35692. - public void testRepeatedLoadAndGetIntArrayWithSameFixedSize() {
  35693. - ArrayTestRunner.Builder.newInstance()
  35694. - .addSrcArray(INT_ARRAY2, ARRAY2_SHAPE)
  35695. - .addSrcArray(INT_ARRAY3, ARRAY3_SHAPE)
  35696. - .setTensorBufferShape(ARRAY2_SHAPE)
  35697. - .setExpectedResults(
  35698. - /*bufferType = */ DataType.FLOAT32,
  35699. - /*expectedFloatArr=*/ FLOAT_ARRAY3_ROUNDED,
  35700. - /*expectedIntArr=*/ INT_ARRAY3)
  35701. - .setExpectedResults(
  35702. - /*bufferType = */ DataType.UINT8,
  35703. - /*expectedFloatArr=*/ FLOAT_ARRAY3_ROUNDED,
  35704. - /*expectedIntArr=*/ INT_ARRAY3)
  35705. - .build()
  35706. - .run();
  35707. - }
  35708. -
  35709. - @Test
  35710. - public void testRepeatedLoadAndGetFloatArrayWithSameFixedSize() {
  35711. - ArrayTestRunner.Builder.newInstance()
  35712. - .addSrcArray(FLOAT_ARRAY2, ARRAY2_SHAPE)
  35713. - .addSrcArray(FLOAT_ARRAY3, ARRAY3_SHAPE)
  35714. - .setTensorBufferShape(ARRAY2_SHAPE)
  35715. - .setExpectedResults(
  35716. - /*bufferType = */ DataType.FLOAT32,
  35717. - /*expectedFloatArr=*/ FLOAT_ARRAY3,
  35718. - /*expectedIntArr=*/ INT_ARRAY3)
  35719. - .setExpectedResults(
  35720. - /*bufferType = */ DataType.UINT8,
  35721. - /*expectedFloatArr=*/ FLOAT_ARRAY3_ROUNDED,
  35722. - /*expectedIntArr=*/ INT_ARRAY3)
  35723. - .build()
  35724. - .run();
  35725. - }
  35726. -
  35727. - @Test
  35728. - public void testRepeatedLoadIntArrayWithDifferentFixedSize() {
  35729. - int[] srcArr1 = INT_ARRAY1;
  35730. - int[] srcArr2 = INT_ARRAY2;
  35731. - for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
  35732. - TensorBuffer tensorBuffer =
  35733. - TensorBuffer.createFixedSize(new int[] {srcArr1.length}, dataType);
  35734. - tensorBuffer.loadArray(srcArr1, new int[] {srcArr1.length});
  35735. - // Load srcArr2 which had different size as srcArr1.
  35736. - Assert.assertThrows(
  35737. - IllegalArgumentException.class,
  35738. - () -> tensorBuffer.loadArray(srcArr2, new int[] {srcArr2.length}));
  35739. - }
  35740. - }
  35741. -
  35742. - @Test
  35743. - public void testRepeatedLoadFloatArrayWithDifferentFixedSize() {
  35744. - float[] srcArr1 = FLOAT_ARRAY1;
  35745. - float[] srcArr2 = FLOAT_ARRAY2;
  35746. - for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
  35747. - TensorBuffer tensorBuffer =
  35748. - TensorBuffer.createFixedSize(new int[] {srcArr1.length}, dataType);
  35749. - tensorBuffer.loadArray(srcArr1, new int[] {srcArr1.length});
  35750. - // Load srcArr2 which had different size as srcArr1.
  35751. - Assert.assertThrows(
  35752. - IllegalArgumentException.class,
  35753. - () -> tensorBuffer.loadArray(srcArr2, new int[] {srcArr2.length}));
  35754. - }
  35755. - }
  35756. -
  35757. - @Test
  35758. - public void testLoadAndGetIntArrayWithDynamicSize() {
  35759. - ArrayTestRunner.Builder.newInstance()
  35760. - .addSrcArray(INT_ARRAY1, ARRAY1_SHAPE)
  35761. - .setExpectedResults(
  35762. - /*bufferType = */ DataType.FLOAT32,
  35763. - /*expectedFloatArr=*/ FLOAT_ARRAY1_ROUNDED,
  35764. - /*expectedIntArr=*/ INT_ARRAY1)
  35765. - .setExpectedResults(
  35766. - /*bufferType = */ DataType.UINT8,
  35767. - /*expectedFloatArr=*/ FLOAT_ARRAY1_CAPPED,
  35768. - /*expectedIntArr=*/ INT_ARRAY1_CAPPED)
  35769. - .build()
  35770. - .run();
  35771. - }
  35772. -
  35773. - @Test
  35774. - public void testLoadAndGetFloatArrayWithDynamicSize() {
  35775. - ArrayTestRunner.Builder.newInstance()
  35776. - .addSrcArray(FLOAT_ARRAY1, ARRAY1_SHAPE)
  35777. - .setExpectedResults(
  35778. - /*bufferType = */ DataType.FLOAT32,
  35779. - /*expectedFloatArr=*/ FLOAT_ARRAY1,
  35780. - /*expectedIntArr=*/ INT_ARRAY1)
  35781. - .setExpectedResults(
  35782. - /*bufferType = */ DataType.UINT8,
  35783. - /*expectedFloatArr=*/ FLOAT_ARRAY1_CAPPED,
  35784. - /*expectedIntArr=*/ INT_ARRAY1_CAPPED)
  35785. - .build()
  35786. - .run();
  35787. - }
  35788. -
  35789. - @Test
  35790. - public void testRepeatedLoadAndGetIntArrayWithDifferentDynamicSize() {
  35791. - ArrayTestRunner.Builder.newInstance()
  35792. - .addSrcArray(INT_ARRAY1, ARRAY1_SHAPE)
  35793. - .addSrcArray(INT_ARRAY2, ARRAY2_SHAPE)
  35794. - .setExpectedResults(
  35795. - /*bufferType = */ DataType.FLOAT32,
  35796. - /*expectedFloatArr=*/ FLOAT_ARRAY2_ROUNDED,
  35797. - /*expectedIntArr=*/ INT_ARRAY2)
  35798. - .setExpectedResults(
  35799. - /*bufferType = */ DataType.UINT8,
  35800. - /*expectedFloatArr=*/ FLOAT_ARRAY2_ROUNDED,
  35801. - /*expectedIntArr=*/ INT_ARRAY2)
  35802. - .build()
  35803. - .run();
  35804. - }
  35805. -
  35806. - @Test
  35807. - public void testRepeatedLoadAndGetFloatArrayWithDifferentDynamicSize() {
  35808. - ArrayTestRunner.Builder.newInstance()
  35809. - .addSrcArray(FLOAT_ARRAY1, ARRAY1_SHAPE)
  35810. - .addSrcArray(FLOAT_ARRAY2, ARRAY2_SHAPE)
  35811. - .setExpectedResults(
  35812. - /*bufferType = */ DataType.FLOAT32,
  35813. - /*expectedFloatArr=*/ FLOAT_ARRAY2,
  35814. - /*expectedIntArr=*/ INT_ARRAY2)
  35815. - .setExpectedResults(
  35816. - /*bufferType = */ DataType.UINT8,
  35817. - /*expectedFloatArr=*/ FLOAT_ARRAY2_ROUNDED,
  35818. - /*expectedIntArr=*/ INT_ARRAY2)
  35819. - .build()
  35820. - .run();
  35821. - }
  35822. -
  35823. - @Test
  35824. - public void testGetForEmptyArrayWithFixedSizeBuffer() {
  35825. - ArrayTestRunner.Builder.newInstance()
  35826. - .setTensorBufferShape(EMPTY_ARRAY_SHAPE)
  35827. - .setExpectedResults(
  35828. - /*bufferType = */ DataType.FLOAT32,
  35829. - /*expectedFloatArr=*/ EMPTY_FLOAT_ARRAY,
  35830. - /*expectedIntArr=*/ EMPTY_INT_ARRAY)
  35831. - .setExpectedResults(
  35832. - /*bufferType = */ DataType.UINT8,
  35833. - /*expectedFloatArr=*/ EMPTY_FLOAT_ARRAY,
  35834. - /*expectedIntArr=*/ EMPTY_INT_ARRAY)
  35835. - .build()
  35836. - .run();
  35837. - }
  35838. -
  35839. - @Test
  35840. - public void testGetForEmptyArrayWithDynamicBuffer() {
  35841. - ArrayTestRunner.Builder.newInstance()
  35842. - .setExpectedResults(
  35843. - /*bufferType = */ DataType.FLOAT32,
  35844. - /*expectedFloatArr=*/ EMPTY_FLOAT_ARRAY,
  35845. - /*expectedIntArr=*/ EMPTY_INT_ARRAY)
  35846. - .setExpectedResults(
  35847. - /*bufferType = */ DataType.UINT8,
  35848. - /*expectedFloatArr=*/ EMPTY_FLOAT_ARRAY,
  35849. - /*expectedIntArr=*/ EMPTY_INT_ARRAY)
  35850. - .build()
  35851. - .run();
  35852. - }
  35853. -
  35854. - @Test
  35855. - public void testRepeatedLoadAndGetForEmptyArray() {
  35856. - ArrayTestRunner.Builder.newInstance()
  35857. - .addSrcArray(EMPTY_INT_ARRAY, EMPTY_ARRAY_SHAPE)
  35858. - .addSrcArray(FLOAT_ARRAY2, ARRAY2_SHAPE)
  35859. - .addSrcArray(EMPTY_FLOAT_ARRAY, EMPTY_ARRAY_SHAPE)
  35860. - .setExpectedResults(
  35861. - /*bufferType = */ DataType.FLOAT32,
  35862. - /*expectedFloatArr=*/ EMPTY_FLOAT_ARRAY,
  35863. - /*expectedIntArr=*/ EMPTY_INT_ARRAY)
  35864. - .setExpectedResults(
  35865. - /*bufferType = */ DataType.UINT8,
  35866. - /*expectedFloatArr=*/ EMPTY_FLOAT_ARRAY,
  35867. - /*expectedIntArr=*/ EMPTY_INT_ARRAY)
  35868. - .build()
  35869. - .run();
  35870. - }
  35871. -
  35872. - @Test
  35873. - public void testLoadNullIntArrays() {
  35874. - int[] nullArray = null;
  35875. - int[] shape = new int[] {};
  35876. - for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
  35877. - TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType);
  35878. - Assert.assertThrows(
  35879. - NullPointerException.class, () -> tensorBuffer.loadArray(nullArray, shape));
  35880. - }
  35881. - }
  35882. -
  35883. - @Test
  35884. - public void testLoadNullFloatArrays() {
  35885. - float[] nullArray = null;
  35886. - int[] shape = new int[] {};
  35887. - for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
  35888. - TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType);
  35889. - Assert.assertThrows(
  35890. - NullPointerException.class, () -> tensorBuffer.loadArray(nullArray, shape));
  35891. - }
  35892. - }
  35893. -
  35894. - @Test
  35895. - public void testLoadFloatArraysWithNullShape() {
  35896. - float[] arr = new float[] {1.0f};
  35897. - int[] nullShape = null;
  35898. - for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
  35899. - TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType);
  35900. - Assert.assertThrows(NullPointerException.class, () -> tensorBuffer.loadArray(arr, nullShape));
  35901. - }
  35902. - }
  35903. -
  35904. - @Test
  35905. - public void testLoadIntArraysWithNullShape() {
  35906. - int[] arr = new int[] {1};
  35907. - int[] nullShape = null;
  35908. - for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
  35909. - TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType);
  35910. - Assert.assertThrows(NullPointerException.class, () -> tensorBuffer.loadArray(arr, nullShape));
  35911. - }
  35912. - }
  35913. -
  35914. - @Test
  35915. - public void testLoadIntArraysWithoutShapeAndArrayDoesNotMatchShape() {
  35916. - for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
  35917. - TensorBuffer fixedTensorBuffer = TensorBuffer.createFixedSize(ARRAY1_SHAPE, dataType);
  35918. - Assert.assertThrows(
  35919. - IllegalArgumentException.class, () -> fixedTensorBuffer.loadArray(INT_ARRAY2));
  35920. - }
  35921. - }
  35922. -
  35923. - @Test
  35924. - public void testLoadFloatArraysWithoutShapeAndArrayDoesNotMatchShape() {
  35925. - for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
  35926. - TensorBuffer fixedTensorBuffer = TensorBuffer.createFixedSize(ARRAY1_SHAPE, dataType);
  35927. - Assert.assertThrows(
  35928. - IllegalArgumentException.class, () -> fixedTensorBuffer.loadArray(FLOAT_ARRAY2));
  35929. - }
  35930. - }
  35931. -
  35932. - @Test
  35933. - public void testLoadByteBufferForNullBuffer() {
  35934. - ByteBuffer byteBuffer = null;
  35935. - int[] shape = new int[] {};
  35936. - for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
  35937. - TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType);
  35938. - Assert.assertThrows(
  35939. - NullPointerException.class, () -> tensorBuffer.loadBuffer(byteBuffer, shape));
  35940. - }
  35941. - }
  35942. -
  35943. - @Test
  35944. - public void testLoadByteBufferForEmptyBuffer() {
  35945. - for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
  35946. - TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType);
  35947. - tensorBuffer.loadBuffer(EMPTY_BYTE_BUFFER, EMPTY_ARRAY_SHAPE);
  35948. - assertThat(tensorBuffer.getFlatSize()).isEqualTo(0);
  35949. - }
  35950. - }
  35951. -
  35952. - @Test
  35953. - public void testLoadByteBufferWithDifferentFixedSize() {
  35954. - // Create a fixed-size TensorBuffer with size 2, and load a ByteBuffer with size 5.
  35955. - int[] tensorBufferShape = new int[] {2};
  35956. - TensorBuffer tensorBuffer = TensorBuffer.createFixedSize(tensorBufferShape, DataType.FLOAT32);
  35957. - Assert.assertThrows(
  35958. - IllegalArgumentException.class,
  35959. - () -> tensorBuffer.loadBuffer(FLOAT_BYTE_BUFFER1, ARRAY1_SHAPE));
  35960. - }
  35961. -
  35962. - @Test
  35963. - public void testLoadByteBufferWithMisMatchDataType() {
  35964. - TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.FLOAT32);
  35965. - int[] wrongShape = new int[] {1};
  35966. - // Size of INT_BYTE_BUFFER is 8 bytes. It does not match the specified shape.
  35967. - Assert.assertThrows(
  35968. - IllegalArgumentException.class,
  35969. - () -> tensorBuffer.loadBuffer(INT_BYTE_BUFFER2, wrongShape));
  35970. - }
  35971. -
  35972. - @Test
  35973. - public void testLoadByteBufferForTensorBufferFloat() {
  35974. - TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.FLOAT32);
  35975. - tensorBuffer.loadBuffer(FLOAT_BYTE_BUFFER1, ARRAY1_SHAPE);
  35976. - assertThat(tensorBuffer.getFloatArray()).isEqualTo(FLOAT_ARRAY1);
  35977. - assertThat(tensorBuffer.getShape()).isEqualTo(ARRAY1_SHAPE);
  35978. - }
  35979. -
  35980. - @Test
  35981. - public void testLoadByteBufferForTensorBufferUint8() {
  35982. - TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.UINT8);
  35983. - tensorBuffer.loadBuffer(INT_BYTE_BUFFER2, ARRAY2_SHAPE);
  35984. - assertThat(tensorBuffer.getIntArray()).isEqualTo(INT_ARRAY2);
  35985. - assertThat(tensorBuffer.getShape()).isEqualTo(ARRAY2_SHAPE);
  35986. - }
  35987. -
  35988. - @Test
  35989. - public void testGetFloatValueWithInvalidIndex() {
  35990. - float[] arrayWithSixElements = FLOAT_ARRAY1;
  35991. - int[] shapeOfArrayWithSixElements = ARRAY1_SHAPE;
  35992. - int[] invalidIndexes = {-1, 7};
  35993. - for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
  35994. - TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType);
  35995. - tensorBuffer.loadArray(arrayWithSixElements, shapeOfArrayWithSixElements);
  35996. - for (int invalidIndex : invalidIndexes) {
  35997. - Assert.assertThrows(
  35998. - IndexOutOfBoundsException.class, () -> tensorBuffer.getFloatValue(invalidIndex));
  35999. - }
  36000. - }
  36001. - }
  36002. -
  36003. - @Test
  36004. - public void testGetFloatValueFromScalarWithInvalidIndex() {
  36005. - int[] shape = new int[] {};
  36006. - float[] arr = new float[] {10.0f};
  36007. - int[] invalidIndexes =
  36008. - new int[] {-1, 1}; // -1 is negative, and 1 is not smaller than the flatsize.
  36009. - for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
  36010. - TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType);
  36011. - tensorBuffer.loadArray(arr, shape);
  36012. - for (int invalidIndex : invalidIndexes) {
  36013. - Assert.assertThrows(
  36014. - IndexOutOfBoundsException.class, () -> tensorBuffer.getFloatValue(invalidIndex));
  36015. - }
  36016. - }
  36017. - }
  36018. -
  36019. - @Test
  36020. - public void testGetIntValueWithInvalidIndex() {
  36021. - float[] arrayWithSixElements = FLOAT_ARRAY1;
  36022. - int[] shapeOfArrayWithSixElements = ARRAY1_SHAPE;
  36023. - int[] invalidIndexes = {-1, 7};
  36024. - for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
  36025. - TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType);
  36026. - tensorBuffer.loadArray(arrayWithSixElements, shapeOfArrayWithSixElements);
  36027. - for (int invalidIndex : invalidIndexes) {
  36028. - Assert.assertThrows(
  36029. - IndexOutOfBoundsException.class, () -> tensorBuffer.getIntValue(invalidIndex));
  36030. - }
  36031. - }
  36032. - }
  36033. -
  36034. - @Test
  36035. - public void testGetIntValueFromScalarWithInvalidIndex() {
  36036. - int[] shape = new int[] {};
  36037. - float[] arr = new float[] {10.0f};
  36038. - int[] invalidIndexes =
  36039. - new int[] {-1, 1}; // -1 is negative, and 1 is not smaller than the flatsize.
  36040. - for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
  36041. - TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType);
  36042. - tensorBuffer.loadArray(arr, shape);
  36043. - for (int invalidIndex : invalidIndexes) {
  36044. - Assert.assertThrows(
  36045. - IndexOutOfBoundsException.class, () -> tensorBuffer.getIntValue(invalidIndex));
  36046. - }
  36047. - }
  36048. - }
  36049. -
  36050. - @Test
  36051. - public void testLoadByteBufferSliceForTensorBufferFloat() {
  36052. - TensorBuffer original = TensorBuffer.createDynamic(DataType.FLOAT32);
  36053. - original.loadArray(new float[] {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}, new int[] {6});
  36054. - ByteBuffer buffer = original.getBuffer();
  36055. - // Slice original buffer to 3 sub-buffer, each of which has 2 element
  36056. - int numBuffers = 3;
  36057. - int numElements = 2;
  36058. - int subArrayLength = numElements * original.getTypeSize();
  36059. - TensorBuffer tensorSlice = TensorBuffer.createDynamic(original.getDataType());
  36060. - for (int i = 0; i < numBuffers; i++) {
  36061. - buffer.position(i * subArrayLength);
  36062. - ByteBuffer subBuffer = buffer.slice();
  36063. - // ByteBuffer.slice doesn't keep order.
  36064. - subBuffer.order(buffer.order()).limit(subArrayLength);
  36065. - tensorSlice.loadBuffer(subBuffer, new int[] {numElements});
  36066. - float[] arraySlice = tensorSlice.getFloatArray();
  36067. - assertThat(arraySlice.length).isEqualTo(numElements);
  36068. - assertThat(arraySlice[0]).isEqualTo(i * numElements + 1);
  36069. - assertThat(arraySlice[1]).isEqualTo(i * numElements + 2);
  36070. - }
  36071. - }
  36072. -
  36073. - @Test
  36074. - public void testLoadByteBufferSliceForTensorBufferUInt8() {
  36075. - TensorBuffer original = TensorBuffer.createDynamic(DataType.UINT8);
  36076. - original.loadArray(new int[] {1, 2, 3, 4, 5, 6}, new int[] {6});
  36077. - ByteBuffer buffer = original.getBuffer();
  36078. - // Slice original buffer to 3 sub-buffer, each of which has 2 element
  36079. - int numBuffers = 3;
  36080. - int numElements = 2;
  36081. - int subArrayLength = numElements * original.getTypeSize();
  36082. - TensorBuffer tensorSlice = TensorBuffer.createDynamic(original.getDataType());
  36083. - for (int i = 0; i < numBuffers; i++) {
  36084. - buffer.position(i * subArrayLength);
  36085. - ByteBuffer subBuffer = buffer.slice();
  36086. - // ByteBuffer.slice doesn't keep order.
  36087. - subBuffer.order(buffer.order()).limit(subArrayLength);
  36088. - tensorSlice.loadBuffer(subBuffer, new int[] {numElements});
  36089. - int[] arraySlice = tensorSlice.getIntArray();
  36090. - assertThat(arraySlice.length).isEqualTo(numElements);
  36091. - assertThat(arraySlice[0]).isEqualTo(i * numElements + 1);
  36092. - assertThat(arraySlice[1]).isEqualTo(i * numElements + 2);
  36093. - }
  36094. - }
  36095. -
  36096. - @Test
  36097. - public void getShapeFailsAfterByteBufferChanged() {
  36098. - TensorBuffer tensorBuffer = TensorBuffer.createFixedSize(new int[] {3, 2}, DataType.FLOAT32);
  36099. - ByteBuffer byteBuffer = tensorBuffer.getBuffer();
  36100. - byteBuffer.limit(5);
  36101. -
  36102. - IllegalStateException exception =
  36103. - assertThrows(IllegalStateException.class, tensorBuffer::getShape);
  36104. - assertThat(exception)
  36105. - .hasMessageThat()
  36106. - .contains(
  36107. - "The size of underlying ByteBuffer (5) and the shape ([3, 2]) do not match. The"
  36108. + // FLOAT_ARRAY1 and INT_ARRAY1 correspond to each other.
  36109. + private static final int[] ARRAY1_SHAPE = new int[] {2, 3};
  36110. + private static final float[] FLOAT_ARRAY1 = new float[] {500.1f, 4.2f, 3.3f, 2.4f, 1.5f, 6.1f};
  36111. + private static final float[] FLOAT_ARRAY1_ROUNDED =
  36112. + new float[] {500.0f, 4.0f, 3.0f, 2.0f, 1.0f, 6.0f};
  36113. + // FLOAT_ARRAY1_CAPPED and INT_ARRAY1_CAPPED correspond to the expected values when converted
  36114. + // into uint8.
  36115. + private static final float[] FLOAT_ARRAY1_CAPPED =
  36116. + new float[] {255.0f, 4.0f, 3.0f, 2.0f, 1.0f, 6.0f};
  36117. + private static final int[] INT_ARRAY1 = new int[] {500, 4, 3, 2, 1, 6};
  36118. + private static final int[] INT_ARRAY1_CAPPED = new int[] {255, 4, 3, 2, 1, 6};
  36119. + // FLOAT_ARRAY2 and INT_ARRAY2 correspond to each other.
  36120. + private static final int[] ARRAY2_SHAPE = new int[] {2, 1};
  36121. + private static final float[] FLOAT_ARRAY2 = new float[] {6.7f, 7.6f};
  36122. + private static final float[] FLOAT_ARRAY2_ROUNDED = new float[] {6.0f, 7.0f};
  36123. + private static final int[] INT_ARRAY2 = new int[] {6, 7};
  36124. + // FLOAT_ARRAY2 and FLOAT_ARRAY3 have the same size.
  36125. + private static final int[] ARRAY3_SHAPE = new int[] {2, 1};
  36126. + private static final float[] FLOAT_ARRAY3 = new float[] {8.2f, 9.9f};
  36127. + private static final float[] FLOAT_ARRAY3_ROUNDED = new float[] {8.0f, 9.0f};
  36128. + // INT_ARRAY2 and INT_ARRAY3 have the same size.
  36129. + private static final int[] INT_ARRAY3 = new int[] {8, 9};
  36130. + private static final int[] EMPTY_ARRAY_SHAPE = new int[] {0};
  36131. + private static final int[] EMPTY_INT_ARRAY = new int[0];
  36132. + private static final float[] EMPTY_FLOAT_ARRAY = new float[0];
  36133. + // Single element array which represents a scalar.
  36134. + private static final int[] SCALAR_ARRAY_SHAPE = new int[] {};
  36135. + private static final float[] FLOAT_SCALAR_ARRAY = new float[] {800.2f};
  36136. + private static final float[] FLOAT_SCALAR_ARRAY_ROUNDED = new float[] {800.0f};
  36137. + private static final float[] FLOAT_SCALAR_ARRAY_CAPPED = new float[] {255.0f};
  36138. + private static final int[] INT_SCALAR_ARRAY = new int[] {800};
  36139. + private static final int[] INT_SCALAR_ARRAY_CAPPED = new int[] {255};
  36140. + // Several different ByteBuffer.
  36141. + private static final ByteBuffer EMPTY_BYTE_BUFFER = ByteBuffer.allocateDirect(0);
  36142. + private static final ByteBuffer FLOAT_BYTE_BUFFER1 = ByteBuffer.allocateDirect(24);
  36143. +
  36144. + static {
  36145. + FLOAT_BYTE_BUFFER1.rewind();
  36146. +
  36147. + FloatBuffer floatBuffer = FLOAT_BYTE_BUFFER1.asFloatBuffer();
  36148. + floatBuffer.put(FLOAT_ARRAY1);
  36149. + }
  36150. +
  36151. + private static final ByteBuffer INT_BYTE_BUFFER2 = ByteBuffer.allocateDirect(2);
  36152. +
  36153. + static {
  36154. + INT_BYTE_BUFFER2.rewind();
  36155. +
  36156. + for (int a : INT_ARRAY2) {
  36157. + INT_BYTE_BUFFER2.put((byte) a);
  36158. + }
  36159. + }
  36160. +
  36161. + @Test
  36162. + public void testCreateFixedSizeTensorBufferFloat() {
  36163. + int[] shape = new int[] {1, 2, 3};
  36164. + TensorBuffer tensorBufferFloat = TensorBuffer.createFixedSize(shape, DataType.FLOAT32);
  36165. + assertThat(tensorBufferFloat).isNotNull();
  36166. + assertThat(tensorBufferFloat.getFlatSize()).isEqualTo(6);
  36167. + }
  36168. +
  36169. + @Test
  36170. + public void testCreateFixedSizeTensorBufferUint8() {
  36171. + int[] shape = new int[] {1, 2, 3};
  36172. + TensorBuffer tensorBufferUint8 = TensorBuffer.createFixedSize(shape, DataType.UINT8);
  36173. + assertThat(tensorBufferUint8).isNotNull();
  36174. + assertThat(tensorBufferUint8.getFlatSize()).isEqualTo(6);
  36175. + }
  36176. +
  36177. + @Test
  36178. + public void testCreateDynamicTensorBufferFloat() {
  36179. + TensorBuffer tensorBufferFloat = TensorBuffer.createDynamic(DataType.FLOAT32);
  36180. + assertThat(tensorBufferFloat).isNotNull();
  36181. + }
  36182. +
  36183. + @Test
  36184. + public void testCreateDynamicTensorBufferUint8() {
  36185. + TensorBuffer tensorBufferUint8 = TensorBuffer.createDynamic(DataType.UINT8);
  36186. + assertThat(tensorBufferUint8).isNotNull();
  36187. + }
  36188. +
  36189. + @Test
  36190. + public void testCreateTensorBufferFromFixedSize() {
  36191. + int[] shape = new int[] {1, 2, 3};
  36192. + TensorBuffer src = TensorBuffer.createFixedSize(shape, DataType.UINT8);
  36193. + TensorBuffer dst = TensorBuffer.createFrom(src, DataType.FLOAT32);
  36194. + assertThat(dst.getShape()).isEqualTo(new int[] {1, 2, 3});
  36195. + }
  36196. +
  36197. + @Test
  36198. + public void testCreateTensorBufferFromDynamicSize() {
  36199. + int[] shape = new int[] {1, 2, 3};
  36200. + TensorBuffer src = TensorBuffer.createDynamic(DataType.UINT8);
  36201. + src.resize(shape);
  36202. + TensorBuffer dst = TensorBuffer.createFrom(src, DataType.FLOAT32);
  36203. + assertThat(dst.getShape()).isEqualTo(new int[] {1, 2, 3});
  36204. + }
  36205. +
  36206. + @Test
  36207. + public void testCreateTensorBufferUInt8FromUInt8() {
  36208. + int[] shape = new int[] {INT_ARRAY1.length};
  36209. + TensorBuffer src = TensorBuffer.createFixedSize(shape, DataType.UINT8);
  36210. + src.loadArray(INT_ARRAY1);
  36211. + TensorBuffer dst = TensorBuffer.createFrom(src, DataType.UINT8);
  36212. + int[] data = dst.getIntArray();
  36213. + assertThat(data).isEqualTo(INT_ARRAY1_CAPPED);
  36214. + }
  36215. +
  36216. + @Test
  36217. + public void testCreateTensorBufferUInt8FromFloat32() {
  36218. + TensorBuffer src = TensorBuffer.createDynamic(DataType.FLOAT32);
  36219. + src.loadArray(FLOAT_ARRAY1, ARRAY1_SHAPE);
  36220. + TensorBuffer dst = TensorBuffer.createFrom(src, DataType.UINT8);
  36221. + int[] data = dst.getIntArray();
  36222. + assertThat(data).isEqualTo(INT_ARRAY1_CAPPED);
  36223. + }
  36224. +
  36225. + @Test
  36226. + public void testCreateTensorBufferFloat32FromUInt8() {
  36227. + TensorBuffer src = TensorBuffer.createDynamic(DataType.UINT8);
  36228. + src.loadArray(INT_ARRAY1, ARRAY1_SHAPE);
  36229. + TensorBuffer dst = TensorBuffer.createFrom(src, DataType.FLOAT32);
  36230. + float[] data = dst.getFloatArray();
  36231. + assertThat(data).isEqualTo(FLOAT_ARRAY1_CAPPED);
  36232. + }
  36233. +
  36234. + @Test
  36235. + public void testCreateTensorBufferFloat32FromFloat32() {
  36236. + int[] shape = new int[] {FLOAT_ARRAY1.length};
  36237. + TensorBuffer src = TensorBuffer.createFixedSize(shape, DataType.FLOAT32);
  36238. + src.loadArray(FLOAT_ARRAY1);
  36239. + TensorBuffer dst = TensorBuffer.createFrom(src, DataType.FLOAT32);
  36240. + float[] data = dst.getFloatArray();
  36241. + assertThat(data).isEqualTo(FLOAT_ARRAY1);
  36242. + }
  36243. +
  36244. + @Test
  36245. + public void testGetBuffer() throws IOException {
  36246. + int[] shape = new int[] {1, 2, 3};
  36247. + TensorBuffer tensorBufferUint8 = TensorBuffer.createFixedSize(shape, DataType.UINT8);
  36248. + assertThat(tensorBufferUint8.getBuffer()).isNotNull();
  36249. + }
  36250. +
  36251. + @Test
  36252. + public void testLoadAndGetIntArrayWithFixedSizeForScalarArray() throws IOException {
  36253. + ArrayTestRunner.Builder.newInstance()
  36254. + .addSrcArray(INT_SCALAR_ARRAY, SCALAR_ARRAY_SHAPE)
  36255. + .setTensorBufferShape(SCALAR_ARRAY_SHAPE)
  36256. + .setExpectedResults(
  36257. + /*bufferType = */ DataType.FLOAT32,
  36258. + /*expectedFloatArr=*/FLOAT_SCALAR_ARRAY_ROUNDED,
  36259. + /*expectedIntArr=*/INT_SCALAR_ARRAY)
  36260. + .setExpectedResults(
  36261. + /*bufferType = */ DataType.UINT8,
  36262. + /*expectedFloatArr=*/FLOAT_SCALAR_ARRAY_CAPPED,
  36263. + /*expectedIntArr=*/INT_SCALAR_ARRAY_CAPPED)
  36264. + .build()
  36265. + .run();
  36266. + }
  36267. +
  36268. + @Test
  36269. + public void testLoadAndGetFloatArrayWithFixedSizeForScalarArray() throws IOException {
  36270. + ArrayTestRunner.Builder.newInstance()
  36271. + .addSrcArray(FLOAT_SCALAR_ARRAY, SCALAR_ARRAY_SHAPE)
  36272. + .setTensorBufferShape(SCALAR_ARRAY_SHAPE)
  36273. + .setExpectedResults(
  36274. + /*bufferType = */ DataType.FLOAT32,
  36275. + /*expectedFloatArr=*/FLOAT_SCALAR_ARRAY,
  36276. + /*expectedIntArr=*/INT_SCALAR_ARRAY)
  36277. + .setExpectedResults(
  36278. + /*bufferType = */ DataType.UINT8,
  36279. + /*expectedFloatArr=*/FLOAT_SCALAR_ARRAY_CAPPED,
  36280. + /*expectedIntArr=*/INT_SCALAR_ARRAY_CAPPED)
  36281. + .build()
  36282. + .run();
  36283. + }
  36284. +
  36285. + @Test
  36286. + public void testLoadAndGetIntArrayWithFixedSize() {
  36287. + ArrayTestRunner.Builder.newInstance()
  36288. + .addSrcArray(INT_ARRAY1, ARRAY1_SHAPE)
  36289. + .setTensorBufferShape(ARRAY1_SHAPE)
  36290. + .setExpectedResults(
  36291. + /*bufferType = */ DataType.FLOAT32,
  36292. + /*expectedFloatArr=*/FLOAT_ARRAY1_ROUNDED,
  36293. + /*expectedIntArr=*/INT_ARRAY1)
  36294. + .setExpectedResults(
  36295. + /*bufferType = */ DataType.UINT8,
  36296. + /*expectedFloatArr=*/FLOAT_ARRAY1_CAPPED,
  36297. + /*expectedIntArr=*/INT_ARRAY1_CAPPED)
  36298. + .build()
  36299. + .run();
  36300. + }
  36301. +
  36302. + @Test
  36303. + public void testLoadAndGetFloatArrayWithFixedSize() {
  36304. + ArrayTestRunner.Builder.newInstance()
  36305. + .addSrcArray(FLOAT_ARRAY1, ARRAY1_SHAPE)
  36306. + .setTensorBufferShape(ARRAY1_SHAPE)
  36307. + .setExpectedResults(
  36308. + /*bufferType = */ DataType.FLOAT32,
  36309. + /*expectedFloatArr=*/FLOAT_ARRAY1,
  36310. + /*expectedIntArr=*/INT_ARRAY1)
  36311. + .setExpectedResults(
  36312. + /*bufferType = */ DataType.UINT8,
  36313. + /*expectedFloatArr=*/FLOAT_ARRAY1_CAPPED,
  36314. + /*expectedIntArr=*/INT_ARRAY1_CAPPED)
  36315. + .build()
  36316. + .run();
  36317. + }
  36318. +
  36319. + @Test
  36320. + public void testRepeatedLoadAndGetIntArrayWithSameFixedSize() {
  36321. + ArrayTestRunner.Builder.newInstance()
  36322. + .addSrcArray(INT_ARRAY2, ARRAY2_SHAPE)
  36323. + .addSrcArray(INT_ARRAY3, ARRAY3_SHAPE)
  36324. + .setTensorBufferShape(ARRAY2_SHAPE)
  36325. + .setExpectedResults(
  36326. + /*bufferType = */ DataType.FLOAT32,
  36327. + /*expectedFloatArr=*/FLOAT_ARRAY3_ROUNDED,
  36328. + /*expectedIntArr=*/INT_ARRAY3)
  36329. + .setExpectedResults(
  36330. + /*bufferType = */ DataType.UINT8,
  36331. + /*expectedFloatArr=*/FLOAT_ARRAY3_ROUNDED,
  36332. + /*expectedIntArr=*/INT_ARRAY3)
  36333. + .build()
  36334. + .run();
  36335. + }
  36336. +
  36337. + @Test
  36338. + public void testRepeatedLoadAndGetFloatArrayWithSameFixedSize() {
  36339. + ArrayTestRunner.Builder.newInstance()
  36340. + .addSrcArray(FLOAT_ARRAY2, ARRAY2_SHAPE)
  36341. + .addSrcArray(FLOAT_ARRAY3, ARRAY3_SHAPE)
  36342. + .setTensorBufferShape(ARRAY2_SHAPE)
  36343. + .setExpectedResults(
  36344. + /*bufferType = */ DataType.FLOAT32,
  36345. + /*expectedFloatArr=*/FLOAT_ARRAY3,
  36346. + /*expectedIntArr=*/INT_ARRAY3)
  36347. + .setExpectedResults(
  36348. + /*bufferType = */ DataType.UINT8,
  36349. + /*expectedFloatArr=*/FLOAT_ARRAY3_ROUNDED,
  36350. + /*expectedIntArr=*/INT_ARRAY3)
  36351. + .build()
  36352. + .run();
  36353. + }
  36354. +
  36355. + @Test
  36356. + public void testRepeatedLoadIntArrayWithDifferentFixedSize() {
  36357. + int[] srcArr1 = INT_ARRAY1;
  36358. + int[] srcArr2 = INT_ARRAY2;
  36359. + for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
  36360. + TensorBuffer tensorBuffer =
  36361. + TensorBuffer.createFixedSize(new int[] {srcArr1.length}, dataType);
  36362. + tensorBuffer.loadArray(srcArr1, new int[] {srcArr1.length});
  36363. + // Load srcArr2 which had different size as srcArr1.
  36364. + Assert.assertThrows(IllegalArgumentException.class,
  36365. + () -> tensorBuffer.loadArray(srcArr2, new int[] {srcArr2.length}));
  36366. + }
  36367. + }
  36368. +
  36369. + @Test
  36370. + public void testRepeatedLoadFloatArrayWithDifferentFixedSize() {
  36371. + float[] srcArr1 = FLOAT_ARRAY1;
  36372. + float[] srcArr2 = FLOAT_ARRAY2;
  36373. + for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
  36374. + TensorBuffer tensorBuffer =
  36375. + TensorBuffer.createFixedSize(new int[] {srcArr1.length}, dataType);
  36376. + tensorBuffer.loadArray(srcArr1, new int[] {srcArr1.length});
  36377. + // Load srcArr2 which had different size as srcArr1.
  36378. + Assert.assertThrows(IllegalArgumentException.class,
  36379. + () -> tensorBuffer.loadArray(srcArr2, new int[] {srcArr2.length}));
  36380. + }
  36381. + }
  36382. +
  36383. + @Test
  36384. + public void testLoadAndGetIntArrayWithDynamicSize() {
  36385. + ArrayTestRunner.Builder.newInstance()
  36386. + .addSrcArray(INT_ARRAY1, ARRAY1_SHAPE)
  36387. + .setExpectedResults(
  36388. + /*bufferType = */ DataType.FLOAT32,
  36389. + /*expectedFloatArr=*/FLOAT_ARRAY1_ROUNDED,
  36390. + /*expectedIntArr=*/INT_ARRAY1)
  36391. + .setExpectedResults(
  36392. + /*bufferType = */ DataType.UINT8,
  36393. + /*expectedFloatArr=*/FLOAT_ARRAY1_CAPPED,
  36394. + /*expectedIntArr=*/INT_ARRAY1_CAPPED)
  36395. + .build()
  36396. + .run();
  36397. + }
  36398. +
  36399. + @Test
  36400. + public void testLoadAndGetFloatArrayWithDynamicSize() {
  36401. + ArrayTestRunner.Builder.newInstance()
  36402. + .addSrcArray(FLOAT_ARRAY1, ARRAY1_SHAPE)
  36403. + .setExpectedResults(
  36404. + /*bufferType = */ DataType.FLOAT32,
  36405. + /*expectedFloatArr=*/FLOAT_ARRAY1,
  36406. + /*expectedIntArr=*/INT_ARRAY1)
  36407. + .setExpectedResults(
  36408. + /*bufferType = */ DataType.UINT8,
  36409. + /*expectedFloatArr=*/FLOAT_ARRAY1_CAPPED,
  36410. + /*expectedIntArr=*/INT_ARRAY1_CAPPED)
  36411. + .build()
  36412. + .run();
  36413. + }
  36414. +
  36415. + @Test
  36416. + public void testRepeatedLoadAndGetIntArrayWithDifferentDynamicSize() {
  36417. + ArrayTestRunner.Builder.newInstance()
  36418. + .addSrcArray(INT_ARRAY1, ARRAY1_SHAPE)
  36419. + .addSrcArray(INT_ARRAY2, ARRAY2_SHAPE)
  36420. + .setExpectedResults(
  36421. + /*bufferType = */ DataType.FLOAT32,
  36422. + /*expectedFloatArr=*/FLOAT_ARRAY2_ROUNDED,
  36423. + /*expectedIntArr=*/INT_ARRAY2)
  36424. + .setExpectedResults(
  36425. + /*bufferType = */ DataType.UINT8,
  36426. + /*expectedFloatArr=*/FLOAT_ARRAY2_ROUNDED,
  36427. + /*expectedIntArr=*/INT_ARRAY2)
  36428. + .build()
  36429. + .run();
  36430. + }
  36431. +
  36432. + @Test
  36433. + public void testRepeatedLoadAndGetFloatArrayWithDifferentDynamicSize() {
  36434. + ArrayTestRunner.Builder.newInstance()
  36435. + .addSrcArray(FLOAT_ARRAY1, ARRAY1_SHAPE)
  36436. + .addSrcArray(FLOAT_ARRAY2, ARRAY2_SHAPE)
  36437. + .setExpectedResults(
  36438. + /*bufferType = */ DataType.FLOAT32,
  36439. + /*expectedFloatArr=*/FLOAT_ARRAY2,
  36440. + /*expectedIntArr=*/INT_ARRAY2)
  36441. + .setExpectedResults(
  36442. + /*bufferType = */ DataType.UINT8,
  36443. + /*expectedFloatArr=*/FLOAT_ARRAY2_ROUNDED,
  36444. + /*expectedIntArr=*/INT_ARRAY2)
  36445. + .build()
  36446. + .run();
  36447. + }
  36448. +
  36449. + @Test
  36450. + public void testGetForEmptyArrayWithFixedSizeBuffer() {
  36451. + ArrayTestRunner.Builder.newInstance()
  36452. + .setTensorBufferShape(EMPTY_ARRAY_SHAPE)
  36453. + .setExpectedResults(
  36454. + /*bufferType = */ DataType.FLOAT32,
  36455. + /*expectedFloatArr=*/EMPTY_FLOAT_ARRAY,
  36456. + /*expectedIntArr=*/EMPTY_INT_ARRAY)
  36457. + .setExpectedResults(
  36458. + /*bufferType = */ DataType.UINT8,
  36459. + /*expectedFloatArr=*/EMPTY_FLOAT_ARRAY,
  36460. + /*expectedIntArr=*/EMPTY_INT_ARRAY)
  36461. + .build()
  36462. + .run();
  36463. + }
  36464. +
  36465. + @Test
  36466. + public void testGetForEmptyArrayWithDynamicBuffer() {
  36467. + ArrayTestRunner.Builder.newInstance()
  36468. + .setExpectedResults(
  36469. + /*bufferType = */ DataType.FLOAT32,
  36470. + /*expectedFloatArr=*/EMPTY_FLOAT_ARRAY,
  36471. + /*expectedIntArr=*/EMPTY_INT_ARRAY)
  36472. + .setExpectedResults(
  36473. + /*bufferType = */ DataType.UINT8,
  36474. + /*expectedFloatArr=*/EMPTY_FLOAT_ARRAY,
  36475. + /*expectedIntArr=*/EMPTY_INT_ARRAY)
  36476. + .build()
  36477. + .run();
  36478. + }
  36479. +
  36480. + @Test
  36481. + public void testRepeatedLoadAndGetForEmptyArray() {
  36482. + ArrayTestRunner.Builder.newInstance()
  36483. + .addSrcArray(EMPTY_INT_ARRAY, EMPTY_ARRAY_SHAPE)
  36484. + .addSrcArray(FLOAT_ARRAY2, ARRAY2_SHAPE)
  36485. + .addSrcArray(EMPTY_FLOAT_ARRAY, EMPTY_ARRAY_SHAPE)
  36486. + .setExpectedResults(
  36487. + /*bufferType = */ DataType.FLOAT32,
  36488. + /*expectedFloatArr=*/EMPTY_FLOAT_ARRAY,
  36489. + /*expectedIntArr=*/EMPTY_INT_ARRAY)
  36490. + .setExpectedResults(
  36491. + /*bufferType = */ DataType.UINT8,
  36492. + /*expectedFloatArr=*/EMPTY_FLOAT_ARRAY,
  36493. + /*expectedIntArr=*/EMPTY_INT_ARRAY)
  36494. + .build()
  36495. + .run();
  36496. + }
  36497. +
  36498. + @Test
  36499. + public void testLoadNullIntArrays() {
  36500. + int[] nullArray = null;
  36501. + int[] shape = new int[] {};
  36502. + for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
  36503. + TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType);
  36504. + Assert.assertThrows(
  36505. + NullPointerException.class, () -> tensorBuffer.loadArray(nullArray, shape));
  36506. + }
  36507. + }
  36508. +
  36509. + @Test
  36510. + public void testLoadNullFloatArrays() {
  36511. + float[] nullArray = null;
  36512. + int[] shape = new int[] {};
  36513. + for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
  36514. + TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType);
  36515. + Assert.assertThrows(
  36516. + NullPointerException.class, () -> tensorBuffer.loadArray(nullArray, shape));
  36517. + }
  36518. + }
  36519. +
  36520. + @Test
  36521. + public void testLoadFloatArraysWithNullShape() {
  36522. + float[] arr = new float[] {1.0f};
  36523. + int[] nullShape = null;
  36524. + for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
  36525. + TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType);
  36526. + Assert.assertThrows(
  36527. + NullPointerException.class, () -> tensorBuffer.loadArray(arr, nullShape));
  36528. + }
  36529. + }
  36530. +
  36531. + @Test
  36532. + public void testLoadIntArraysWithNullShape() {
  36533. + int[] arr = new int[] {1};
  36534. + int[] nullShape = null;
  36535. + for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
  36536. + TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType);
  36537. + Assert.assertThrows(
  36538. + NullPointerException.class, () -> tensorBuffer.loadArray(arr, nullShape));
  36539. + }
  36540. + }
  36541. +
  36542. + @Test
  36543. + public void testLoadIntArraysWithoutShapeAndArrayDoesNotMatchShape() {
  36544. + for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
  36545. + TensorBuffer fixedTensorBuffer = TensorBuffer.createFixedSize(ARRAY1_SHAPE, dataType);
  36546. + Assert.assertThrows(
  36547. + IllegalArgumentException.class, () -> fixedTensorBuffer.loadArray(INT_ARRAY2));
  36548. + }
  36549. + }
  36550. +
  36551. + @Test
  36552. + public void testLoadFloatArraysWithoutShapeAndArrayDoesNotMatchShape() {
  36553. + for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
  36554. + TensorBuffer fixedTensorBuffer = TensorBuffer.createFixedSize(ARRAY1_SHAPE, dataType);
  36555. + Assert.assertThrows(IllegalArgumentException.class,
  36556. + () -> fixedTensorBuffer.loadArray(FLOAT_ARRAY2));
  36557. + }
  36558. + }
  36559. +
  36560. + @Test
  36561. + public void testLoadByteBufferForNullBuffer() {
  36562. + ByteBuffer byteBuffer = null;
  36563. + int[] shape = new int[] {};
  36564. + for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
  36565. + TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType);
  36566. + Assert.assertThrows(
  36567. + NullPointerException.class, () -> tensorBuffer.loadBuffer(byteBuffer, shape));
  36568. + }
  36569. + }
  36570. +
  36571. + @Test
  36572. + public void testLoadByteBufferForEmptyBuffer() {
  36573. + for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
  36574. + TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType);
  36575. + tensorBuffer.loadBuffer(EMPTY_BYTE_BUFFER, EMPTY_ARRAY_SHAPE);
  36576. + assertThat(tensorBuffer.getFlatSize()).isEqualTo(0);
  36577. + }
  36578. + }
  36579. +
  36580. + @Test
  36581. + public void testLoadByteBufferWithDifferentFixedSize() {
  36582. + // Create a fixed-size TensorBuffer with size 2, and load a ByteBuffer with size 5.
  36583. + int[] tensorBufferShape = new int[] {2};
  36584. + TensorBuffer tensorBuffer =
  36585. + TensorBuffer.createFixedSize(tensorBufferShape, DataType.FLOAT32);
  36586. + Assert.assertThrows(IllegalArgumentException.class,
  36587. + () -> tensorBuffer.loadBuffer(FLOAT_BYTE_BUFFER1, ARRAY1_SHAPE));
  36588. + }
  36589. +
  36590. + @Test
  36591. + public void testLoadByteBufferWithMisMatchDataType() {
  36592. + TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.FLOAT32);
  36593. + int[] wrongShape = new int[] {1};
  36594. + // Size of INT_BYTE_BUFFER is 8 bytes. It does not match the specified shape.
  36595. + Assert.assertThrows(IllegalArgumentException.class,
  36596. + () -> tensorBuffer.loadBuffer(INT_BYTE_BUFFER2, wrongShape));
  36597. + }
  36598. +
  36599. + @Test
  36600. + public void testLoadByteBufferForTensorBufferFloat() {
  36601. + TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.FLOAT32);
  36602. + tensorBuffer.loadBuffer(FLOAT_BYTE_BUFFER1, ARRAY1_SHAPE);
  36603. + assertThat(tensorBuffer.getFloatArray()).isEqualTo(FLOAT_ARRAY1);
  36604. + assertThat(tensorBuffer.getShape()).isEqualTo(ARRAY1_SHAPE);
  36605. + }
  36606. +
  36607. + @Test
  36608. + public void testLoadByteBufferForTensorBufferUint8() {
  36609. + TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.UINT8);
  36610. + tensorBuffer.loadBuffer(INT_BYTE_BUFFER2, ARRAY2_SHAPE);
  36611. + assertThat(tensorBuffer.getIntArray()).isEqualTo(INT_ARRAY2);
  36612. + assertThat(tensorBuffer.getShape()).isEqualTo(ARRAY2_SHAPE);
  36613. + }
  36614. +
  36615. + @Test
  36616. + public void testGetFloatValueWithInvalidIndex() {
  36617. + float[] arrayWithSixElements = FLOAT_ARRAY1;
  36618. + int[] shapeOfArrayWithSixElements = ARRAY1_SHAPE;
  36619. + int[] invalidIndexes = {-1, 7};
  36620. + for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
  36621. + TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType);
  36622. + tensorBuffer.loadArray(arrayWithSixElements, shapeOfArrayWithSixElements);
  36623. + for (int invalidIndex : invalidIndexes) {
  36624. + Assert.assertThrows(IndexOutOfBoundsException.class,
  36625. + () -> tensorBuffer.getFloatValue(invalidIndex));
  36626. + }
  36627. + }
  36628. + }
  36629. +
  36630. + @Test
  36631. + public void testGetFloatValueFromScalarWithInvalidIndex() {
  36632. + int[] shape = new int[] {};
  36633. + float[] arr = new float[] {10.0f};
  36634. + int[] invalidIndexes =
  36635. + new int[] {-1, 1}; // -1 is negative, and 1 is not smaller than the flatsize.
  36636. + for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
  36637. + TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType);
  36638. + tensorBuffer.loadArray(arr, shape);
  36639. + for (int invalidIndex : invalidIndexes) {
  36640. + Assert.assertThrows(IndexOutOfBoundsException.class,
  36641. + () -> tensorBuffer.getFloatValue(invalidIndex));
  36642. + }
  36643. + }
  36644. + }
  36645. +
  36646. + @Test
  36647. + public void testGetIntValueWithInvalidIndex() {
  36648. + float[] arrayWithSixElements = FLOAT_ARRAY1;
  36649. + int[] shapeOfArrayWithSixElements = ARRAY1_SHAPE;
  36650. + int[] invalidIndexes = {-1, 7};
  36651. + for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
  36652. + TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType);
  36653. + tensorBuffer.loadArray(arrayWithSixElements, shapeOfArrayWithSixElements);
  36654. + for (int invalidIndex : invalidIndexes) {
  36655. + Assert.assertThrows(IndexOutOfBoundsException.class,
  36656. + () -> tensorBuffer.getIntValue(invalidIndex));
  36657. + }
  36658. + }
  36659. + }
  36660. +
  36661. + @Test
  36662. + public void testGetIntValueFromScalarWithInvalidIndex() {
  36663. + int[] shape = new int[] {};
  36664. + float[] arr = new float[] {10.0f};
  36665. + int[] invalidIndexes =
  36666. + new int[] {-1, 1}; // -1 is negative, and 1 is not smaller than the flatsize.
  36667. + for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
  36668. + TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType);
  36669. + tensorBuffer.loadArray(arr, shape);
  36670. + for (int invalidIndex : invalidIndexes) {
  36671. + Assert.assertThrows(IndexOutOfBoundsException.class,
  36672. + () -> tensorBuffer.getIntValue(invalidIndex));
  36673. + }
  36674. + }
  36675. + }
  36676. +
  36677. + @Test
  36678. + public void testLoadByteBufferSliceForTensorBufferFloat() {
  36679. + TensorBuffer original = TensorBuffer.createDynamic(DataType.FLOAT32);
  36680. + original.loadArray(new float[] {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}, new int[] {6});
  36681. + ByteBuffer buffer = original.getBuffer();
  36682. + // Slice original buffer to 3 sub-buffer, each of which has 2 element
  36683. + int numBuffers = 3;
  36684. + int numElements = 2;
  36685. + int subArrayLength = numElements * original.getTypeSize();
  36686. + TensorBuffer tensorSlice = TensorBuffer.createDynamic(original.getDataType());
  36687. + for (int i = 0; i < numBuffers; i++) {
  36688. + buffer.position(i * subArrayLength);
  36689. + ByteBuffer subBuffer = buffer.slice();
  36690. + // ByteBuffer.slice doesn't keep order.
  36691. + subBuffer.order(buffer.order()).limit(subArrayLength);
  36692. + tensorSlice.loadBuffer(subBuffer, new int[] {numElements});
  36693. + float[] arraySlice = tensorSlice.getFloatArray();
  36694. + assertThat(arraySlice.length).isEqualTo(numElements);
  36695. + assertThat(arraySlice[0]).isEqualTo(i * numElements + 1);
  36696. + assertThat(arraySlice[1]).isEqualTo(i * numElements + 2);
  36697. + }
  36698. + }
  36699. +
  36700. + @Test
  36701. + public void testLoadByteBufferSliceForTensorBufferUInt8() {
  36702. + TensorBuffer original = TensorBuffer.createDynamic(DataType.UINT8);
  36703. + original.loadArray(new int[] {1, 2, 3, 4, 5, 6}, new int[] {6});
  36704. + ByteBuffer buffer = original.getBuffer();
  36705. + // Slice original buffer to 3 sub-buffer, each of which has 2 element
  36706. + int numBuffers = 3;
  36707. + int numElements = 2;
  36708. + int subArrayLength = numElements * original.getTypeSize();
  36709. + TensorBuffer tensorSlice = TensorBuffer.createDynamic(original.getDataType());
  36710. + for (int i = 0; i < numBuffers; i++) {
  36711. + buffer.position(i * subArrayLength);
  36712. + ByteBuffer subBuffer = buffer.slice();
  36713. + // ByteBuffer.slice doesn't keep order.
  36714. + subBuffer.order(buffer.order()).limit(subArrayLength);
  36715. + tensorSlice.loadBuffer(subBuffer, new int[] {numElements});
  36716. + int[] arraySlice = tensorSlice.getIntArray();
  36717. + assertThat(arraySlice.length).isEqualTo(numElements);
  36718. + assertThat(arraySlice[0]).isEqualTo(i * numElements + 1);
  36719. + assertThat(arraySlice[1]).isEqualTo(i * numElements + 2);
  36720. + }
  36721. + }
  36722. +
  36723. + @Test
  36724. + public void getShapeFailsAfterByteBufferChanged() {
  36725. + TensorBuffer tensorBuffer =
  36726. + TensorBuffer.createFixedSize(new int[] {3, 2}, DataType.FLOAT32);
  36727. + ByteBuffer byteBuffer = tensorBuffer.getBuffer();
  36728. + byteBuffer.limit(5);
  36729. +
  36730. + IllegalStateException exception =
  36731. + assertThrows(IllegalStateException.class, tensorBuffer::getShape);
  36732. + assertThat(exception).hasMessageThat().contains(
  36733. + "The size of underlying ByteBuffer (5) and the shape ([3, 2]) do not match. The"
  36734. + " ByteBuffer may have been changed.");
  36735. - }
  36736. -
  36737. - @Test
  36738. - public void getFlatSizeFailsAfterByteBufferChanged() {
  36739. - TensorBuffer tensorBuffer = TensorBuffer.createFixedSize(new int[] {3, 2}, DataType.FLOAT32);
  36740. - ByteBuffer byteBuffer = tensorBuffer.getBuffer();
  36741. - byteBuffer.limit(5);
  36742. -
  36743. - IllegalStateException exception =
  36744. - assertThrows(IllegalStateException.class, tensorBuffer::getFlatSize);
  36745. - assertThat(exception)
  36746. - .hasMessageThat()
  36747. - .contains(
  36748. - "The size of underlying ByteBuffer (5) and the shape ([3, 2]) do not match. The"
  36749. + }
  36750. +
  36751. + @Test
  36752. + public void getFlatSizeFailsAfterByteBufferChanged() {
  36753. + TensorBuffer tensorBuffer =
  36754. + TensorBuffer.createFixedSize(new int[] {3, 2}, DataType.FLOAT32);
  36755. + ByteBuffer byteBuffer = tensorBuffer.getBuffer();
  36756. + byteBuffer.limit(5);
  36757. +
  36758. + IllegalStateException exception =
  36759. + assertThrows(IllegalStateException.class, tensorBuffer::getFlatSize);
  36760. + assertThat(exception).hasMessageThat().contains(
  36761. + "The size of underlying ByteBuffer (5) and the shape ([3, 2]) do not match. The"
  36762. + " ByteBuffer may have been changed.");
  36763. - }
  36764. -
  36765. - @Test
  36766. - public void loadReadOnlyBuffersCopiesOnWrite() {
  36767. - TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.UINT8);
  36768. - ByteBuffer originalByteBuffer = ByteBuffer.allocateDirect(1);
  36769. - originalByteBuffer.put(new byte[]{99});
  36770. - originalByteBuffer.rewind();
  36771. - ByteBuffer readOnlyByteBuffer = originalByteBuffer.asReadOnlyBuffer();
  36772. -
  36773. - tensorBuffer.loadBuffer(readOnlyByteBuffer, new int[]{1});
  36774. - assertThat(tensorBuffer.getBuffer()).isSameInstanceAs(readOnlyByteBuffer);
  36775. -
  36776. - tensorBuffer.loadArray(new int[]{42});
  36777. - assertThat(tensorBuffer.getBuffer()).isNotSameInstanceAs(readOnlyByteBuffer);
  36778. - assertThat(tensorBuffer.getBuffer().get(0)).isEqualTo(42); // updated
  36779. - assertThat(originalByteBuffer.get(0)).isEqualTo(99); // original one not changed
  36780. - }
  36781. + }
  36782. +
  36783. + @Test
  36784. + public void loadReadOnlyBuffersCopiesOnWrite() {
  36785. + TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.UINT8);
  36786. + ByteBuffer originalByteBuffer = ByteBuffer.allocateDirect(1);
  36787. + originalByteBuffer.put(new byte[] {99});
  36788. + originalByteBuffer.rewind();
  36789. + ByteBuffer readOnlyByteBuffer = originalByteBuffer.asReadOnlyBuffer();
  36790. +
  36791. + tensorBuffer.loadBuffer(readOnlyByteBuffer, new int[] {1});
  36792. + assertThat(tensorBuffer.getBuffer()).isSameInstanceAs(readOnlyByteBuffer);
  36793. +
  36794. + tensorBuffer.loadArray(new int[] {42});
  36795. + assertThat(tensorBuffer.getBuffer()).isNotSameInstanceAs(readOnlyByteBuffer);
  36796. + assertThat(tensorBuffer.getBuffer().get(0)).isEqualTo(42); // updated
  36797. + assertThat(originalByteBuffer.get(0)).isEqualTo(99); // original one not changed
  36798. + }
  36799. }
  36800. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/tensorbuffer/TensorBufferUint8Test.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/tensorbuffer/TensorBufferUint8Test.java
  36801. index e843133275d61..1921f4e467d01 100644
  36802. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/tensorbuffer/TensorBufferUint8Test.java
  36803. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/tensorbuffer/TensorBufferUint8Test.java
  36804. @@ -26,51 +26,51 @@ import org.tensorflow.lite.DataType;
  36805. /** Tests of {@link org.tensorflow.lite.support.tensorbuffer.TensorBufferUint8}. */
  36806. @RunWith(RobolectricTestRunner.class)
  36807. public final class TensorBufferUint8Test {
  36808. - @Test
  36809. - public void testCreateDynamic() {
  36810. - TensorBufferUint8 tensorBufferUint8 = new TensorBufferUint8();
  36811. - assertThat(tensorBufferUint8).isNotNull();
  36812. - }
  36813. + @Test
  36814. + public void testCreateDynamic() {
  36815. + TensorBufferUint8 tensorBufferUint8 = new TensorBufferUint8();
  36816. + assertThat(tensorBufferUint8).isNotNull();
  36817. + }
  36818. - @Test
  36819. - public void testCreateFixedSize() {
  36820. - int[] shape = new int[] {1, 2, 3};
  36821. - TensorBufferUint8 tensorBufferUint8 = new TensorBufferUint8(shape);
  36822. - assertThat(tensorBufferUint8).isNotNull();
  36823. - assertThat(tensorBufferUint8.getFlatSize()).isEqualTo(6);
  36824. - }
  36825. + @Test
  36826. + public void testCreateFixedSize() {
  36827. + int[] shape = new int[] {1, 2, 3};
  36828. + TensorBufferUint8 tensorBufferUint8 = new TensorBufferUint8(shape);
  36829. + assertThat(tensorBufferUint8).isNotNull();
  36830. + assertThat(tensorBufferUint8.getFlatSize()).isEqualTo(6);
  36831. + }
  36832. - @Test
  36833. - public void testCreateFixedSizeWithScalarShape() {
  36834. - int[] shape = new int[] {};
  36835. - TensorBufferUint8 tensorBufferUint8 = new TensorBufferUint8(shape);
  36836. - assertThat(tensorBufferUint8).isNotNull();
  36837. - assertThat(tensorBufferUint8.getFlatSize()).isEqualTo(1);
  36838. - }
  36839. + @Test
  36840. + public void testCreateFixedSizeWithScalarShape() {
  36841. + int[] shape = new int[] {};
  36842. + TensorBufferUint8 tensorBufferUint8 = new TensorBufferUint8(shape);
  36843. + assertThat(tensorBufferUint8).isNotNull();
  36844. + assertThat(tensorBufferUint8.getFlatSize()).isEqualTo(1);
  36845. + }
  36846. - @Test
  36847. - public void testCreateWithNullShape() {
  36848. - int[] shape = null;
  36849. - Assert.assertThrows(NullPointerException.class, () -> new TensorBufferUint8(shape));
  36850. - }
  36851. + @Test
  36852. + public void testCreateWithNullShape() {
  36853. + int[] shape = null;
  36854. + Assert.assertThrows(NullPointerException.class, () -> new TensorBufferUint8(shape));
  36855. + }
  36856. - @Test
  36857. - public void testCreateWithInvalidShape() {
  36858. - int[] shape = new int[] {1, -1, 2};
  36859. - Assert.assertThrows(IllegalArgumentException.class, () -> new TensorBufferUint8(shape));
  36860. - }
  36861. + @Test
  36862. + public void testCreateWithInvalidShape() {
  36863. + int[] shape = new int[] {1, -1, 2};
  36864. + Assert.assertThrows(IllegalArgumentException.class, () -> new TensorBufferUint8(shape));
  36865. + }
  36866. - @Test
  36867. - public void testCreateUsingShapeWithZero() {
  36868. - int[] shape = new int[] {1, 0, 2};
  36869. - TensorBufferUint8 tensorBufferUint8 = new TensorBufferUint8(shape);
  36870. - assertThat(tensorBufferUint8).isNotNull();
  36871. - assertThat(tensorBufferUint8.getFlatSize()).isEqualTo(0);
  36872. - }
  36873. + @Test
  36874. + public void testCreateUsingShapeWithZero() {
  36875. + int[] shape = new int[] {1, 0, 2};
  36876. + TensorBufferUint8 tensorBufferUint8 = new TensorBufferUint8(shape);
  36877. + assertThat(tensorBufferUint8).isNotNull();
  36878. + assertThat(tensorBufferUint8.getFlatSize()).isEqualTo(0);
  36879. + }
  36880. - @Test
  36881. - public void testGetDataType() {
  36882. - TensorBufferUint8 tensorBufferUint8 = new TensorBufferUint8();
  36883. - assertThat(tensorBufferUint8.getDataType()).isEqualTo(DataType.UINT8);
  36884. - }
  36885. + @Test
  36886. + public void testGetDataType() {
  36887. + TensorBufferUint8 tensorBufferUint8 = new TensorBufferUint8();
  36888. + assertThat(tensorBufferUint8.getDataType()).isEqualTo(DataType.UINT8);
  36889. + }
  36890. }
  36891. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/audio/classifier/audio_classifier_jni.cc b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/audio/classifier/audio_classifier_jni.cc
  36892. index d62da546a484b..c3c21fa43ab49 100644
  36893. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/audio/classifier/audio_classifier_jni.cc
  36894. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/audio/classifier/audio_classifier_jni.cc
  36895. @@ -134,7 +134,8 @@ jobject ConvertToClassificationResults(JNIEnv* env,
  36896. }
  36897. // Creates an AudioClassifierOptions proto based on the Java class.
  36898. -AudioClassifierOptions ConvertToProtoOptions(JNIEnv* env, jobject java_options,
  36899. +AudioClassifierOptions ConvertToProtoOptions(JNIEnv* env,
  36900. + jobject java_options,
  36901. jlong base_options_handle) {
  36902. AudioClassifierOptions proto_options;
  36903. @@ -214,7 +215,9 @@ jlong CreateAudioClassifierFromOptions(JNIEnv* env,
  36904. extern "C" JNIEXPORT void JNICALL
  36905. Java_org_tensorflow_lite_task_audio_classifier_AudioClassifier_deinitJni(
  36906. - JNIEnv* env, jobject thiz, jlong native_handle) {
  36907. + JNIEnv* env,
  36908. + jobject thiz,
  36909. + jlong native_handle) {
  36910. delete reinterpret_cast<AudioClassifier*>(native_handle);
  36911. }
  36912. @@ -223,9 +226,13 @@ Java_org_tensorflow_lite_task_audio_classifier_AudioClassifier_deinitJni(
  36913. // values will be ignored.
  36914. extern "C" JNIEXPORT jlong JNICALL
  36915. Java_org_tensorflow_lite_task_audio_classifier_AudioClassifier_initJniWithModelFdAndOptions(
  36916. - JNIEnv* env, jclass thiz, jint file_descriptor,
  36917. - jlong file_descriptor_length, jlong file_descriptor_offset,
  36918. - jobject java_options, jlong base_options_handle) {
  36919. + JNIEnv* env,
  36920. + jclass thiz,
  36921. + jint file_descriptor,
  36922. + jlong file_descriptor_length,
  36923. + jlong file_descriptor_offset,
  36924. + jobject java_options,
  36925. + jlong base_options_handle) {
  36926. AudioClassifierOptions proto_options =
  36927. ConvertToProtoOptions(env, java_options, base_options_handle);
  36928. auto file_descriptor_meta = proto_options.mutable_base_options()
  36929. @@ -243,7 +250,10 @@ Java_org_tensorflow_lite_task_audio_classifier_AudioClassifier_initJniWithModelF
  36930. extern "C" JNIEXPORT jlong JNICALL
  36931. Java_org_tensorflow_lite_task_audio_classifier_AudioClassifier_initJniWithByteBuffer(
  36932. - JNIEnv* env, jclass thiz, jobject model_buffer, jobject java_options,
  36933. + JNIEnv* env,
  36934. + jclass thiz,
  36935. + jobject model_buffer,
  36936. + jobject java_options,
  36937. jlong base_options_handle) {
  36938. AudioClassifierOptions proto_options =
  36939. ConvertToProtoOptions(env, java_options, base_options_handle);
  36940. @@ -262,7 +272,9 @@ Java_org_tensorflow_lite_task_audio_classifier_AudioClassifier_initJniWithByteBu
  36941. // caching it in JAVA layer.
  36942. extern "C" JNIEXPORT jlong JNICALL
  36943. Java_org_tensorflow_lite_task_audio_classifier_AudioClassifier_getRequiredSampleRateNative(
  36944. - JNIEnv* env, jclass thiz, jlong native_handle) {
  36945. + JNIEnv* env,
  36946. + jclass thiz,
  36947. + jlong native_handle) {
  36948. auto* classifier = reinterpret_cast<AudioClassifier*>(native_handle);
  36949. StatusOr<AudioBuffer::AudioFormat> format_or =
  36950. classifier->GetRequiredAudioFormat();
  36951. @@ -279,7 +291,9 @@ Java_org_tensorflow_lite_task_audio_classifier_AudioClassifier_getRequiredSample
  36952. extern "C" JNIEXPORT jlong JNICALL
  36953. Java_org_tensorflow_lite_task_audio_classifier_AudioClassifier_getRequiredChannelsNative(
  36954. - JNIEnv* env, jclass thiz, jlong native_handle) {
  36955. + JNIEnv* env,
  36956. + jclass thiz,
  36957. + jlong native_handle) {
  36958. auto* classifier = reinterpret_cast<AudioClassifier*>(native_handle);
  36959. StatusOr<AudioBuffer::AudioFormat> format_or =
  36960. classifier->GetRequiredAudioFormat();
  36961. @@ -296,15 +310,21 @@ Java_org_tensorflow_lite_task_audio_classifier_AudioClassifier_getRequiredChanne
  36962. extern "C" JNIEXPORT jlong JNICALL
  36963. Java_org_tensorflow_lite_task_audio_classifier_AudioClassifier_getRequiredInputBufferSizeNative(
  36964. - JNIEnv* env, jclass thiz, jlong native_handle) {
  36965. + JNIEnv* env,
  36966. + jclass thiz,
  36967. + jlong native_handle) {
  36968. auto* classifier = reinterpret_cast<AudioClassifier*>(native_handle);
  36969. return classifier->GetRequiredInputBufferSize();
  36970. }
  36971. extern "C" JNIEXPORT jobject JNICALL
  36972. Java_org_tensorflow_lite_task_audio_classifier_AudioClassifier_classifyNative(
  36973. - JNIEnv* env, jclass thiz, jlong native_handle, jbyteArray java_array,
  36974. - jint channels, jint sample_rate) {
  36975. + JNIEnv* env,
  36976. + jclass thiz,
  36977. + jlong native_handle,
  36978. + jbyteArray java_array,
  36979. + jint channels,
  36980. + jint sample_rate) {
  36981. // Get the primitive native array. Depending on the JAVA runtime, the returned
  36982. // array might be a copy of the JAVA array (or not).
  36983. jbyte* native_array = env->GetByteArrayElements(java_array, nullptr);
  36984. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/core/task_jni_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/core/task_jni_utils.cc
  36985. index 2fd1d7ca9a593..75f93d6f2e458 100644
  36986. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/core/task_jni_utils.cc
  36987. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/core/task_jni_utils.cc
  36988. @@ -30,7 +30,10 @@ using ::tflite::task::core::BaseOptions;
  36989. extern "C" JNIEXPORT jlong JNICALL
  36990. Java_org_tensorflow_lite_task_core_TaskJniUtils_createProtoBaseOptions(
  36991. - JNIEnv* env, jclass thiz, jint delegate, jint num_threads) {
  36992. + JNIEnv* env,
  36993. + jclass thiz,
  36994. + jint delegate,
  36995. + jint num_threads) {
  36996. StatusOr<Delegate> delegate_proto_or = ConvertToProtoDelegate(delegate);
  36997. if (!delegate_proto_or.ok()) {
  36998. ThrowException(env, kIllegalStateException,
  36999. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert/bert_nl_classifier_jni.cc b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert/bert_nl_classifier_jni.cc
  37000. index 6657ef4ca2d95..2daacdf893903 100644
  37001. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert/bert_nl_classifier_jni.cc
  37002. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert/bert_nl_classifier_jni.cc
  37003. @@ -32,7 +32,9 @@ using ::tflite::task::text::BertNLClassifierOptions;
  37004. using ::tflite::task::text::nlclassifier::RunClassifier;
  37005. BertNLClassifierOptions ConvertJavaBertNLClassifierOptions(
  37006. - JNIEnv* env, jobject java_options, jlong base_options_handle) {
  37007. + JNIEnv* env,
  37008. + jobject java_options,
  37009. + jlong base_options_handle) {
  37010. BertNLClassifierOptions proto_options;
  37011. if (base_options_handle != kInvalidPointer) {
  37012. @@ -47,13 +49,18 @@ BertNLClassifierOptions ConvertJavaBertNLClassifierOptions(
  37013. extern "C" JNIEXPORT void JNICALL
  37014. Java_org_tensorflow_lite_task_text_nlclassifier_BertNLClassifier_deinitJni(
  37015. - JNIEnv* env, jobject thiz, jlong native_handle) {
  37016. + JNIEnv* env,
  37017. + jobject thiz,
  37018. + jlong native_handle) {
  37019. delete reinterpret_cast<BertNLClassifier*>(native_handle);
  37020. }
  37021. extern "C" JNIEXPORT jlong JNICALL
  37022. Java_org_tensorflow_lite_task_text_nlclassifier_BertNLClassifier_initJniWithByteBuffer(
  37023. - JNIEnv* env, jclass thiz, jobject model_buffer, jobject java_options,
  37024. + JNIEnv* env,
  37025. + jclass thiz,
  37026. + jobject model_buffer,
  37027. + jobject java_options,
  37028. jlong base_options_handle) {
  37029. BertNLClassifierOptions proto_options = ConvertJavaBertNLClassifierOptions(
  37030. env, java_options, base_options_handle);
  37031. @@ -76,7 +83,10 @@ Java_org_tensorflow_lite_task_text_nlclassifier_BertNLClassifier_initJniWithByte
  37032. extern "C" JNIEXPORT jlong JNICALL
  37033. Java_org_tensorflow_lite_task_text_nlclassifier_BertNLClassifier_initJniWithFileDescriptor(
  37034. - JNIEnv* env, jclass thiz, jint fd, jobject java_options,
  37035. + JNIEnv* env,
  37036. + jclass thiz,
  37037. + jint fd,
  37038. + jobject java_options,
  37039. jlong base_options_handle) {
  37040. BertNLClassifierOptions proto_options = ConvertJavaBertNLClassifierOptions(
  37041. env, java_options, base_options_handle);
  37042. @@ -100,6 +110,9 @@ Java_org_tensorflow_lite_task_text_nlclassifier_BertNLClassifier_initJniWithFile
  37043. extern "C" JNIEXPORT jobject JNICALL
  37044. Java_org_tensorflow_lite_task_text_nlclassifier_BertNLClassifier_classifyNative(
  37045. - JNIEnv* env, jclass clazz, jlong native_handle, jstring text) {
  37046. + JNIEnv* env,
  37047. + jclass clazz,
  37048. + jlong native_handle,
  37049. + jstring text) {
  37050. return RunClassifier(env, native_handle, text);
  37051. }
  37052. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni.cc b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni.cc
  37053. index f6d34a5f74e2b..4c71a80ea1528 100644
  37054. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni.cc
  37055. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni.cc
  37056. @@ -94,14 +94,19 @@ NLClassifierOptions ConvertToProtoOptions(JNIEnv* env,
  37057. extern "C" JNIEXPORT void JNICALL
  37058. Java_org_tensorflow_lite_task_text_nlclassifier_NLClassifier_deinitJni(
  37059. - JNIEnv* env, jobject thiz, jlong native_handle) {
  37060. + JNIEnv* env,
  37061. + jobject thiz,
  37062. + jlong native_handle) {
  37063. delete reinterpret_cast<NLClassifier*>(native_handle);
  37064. }
  37065. extern "C" JNIEXPORT jlong JNICALL
  37066. Java_org_tensorflow_lite_task_text_nlclassifier_NLClassifier_initJniWithByteBuffer(
  37067. - JNIEnv* env, jclass thiz, jobject nl_classifier_options,
  37068. - jobject model_buffer, jlong base_options_handle) {
  37069. + JNIEnv* env,
  37070. + jclass thiz,
  37071. + jobject nl_classifier_options,
  37072. + jobject model_buffer,
  37073. + jlong base_options_handle) {
  37074. auto model = GetMappedFileBuffer(env, model_buffer);
  37075. tflite::support::StatusOr<std::unique_ptr<NLClassifier>> classifier_or;
  37076. @@ -125,7 +130,10 @@ Java_org_tensorflow_lite_task_text_nlclassifier_NLClassifier_initJniWithByteBuff
  37077. extern "C" JNIEXPORT jlong JNICALL
  37078. Java_org_tensorflow_lite_task_text_nlclassifier_NLClassifier_initJniWithFileDescriptor(
  37079. - JNIEnv* env, jclass thiz, jobject nl_classifier_options, jint fd,
  37080. + JNIEnv* env,
  37081. + jclass thiz,
  37082. + jobject nl_classifier_options,
  37083. + jint fd,
  37084. jlong base_options_handle) {
  37085. tflite::support::StatusOr<std::unique_ptr<NLClassifier>> classifier_or;
  37086. @@ -151,6 +159,9 @@ Java_org_tensorflow_lite_task_text_nlclassifier_NLClassifier_initJniWithFileDesc
  37087. extern "C" JNIEXPORT jobject JNICALL
  37088. Java_org_tensorflow_lite_task_text_nlclassifier_NLClassifier_classifyNative(
  37089. - JNIEnv* env, jclass thiz, jlong native_handle, jstring text) {
  37090. + JNIEnv* env,
  37091. + jclass thiz,
  37092. + jlong native_handle,
  37093. + jstring text) {
  37094. return RunClassifier(env, native_handle, text);
  37095. }
  37096. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/qa/bert_question_answerer_jni.cc b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/qa/bert_question_answerer_jni.cc
  37097. index 1ff0d9fc46161..b77746a2eee68 100644
  37098. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/qa/bert_question_answerer_jni.cc
  37099. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/qa/bert_question_answerer_jni.cc
  37100. @@ -52,14 +52,19 @@ BertQuestionAnswererOptions ConvertToProtoOptions(jlong base_options_handle) {
  37101. extern "C" JNIEXPORT void JNICALL
  37102. Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_deinitJni(
  37103. - JNIEnv* env, jobject thiz, jlong native_handle) {
  37104. + JNIEnv* env,
  37105. + jobject thiz,
  37106. + jlong native_handle) {
  37107. delete reinterpret_cast<QuestionAnswerer*>(native_handle);
  37108. }
  37109. extern "C" JNIEXPORT jlong JNICALL
  37110. Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_initJniWithFileDescriptor(
  37111. - JNIEnv* env, jclass thiz, jint file_descriptor,
  37112. - jlong file_descriptor_length, jlong file_descriptor_offset,
  37113. + JNIEnv* env,
  37114. + jclass thiz,
  37115. + jint file_descriptor,
  37116. + jlong file_descriptor_length,
  37117. + jlong file_descriptor_offset,
  37118. jlong base_options_handle) {
  37119. BertQuestionAnswererOptions proto_options =
  37120. ConvertToProtoOptions(base_options_handle);
  37121. @@ -89,7 +94,9 @@ Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_initJniWithFileDescri
  37122. extern "C" JNIEXPORT jlong JNICALL
  37123. Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_initJniWithBertByteBuffers(
  37124. - JNIEnv* env, jclass thiz, jobjectArray model_buffers) {
  37125. + JNIEnv* env,
  37126. + jclass thiz,
  37127. + jobjectArray model_buffers) {
  37128. absl::string_view model =
  37129. GetMappedFileBuffer(env, env->GetObjectArrayElement(model_buffers, 0));
  37130. absl::string_view vocab =
  37131. @@ -111,7 +118,9 @@ Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_initJniWithBertByteBu
  37132. extern "C" JNIEXPORT jlong JNICALL
  37133. Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_initJniWithAlbertByteBuffers(
  37134. - JNIEnv* env, jclass thiz, jobjectArray model_buffers) {
  37135. + JNIEnv* env,
  37136. + jclass thiz,
  37137. + jobjectArray model_buffers) {
  37138. absl::string_view model =
  37139. GetMappedFileBuffer(env, env->GetObjectArrayElement(model_buffers, 0));
  37140. absl::string_view sp_model =
  37141. @@ -133,7 +142,10 @@ Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_initJniWithAlbertByte
  37142. extern "C" JNIEXPORT jobject JNICALL
  37143. Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_answerNative(
  37144. - JNIEnv* env, jclass thiz, jlong native_handle, jstring context,
  37145. + JNIEnv* env,
  37146. + jclass thiz,
  37147. + jlong native_handle,
  37148. + jstring context,
  37149. jstring question) {
  37150. auto* question_answerer = reinterpret_cast<QuestionAnswerer*>(native_handle);
  37151. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/searcher/text_searcher_jni.cc b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/searcher/text_searcher_jni.cc
  37152. index 8573b0f444626..c207755d3393f 100644
  37153. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/searcher/text_searcher_jni.cc
  37154. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/searcher/text_searcher_jni.cc
  37155. @@ -48,7 +48,8 @@ using ::tflite::task::text::TextSearcherOptions;
  37156. // Creates an TextSearcherOptions proto based on the Java class.
  37157. TextSearcherOptions ConvertToProtoOptions(jlong base_options_handle,
  37158. - bool l2_normalize, bool quantize,
  37159. + bool l2_normalize,
  37160. + bool quantize,
  37161. int index_descriptor,
  37162. int max_results) {
  37163. TextSearcherOptions proto_options;
  37164. @@ -120,7 +121,9 @@ jobject ConvertToSearchResults(JNIEnv* env, const SearchResult& results) {
  37165. extern "C" JNIEXPORT void JNICALL
  37166. Java_org_tensorflow_lite_task_text_searcher_TextSearcher_deinitJni(
  37167. - JNIEnv* env, jobject thiz, jlong native_handle) {
  37168. + JNIEnv* env,
  37169. + jobject thiz,
  37170. + jlong native_handle) {
  37171. delete reinterpret_cast<TextSearcher*>(native_handle);
  37172. }
  37173. @@ -129,10 +132,16 @@ Java_org_tensorflow_lite_task_text_searcher_TextSearcher_deinitJni(
  37174. // values will be ignored.
  37175. extern "C" JNIEXPORT jlong JNICALL
  37176. Java_org_tensorflow_lite_task_text_searcher_TextSearcher_initJniWithModelFdAndOptions(
  37177. - JNIEnv* env, jclass thiz, jint model_descriptor,
  37178. - jlong model_descriptor_length, jlong model_descriptor_offset,
  37179. - jlong base_options_handle, bool l2_normalize, bool quantize,
  37180. - jint index_descriptor, int max_results) {
  37181. + JNIEnv* env,
  37182. + jclass thiz,
  37183. + jint model_descriptor,
  37184. + jlong model_descriptor_length,
  37185. + jlong model_descriptor_offset,
  37186. + jlong base_options_handle,
  37187. + bool l2_normalize,
  37188. + bool quantize,
  37189. + jint index_descriptor,
  37190. + int max_results) {
  37191. TextSearcherOptions proto_options =
  37192. ConvertToProtoOptions(base_options_handle, l2_normalize, quantize,
  37193. index_descriptor, max_results);
  37194. @@ -152,8 +161,14 @@ Java_org_tensorflow_lite_task_text_searcher_TextSearcher_initJniWithModelFdAndOp
  37195. extern "C" JNIEXPORT jlong JNICALL
  37196. Java_org_tensorflow_lite_task_text_searcher_TextSearcher_initJniWithByteBuffer(
  37197. - JNIEnv* env, jclass thiz, jobject model_buffer, jlong base_options_handle,
  37198. - bool l2_normalize, bool quantize, jlong index_descriptor, int max_results) {
  37199. + JNIEnv* env,
  37200. + jclass thiz,
  37201. + jobject model_buffer,
  37202. + jlong base_options_handle,
  37203. + bool l2_normalize,
  37204. + bool quantize,
  37205. + jlong index_descriptor,
  37206. + int max_results) {
  37207. TextSearcherOptions proto_options =
  37208. ConvertToProtoOptions(base_options_handle, l2_normalize, quantize,
  37209. index_descriptor, max_results);
  37210. @@ -166,7 +181,10 @@ Java_org_tensorflow_lite_task_text_searcher_TextSearcher_initJniWithByteBuffer(
  37211. extern "C" JNIEXPORT jobject JNICALL
  37212. Java_org_tensorflow_lite_task_text_searcher_TextSearcher_searchNative(
  37213. - JNIEnv* env, jclass thiz, jlong native_handle, jstring text) {
  37214. + JNIEnv* env,
  37215. + jclass thiz,
  37216. + jlong native_handle,
  37217. + jstring text) {
  37218. auto* searcher = reinterpret_cast<TextSearcher*>(native_handle);
  37219. auto results_or = searcher->Search(JStringToString(env, text));
  37220. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/classifier/image_classifier_jni.cc b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/classifier/image_classifier_jni.cc
  37221. index 18e2ee1a7d4ab..2a713cf8b63cf 100644
  37222. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/classifier/image_classifier_jni.cc
  37223. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/classifier/image_classifier_jni.cc
  37224. @@ -54,7 +54,8 @@ using ::tflite::task::vision::ImageClassifier;
  37225. using ::tflite::task::vision::ImageClassifierOptions;
  37226. // Creates an ImageClassifierOptions proto based on the Java class.
  37227. -ImageClassifierOptions ConvertToProtoOptions(JNIEnv* env, jobject java_options,
  37228. +ImageClassifierOptions ConvertToProtoOptions(JNIEnv* env,
  37229. + jobject java_options,
  37230. jlong base_options_handle) {
  37231. ImageClassifierOptions proto_options;
  37232. @@ -175,7 +176,9 @@ jlong CreateImageClassifierFromOptions(JNIEnv* env,
  37233. extern "C" JNIEXPORT void JNICALL
  37234. Java_org_tensorflow_lite_task_vision_classifier_ImageClassifier_deinitJni(
  37235. - JNIEnv* env, jobject thiz, jlong native_handle) {
  37236. + JNIEnv* env,
  37237. + jobject thiz,
  37238. + jlong native_handle) {
  37239. delete reinterpret_cast<ImageClassifier*>(native_handle);
  37240. }
  37241. @@ -184,9 +187,13 @@ Java_org_tensorflow_lite_task_vision_classifier_ImageClassifier_deinitJni(
  37242. // values will be ignored.
  37243. extern "C" JNIEXPORT jlong JNICALL
  37244. Java_org_tensorflow_lite_task_vision_classifier_ImageClassifier_initJniWithModelFdAndOptions(
  37245. - JNIEnv* env, jclass thiz, jint file_descriptor,
  37246. - jlong file_descriptor_length, jlong file_descriptor_offset,
  37247. - jobject java_options, jlong base_options_handle) {
  37248. + JNIEnv* env,
  37249. + jclass thiz,
  37250. + jint file_descriptor,
  37251. + jlong file_descriptor_length,
  37252. + jlong file_descriptor_offset,
  37253. + jobject java_options,
  37254. + jlong base_options_handle) {
  37255. ImageClassifierOptions proto_options =
  37256. ConvertToProtoOptions(env, java_options, base_options_handle);
  37257. auto file_descriptor_meta = proto_options.mutable_base_options()
  37258. @@ -204,7 +211,10 @@ Java_org_tensorflow_lite_task_vision_classifier_ImageClassifier_initJniWithModel
  37259. extern "C" JNIEXPORT jlong JNICALL
  37260. Java_org_tensorflow_lite_task_vision_classifier_ImageClassifier_initJniWithByteBuffer(
  37261. - JNIEnv* env, jclass thiz, jobject model_buffer, jobject java_options,
  37262. + JNIEnv* env,
  37263. + jclass thiz,
  37264. + jobject model_buffer,
  37265. + jobject java_options,
  37266. jlong base_options_handle) {
  37267. ImageClassifierOptions proto_options =
  37268. ConvertToProtoOptions(env, java_options, base_options_handle);
  37269. @@ -220,7 +230,10 @@ Java_org_tensorflow_lite_task_vision_classifier_ImageClassifier_initJniWithByteB
  37270. extern "C" JNIEXPORT jobject JNICALL
  37271. Java_org_tensorflow_lite_task_vision_classifier_ImageClassifier_classifyNative(
  37272. - JNIEnv* env, jclass thiz, jlong native_handle, jlong frame_buffer_handle,
  37273. + JNIEnv* env,
  37274. + jclass thiz,
  37275. + jlong native_handle,
  37276. + jlong frame_buffer_handle,
  37277. jintArray jroi) {
  37278. auto* classifier = reinterpret_cast<ImageClassifier*>(native_handle);
  37279. // frame_buffer will be deleted after inference is done in
  37280. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/core/base_vision_task_api_jni.cc b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/core/base_vision_task_api_jni.cc
  37281. index 84bff227f2543..2cda1b500aeb5 100644
  37282. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/core/base_vision_task_api_jni.cc
  37283. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/core/base_vision_task_api_jni.cc
  37284. @@ -31,8 +31,13 @@ using ::tflite::task::vision::FrameBuffer;
  37285. extern "C" JNIEXPORT jlong JNICALL
  37286. Java_org_tensorflow_lite_task_vision_core_BaseVisionTaskApi_createFrameBufferFromByteBuffer(
  37287. - JNIEnv* env, jclass thiz, jobject jimage_byte_buffer, jint width,
  37288. - jint height, jint jorientation, jint jcolor_space_type) {
  37289. + JNIEnv* env,
  37290. + jclass thiz,
  37291. + jobject jimage_byte_buffer,
  37292. + jint width,
  37293. + jint height,
  37294. + jint jorientation,
  37295. + jint jcolor_space_type) {
  37296. auto frame_buffer_or = CreateFrameBufferFromByteBuffer(
  37297. env, jimage_byte_buffer, width, height, jorientation, jcolor_space_type);
  37298. if (frame_buffer_or.ok()) {
  37299. @@ -49,8 +54,14 @@ Java_org_tensorflow_lite_task_vision_core_BaseVisionTaskApi_createFrameBufferFro
  37300. extern "C" JNIEXPORT jlong JNICALL
  37301. Java_org_tensorflow_lite_task_vision_core_BaseVisionTaskApi_createFrameBufferFromBytes(
  37302. - JNIEnv* env, jclass thiz, jbyteArray jimage_bytes, jint width, jint height,
  37303. - jint jorientation, jint jcolor_space_type, jlongArray jbyte_array_handle) {
  37304. + JNIEnv* env,
  37305. + jclass thiz,
  37306. + jbyteArray jimage_bytes,
  37307. + jint width,
  37308. + jint height,
  37309. + jint jorientation,
  37310. + jint jcolor_space_type,
  37311. + jlongArray jbyte_array_handle) {
  37312. auto frame_buffer_or =
  37313. CreateFrameBufferFromBytes(env, jimage_bytes, width, height, jorientation,
  37314. jcolor_space_type, jbyte_array_handle);
  37315. @@ -68,9 +79,17 @@ Java_org_tensorflow_lite_task_vision_core_BaseVisionTaskApi_createFrameBufferFro
  37316. extern "C" JNIEXPORT jlong JNICALL
  37317. Java_org_tensorflow_lite_task_vision_core_BaseVisionTaskApi_createFrameBufferFromPlanes(
  37318. - JNIEnv* env, jclass thiz, jobject jy_plane, jobject ju_plane,
  37319. - jobject jv_plane, jint width, jint height, jint row_stride_y,
  37320. - jint row_stride_uv, jint pixel_stride_uv, jint orientation) {
  37321. + JNIEnv* env,
  37322. + jclass thiz,
  37323. + jobject jy_plane,
  37324. + jobject ju_plane,
  37325. + jobject jv_plane,
  37326. + jint width,
  37327. + jint height,
  37328. + jint row_stride_y,
  37329. + jint row_stride_uv,
  37330. + jint pixel_stride_uv,
  37331. + jint orientation) {
  37332. auto frame_buffer_or = CreateFrameBufferFromYuvPlanes(
  37333. env, jy_plane, ju_plane, jv_plane, width, height, row_stride_y,
  37334. row_stride_uv, pixel_stride_uv, orientation);
  37335. @@ -88,8 +107,11 @@ Java_org_tensorflow_lite_task_vision_core_BaseVisionTaskApi_createFrameBufferFro
  37336. extern "C" JNIEXPORT void JNICALL
  37337. Java_org_tensorflow_lite_task_vision_core_BaseVisionTaskApi_deleteFrameBuffer(
  37338. - JNIEnv* env, jobject thiz, jlong frame_buffer_handle,
  37339. - jlong byte_array_handle, jbyteArray jbyte_array) {
  37340. + JNIEnv* env,
  37341. + jobject thiz,
  37342. + jlong frame_buffer_handle,
  37343. + jlong byte_array_handle,
  37344. + jbyteArray jbyte_array) {
  37345. delete reinterpret_cast<FrameBuffer*>(frame_buffer_handle);
  37346. jbyte* bytes_ptr = reinterpret_cast<jbyte*>(byte_array_handle);
  37347. if (bytes_ptr != NULL) {
  37348. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/detector/object_detector_jni.cc b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/detector/object_detector_jni.cc
  37349. index ddb0b72a25b65..f720795263791 100644
  37350. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/detector/object_detector_jni.cc
  37351. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/detector/object_detector_jni.cc
  37352. @@ -54,7 +54,8 @@ using ::tflite::task::vision::ObjectDetector;
  37353. using ::tflite::task::vision::ObjectDetectorOptions;
  37354. // Creates an ObjectDetectorOptions proto based on the Java class.
  37355. -ObjectDetectorOptions ConvertToProtoOptions(JNIEnv* env, jobject java_options,
  37356. +ObjectDetectorOptions ConvertToProtoOptions(JNIEnv* env,
  37357. + jobject java_options,
  37358. jlong base_options_handle) {
  37359. ObjectDetectorOptions proto_options;
  37360. @@ -183,7 +184,9 @@ jlong CreateObjectDetectorFromOptions(JNIEnv* env,
  37361. extern "C" JNIEXPORT void JNICALL
  37362. Java_org_tensorflow_lite_task_vision_detector_ObjectDetector_deinitJni(
  37363. - JNIEnv* env, jobject thiz, jlong native_handle) {
  37364. + JNIEnv* env,
  37365. + jobject thiz,
  37366. + jlong native_handle) {
  37367. delete reinterpret_cast<ObjectDetector*>(native_handle);
  37368. }
  37369. @@ -192,9 +195,13 @@ Java_org_tensorflow_lite_task_vision_detector_ObjectDetector_deinitJni(
  37370. // values will be ignored.
  37371. extern "C" JNIEXPORT jlong JNICALL
  37372. Java_org_tensorflow_lite_task_vision_detector_ObjectDetector_initJniWithModelFdAndOptions(
  37373. - JNIEnv* env, jclass thiz, jint file_descriptor,
  37374. - jlong file_descriptor_length, jlong file_descriptor_offset,
  37375. - jobject java_options, jlong base_options_handle) {
  37376. + JNIEnv* env,
  37377. + jclass thiz,
  37378. + jint file_descriptor,
  37379. + jlong file_descriptor_length,
  37380. + jlong file_descriptor_offset,
  37381. + jobject java_options,
  37382. + jlong base_options_handle) {
  37383. ObjectDetectorOptions proto_options =
  37384. ConvertToProtoOptions(env, java_options, base_options_handle);
  37385. auto file_descriptor_meta = proto_options.mutable_base_options()
  37386. @@ -212,7 +219,10 @@ Java_org_tensorflow_lite_task_vision_detector_ObjectDetector_initJniWithModelFdA
  37387. extern "C" JNIEXPORT jlong JNICALL
  37388. Java_org_tensorflow_lite_task_vision_detector_ObjectDetector_initJniWithByteBuffer(
  37389. - JNIEnv* env, jclass thiz, jobject model_buffer, jobject java_options,
  37390. + JNIEnv* env,
  37391. + jclass thiz,
  37392. + jobject model_buffer,
  37393. + jobject java_options,
  37394. jlong base_options_handle) {
  37395. ObjectDetectorOptions proto_options =
  37396. ConvertToProtoOptions(env, java_options, base_options_handle);
  37397. @@ -224,7 +234,10 @@ Java_org_tensorflow_lite_task_vision_detector_ObjectDetector_initJniWithByteBuff
  37398. extern "C" JNIEXPORT jobject JNICALL
  37399. Java_org_tensorflow_lite_task_vision_detector_ObjectDetector_detectNative(
  37400. - JNIEnv* env, jclass thiz, jlong native_handle, jlong frame_buffer_handle) {
  37401. + JNIEnv* env,
  37402. + jclass thiz,
  37403. + jlong native_handle,
  37404. + jlong frame_buffer_handle) {
  37405. auto* detector = reinterpret_cast<ObjectDetector*>(native_handle);
  37406. // frame_buffer will be deleted after inference is done in
  37407. // base_vision_api_jni.cc.
  37408. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/jni_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/jni_utils.cc
  37409. index 1b08e56ed509b..e0c94e2ec72c6 100644
  37410. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/jni_utils.cc
  37411. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/jni_utils.cc
  37412. @@ -135,8 +135,12 @@ StatusOr<FrameBuffer::Format> GetYUVImageFormat(const uint8* u_buffer,
  37413. }
  37414. StatusOr<std::unique_ptr<FrameBuffer>> CreateFrameBufferFromByteBuffer(
  37415. - JNIEnv* env, jobject jimage_byte_buffer, jint width, jint height,
  37416. - jint jorientation, jint jcolor_space_type) {
  37417. + JNIEnv* env,
  37418. + jobject jimage_byte_buffer,
  37419. + jint width,
  37420. + jint height,
  37421. + jint jorientation,
  37422. + jint jcolor_space_type) {
  37423. absl::string_view image = GetMappedFileBuffer(env, jimage_byte_buffer);
  37424. return CreateFromRawBuffer(
  37425. reinterpret_cast<const uint8*>(image.data()),
  37426. @@ -146,8 +150,13 @@ StatusOr<std::unique_ptr<FrameBuffer>> CreateFrameBufferFromByteBuffer(
  37427. }
  37428. StatusOr<std::unique_ptr<FrameBuffer>> CreateFrameBufferFromBytes(
  37429. - JNIEnv* env, jbyteArray jimage_bytes, jint width, jint height,
  37430. - jint jorientation, jint jcolor_space_type, jlongArray jbyte_array_handle) {
  37431. + JNIEnv* env,
  37432. + jbyteArray jimage_bytes,
  37433. + jint width,
  37434. + jint height,
  37435. + jint jorientation,
  37436. + jint jcolor_space_type,
  37437. + jlongArray jbyte_array_handle) {
  37438. jbyte* jimage_ptr = env->GetByteArrayElements(jimage_bytes, NULL);
  37439. // Free jimage_ptr together with frame_buffer after inference is finished.
  37440. jlong jimage_ptr_handle = reinterpret_cast<jlong>(jimage_ptr);
  37441. @@ -168,9 +177,16 @@ StatusOr<std::unique_ptr<FrameBuffer>> CreateFrameBufferFromBytes(
  37442. }
  37443. StatusOr<std::unique_ptr<FrameBuffer>> CreateFrameBufferFromYuvPlanes(
  37444. - JNIEnv* env, jobject jy_plane, jobject ju_plane, jobject jv_plane,
  37445. - jint width, jint height, jint row_stride_y, jint row_stride_uv,
  37446. - jint pixel_stride_uv, jint jorientation) {
  37447. + JNIEnv* env,
  37448. + jobject jy_plane,
  37449. + jobject ju_plane,
  37450. + jobject jv_plane,
  37451. + jint width,
  37452. + jint height,
  37453. + jint row_stride_y,
  37454. + jint row_stride_uv,
  37455. + jint pixel_stride_uv,
  37456. + jint jorientation) {
  37457. const uint8* y_plane =
  37458. reinterpret_cast<const uint8*>(GetMappedFileBuffer(env, jy_plane).data());
  37459. const uint8* u_plane =
  37460. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/jni_utils.h b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/jni_utils.h
  37461. index dbe32f8a3f2a5..4d7ec17a1c042 100644
  37462. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/jni_utils.h
  37463. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/jni_utils.h
  37464. @@ -34,23 +34,35 @@ FrameBuffer::Orientation ConvertToFrameBufferOrientation(JNIEnv* env,
  37465. // Creates FrameBuffer from a direct ByteBuffer.
  37466. ::tflite::support::StatusOr<std::unique_ptr<FrameBuffer>>
  37467. -CreateFrameBufferFromByteBuffer(JNIEnv* env, jobject jimage_byte_buffer,
  37468. - jint width, jint height, jint jorientation,
  37469. +CreateFrameBufferFromByteBuffer(JNIEnv* env,
  37470. + jobject jimage_byte_buffer,
  37471. + jint width,
  37472. + jint height,
  37473. + jint jorientation,
  37474. jint jcolor_space_type);
  37475. // Creates FrameBuffer from a byte array.
  37476. ::tflite::support::StatusOr<std::unique_ptr<FrameBuffer>>
  37477. -CreateFrameBufferFromBytes(JNIEnv* env, jbyteArray jimage_bytes, jint width,
  37478. - jint height, jint jorientation,
  37479. +CreateFrameBufferFromBytes(JNIEnv* env,
  37480. + jbyteArray jimage_bytes,
  37481. + jint width,
  37482. + jint height,
  37483. + jint jorientation,
  37484. jint jcolor_space_type,
  37485. jlongArray jbyte_array_handle);
  37486. // Creates FrameBuffer from YUV planes.
  37487. ::tflite::support::StatusOr<std::unique_ptr<FrameBuffer>>
  37488. -CreateFrameBufferFromYuvPlanes(JNIEnv* env, jobject jy_plane, jobject ju_plane,
  37489. - jobject jv_plane, jint width, jint height,
  37490. - jint row_stride_y, jint row_stride_uv,
  37491. - jint pixel_stride_uv, jint jorientation);
  37492. +CreateFrameBufferFromYuvPlanes(JNIEnv* env,
  37493. + jobject jy_plane,
  37494. + jobject ju_plane,
  37495. + jobject jv_plane,
  37496. + jint width,
  37497. + jint height,
  37498. + jint row_stride_y,
  37499. + jint row_stride_uv,
  37500. + jint pixel_stride_uv,
  37501. + jint jorientation);
  37502. } // namespace vision
  37503. } // namespace task
  37504. } // namespace tflite
  37505. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/searcher/image_searcher_jni.cc b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/searcher/image_searcher_jni.cc
  37506. index e57f12a16aab3..84cad5db43ea2 100644
  37507. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/searcher/image_searcher_jni.cc
  37508. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/searcher/image_searcher_jni.cc
  37509. @@ -52,7 +52,8 @@ using ::tflite::task::vision::ImageSearcherOptions;
  37510. // Creates an ImageSearcherOptions proto based on the Java class.
  37511. ImageSearcherOptions ConvertToProtoOptions(jlong base_options_handle,
  37512. - bool l2_normalize, bool quantize,
  37513. + bool l2_normalize,
  37514. + bool quantize,
  37515. int index_descriptor,
  37516. int max_results) {
  37517. ImageSearcherOptions proto_options;
  37518. @@ -124,7 +125,9 @@ jobject ConvertToSearchResults(JNIEnv* env, const SearchResult& results) {
  37519. extern "C" JNIEXPORT void JNICALL
  37520. Java_org_tensorflow_lite_task_vision_searcher_ImageSearcher_deinitJni(
  37521. - JNIEnv* env, jobject thiz, jlong native_handle) {
  37522. + JNIEnv* env,
  37523. + jobject thiz,
  37524. + jlong native_handle) {
  37525. delete reinterpret_cast<ImageSearcher*>(native_handle);
  37526. }
  37527. @@ -133,10 +136,16 @@ Java_org_tensorflow_lite_task_vision_searcher_ImageSearcher_deinitJni(
  37528. // values will be ignored.
  37529. extern "C" JNIEXPORT jlong JNICALL
  37530. Java_org_tensorflow_lite_task_vision_searcher_ImageSearcher_initJniWithModelFdAndOptions(
  37531. - JNIEnv* env, jclass thiz, jint model_descriptor,
  37532. - jlong model_descriptor_length, jlong model_descriptor_offset,
  37533. - jlong base_options_handle, bool l2_normalize, bool quantize,
  37534. - jint index_descriptor, int max_results) {
  37535. + JNIEnv* env,
  37536. + jclass thiz,
  37537. + jint model_descriptor,
  37538. + jlong model_descriptor_length,
  37539. + jlong model_descriptor_offset,
  37540. + jlong base_options_handle,
  37541. + bool l2_normalize,
  37542. + bool quantize,
  37543. + jint index_descriptor,
  37544. + int max_results) {
  37545. ImageSearcherOptions proto_options =
  37546. ConvertToProtoOptions(base_options_handle, l2_normalize, quantize,
  37547. index_descriptor, max_results);
  37548. @@ -156,8 +165,14 @@ Java_org_tensorflow_lite_task_vision_searcher_ImageSearcher_initJniWithModelFdAn
  37549. extern "C" JNIEXPORT jlong JNICALL
  37550. Java_org_tensorflow_lite_task_vision_searcher_ImageSearcher_initJniWithByteBuffer(
  37551. - JNIEnv* env, jclass thiz, jobject model_buffer, jlong base_options_handle,
  37552. - bool l2_normalize, bool quantize, jlong index_descriptor, int max_results) {
  37553. + JNIEnv* env,
  37554. + jclass thiz,
  37555. + jobject model_buffer,
  37556. + jlong base_options_handle,
  37557. + bool l2_normalize,
  37558. + bool quantize,
  37559. + jlong index_descriptor,
  37560. + int max_results) {
  37561. ImageSearcherOptions proto_options =
  37562. ConvertToProtoOptions(base_options_handle, l2_normalize, quantize,
  37563. index_descriptor, max_results);
  37564. @@ -170,7 +185,10 @@ Java_org_tensorflow_lite_task_vision_searcher_ImageSearcher_initJniWithByteBuffe
  37565. extern "C" JNIEXPORT jobject JNICALL
  37566. Java_org_tensorflow_lite_task_vision_searcher_ImageSearcher_searchNative(
  37567. - JNIEnv* env, jclass thiz, jlong native_handle, jlong frame_buffer_handle,
  37568. + JNIEnv* env,
  37569. + jclass thiz,
  37570. + jlong native_handle,
  37571. + jlong frame_buffer_handle,
  37572. jintArray jroi) {
  37573. auto* searcher = reinterpret_cast<ImageSearcher*>(native_handle);
  37574. // frame_buffer will be deleted after inference is done in
  37575. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/segmenter/image_segmenter_jni.cc b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/segmenter/image_segmenter_jni.cc
  37576. index 40fa4472d37e1..8d8c8eec34295 100644
  37577. --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/segmenter/image_segmenter_jni.cc
  37578. +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/segmenter/image_segmenter_jni.cc
  37579. @@ -194,7 +194,9 @@ jlong CreateImageSegmenterFromOptions(JNIEnv* env,
  37580. extern "C" JNIEXPORT void JNICALL
  37581. Java_org_tensorflow_lite_task_vision_segmenter_ImageSegmenter_deinitJni(
  37582. - JNIEnv* env, jobject thiz, jlong native_handle) {
  37583. + JNIEnv* env,
  37584. + jobject thiz,
  37585. + jlong native_handle) {
  37586. delete reinterpret_cast<ImageSegmenter*>(native_handle);
  37587. }
  37588. @@ -203,9 +205,14 @@ Java_org_tensorflow_lite_task_vision_segmenter_ImageSegmenter_deinitJni(
  37589. // values will be ignored.
  37590. extern "C" JNIEXPORT jlong JNICALL
  37591. Java_org_tensorflow_lite_task_vision_segmenter_ImageSegmenter_initJniWithModelFdAndOptions(
  37592. - JNIEnv* env, jclass thiz, jint file_descriptor,
  37593. - jlong file_descriptor_length, jlong file_descriptor_offset,
  37594. - jstring display_names_locale, jint output_type, jlong base_options_handle) {
  37595. + JNIEnv* env,
  37596. + jclass thiz,
  37597. + jint file_descriptor,
  37598. + jlong file_descriptor_length,
  37599. + jlong file_descriptor_offset,
  37600. + jstring display_names_locale,
  37601. + jint output_type,
  37602. + jlong base_options_handle) {
  37603. ImageSegmenterOptions proto_options = ConvertToProtoOptions(
  37604. env, display_names_locale, output_type, base_options_handle);
  37605. auto file_descriptor_meta = proto_options.mutable_base_options()
  37606. @@ -223,8 +230,12 @@ Java_org_tensorflow_lite_task_vision_segmenter_ImageSegmenter_initJniWithModelFd
  37607. extern "C" JNIEXPORT jlong JNICALL
  37608. Java_org_tensorflow_lite_task_vision_segmenter_ImageSegmenter_initJniWithByteBuffer(
  37609. - JNIEnv* env, jclass thiz, jobject model_buffer,
  37610. - jstring display_names_locale, jint output_type, jlong base_options_handle) {
  37611. + JNIEnv* env,
  37612. + jclass thiz,
  37613. + jobject model_buffer,
  37614. + jstring display_names_locale,
  37615. + jint output_type,
  37616. + jlong base_options_handle) {
  37617. ImageSegmenterOptions proto_options = ConvertToProtoOptions(
  37618. env, display_names_locale, output_type, base_options_handle);
  37619. proto_options.mutable_base_options()->mutable_model_file()->set_file_content(
  37620. @@ -235,8 +246,13 @@ Java_org_tensorflow_lite_task_vision_segmenter_ImageSegmenter_initJniWithByteBuf
  37621. extern "C" JNIEXPORT void JNICALL
  37622. Java_org_tensorflow_lite_task_vision_segmenter_ImageSegmenter_segmentNative(
  37623. - JNIEnv* env, jclass thiz, jlong native_handle, jlong frame_buffer_handle,
  37624. - jobject jmask_buffers, jintArray jmask_shape, jobject jcolored_labels) {
  37625. + JNIEnv* env,
  37626. + jclass thiz,
  37627. + jlong native_handle,
  37628. + jlong frame_buffer_handle,
  37629. + jobject jmask_buffers,
  37630. + jintArray jmask_shape,
  37631. + jobject jcolored_labels) {
  37632. auto* segmenter = reinterpret_cast<ImageSegmenter*>(native_handle);
  37633. // frame_buffer will be deleted after inference is done in
  37634. // base_vision_api_jni.cc.
  37635. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_extractor.cc b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_extractor.cc
  37636. index 65a01c0b9d33a..2a72338741626 100644
  37637. --- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_extractor.cc
  37638. +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_extractor.cc
  37639. @@ -17,13 +17,13 @@ limitations under the License.
  37640. #include <string>
  37641. -#include "absl/memory/memory.h" // from @com_google_absl
  37642. -#include "absl/status/status.h" // from @com_google_absl
  37643. -#include "absl/strings/str_format.h" // from @com_google_absl
  37644. +#include "absl/memory/memory.h" // from @com_google_absl
  37645. +#include "absl/status/status.h" // from @com_google_absl
  37646. +#include "absl/strings/str_format.h" // from @com_google_absl
  37647. #include "absl/strings/string_view.h" // from @com_google_absl
  37648. -#include "flatbuffers/flatbuffers.h" // from @flatbuffers
  37649. #include "contrib/minizip/ioapi.h"
  37650. #include "contrib/minizip/unzip.h"
  37651. +#include "flatbuffers/flatbuffers.h" // from @flatbuffers
  37652. #include "tensorflow/lite/schema/schema_generated.h"
  37653. #include "tensorflow_lite_support/cc/common.h"
  37654. #include "tensorflow_lite_support/cc/port/status_macros.h"
  37655. @@ -46,7 +46,8 @@ using ::tflite::support::TfLiteSupportStatus;
  37656. // Util to get item from src_vector specified by index.
  37657. template <typename T>
  37658. const T* GetItemFromVector(
  37659. - const flatbuffers::Vector<flatbuffers::Offset<T>>* src_vector, int index) {
  37660. + const flatbuffers::Vector<flatbuffers::Offset<T>>* src_vector,
  37661. + int index) {
  37662. if (src_vector == nullptr || index < 0 || index >= src_vector->size()) {
  37663. return nullptr;
  37664. }
  37665. @@ -158,7 +159,8 @@ ModelMetadataExtractor::FindFirstProcessUnit(
  37666. /* static */
  37667. std::string ModelMetadataExtractor::FindFirstAssociatedFileName(
  37668. const tflite::TensorMetadata& tensor_metadata,
  37669. - tflite::AssociatedFileType type, absl::string_view locale) {
  37670. + tflite::AssociatedFileType type,
  37671. + absl::string_view locale) {
  37672. if (tensor_metadata.associated_files() == nullptr) {
  37673. return std::string();
  37674. }
  37675. @@ -175,7 +177,8 @@ std::string ModelMetadataExtractor::FindFirstAssociatedFileName(
  37676. }
  37677. absl::Status ModelMetadataExtractor::InitFromModelBuffer(
  37678. - const char* buffer_data, size_t buffer_size) {
  37679. + const char* buffer_data,
  37680. + size_t buffer_size) {
  37681. // Rely on the simplest, base flatbuffers verifier. Here is not the place to
  37682. // e.g. use an OpResolver: we just want to make sure the buffer is valid to
  37683. // access the metadata.
  37684. @@ -234,7 +237,8 @@ absl::Status ModelMetadataExtractor::InitFromModelBuffer(
  37685. }
  37686. absl::Status ModelMetadataExtractor::ExtractAssociatedFiles(
  37687. - const char* buffer_data, size_t buffer_size) {
  37688. + const char* buffer_data,
  37689. + size_t buffer_size) {
  37690. // Create in-memory read-only zip file.
  37691. ZipReadOnlyMemFile mem_file = ZipReadOnlyMemFile(buffer_data, buffer_size);
  37692. // Open zip.
  37693. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_extractor.h b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_extractor.h
  37694. index c2b28d18ef7d8..007919d581431 100644
  37695. --- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_extractor.h
  37696. +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_extractor.h
  37697. @@ -16,8 +16,8 @@ limitations under the License.
  37698. #define TENSORFLOW_LITE_SUPPORT_METADATA_CC_METADATA_EXTRACTOR_H_
  37699. #include "absl/container/flat_hash_map.h" // from @com_google_absl
  37700. -#include "absl/status/status.h" // from @com_google_absl
  37701. -#include "absl/strings/string_view.h" // from @com_google_absl
  37702. +#include "absl/status/status.h" // from @com_google_absl
  37703. +#include "absl/strings/string_view.h" // from @com_google_absl
  37704. #include "tensorflow/lite/schema/schema_generated.h"
  37705. #include "tensorflow_lite_support/cc/port/statusor.h"
  37706. #include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
  37707. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_populator.cc b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_populator.cc
  37708. index 9d256b3322fb0..299ade3e95d54 100644
  37709. --- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_populator.cc
  37710. +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_populator.cc
  37711. @@ -19,9 +19,9 @@ limitations under the License.
  37712. #include <cstring>
  37713. #include <functional>
  37714. -#include "flatbuffers/flatbuffers.h" // from @flatbuffers
  37715. #include "contrib/minizip/ioapi.h"
  37716. #include "contrib/minizip/zip.h"
  37717. +#include "flatbuffers/flatbuffers.h" // from @flatbuffers
  37718. #include "tensorflow/lite/schema/schema_generated.h"
  37719. #include "tensorflow_lite_support/cc/common.h"
  37720. #include "tensorflow_lite_support/cc/port/status_macros.h"
  37721. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_populator.h b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_populator.h
  37722. index 510e6c04cdda1..4410f8481f97d 100644
  37723. --- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_populator.h
  37724. +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_populator.h
  37725. @@ -17,8 +17,8 @@ limitations under the License.
  37726. #define TENSORFLOW_LITE_SUPPORT_METADATA_CC_METADATA_POPULATOR_H_
  37727. #include "absl/container/flat_hash_map.h" // from @com_google_absl
  37728. -#include "absl/status/status.h" // from @com_google_absl
  37729. -#include "flatbuffers/flatbuffers.h" // from @flatbuffers
  37730. +#include "absl/status/status.h" // from @com_google_absl
  37731. +#include "flatbuffers/flatbuffers.h" // from @flatbuffers
  37732. #include "tensorflow/lite/schema/schema_generated.h"
  37733. #include "tensorflow_lite_support/cc/port/statusor.h"
  37734. #include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
  37735. @@ -79,7 +79,8 @@ class ModelMetadataPopulator {
  37736. // Zips and appends associated files to the provided model buffer. Called
  37737. // internally by `Populate()`.
  37738. tflite::support::StatusOr<std::string> AppendAssociatedFiles(
  37739. - const char* model_buffer_data, size_t model_buffer_size);
  37740. + const char* model_buffer_data,
  37741. + size_t model_buffer_size);
  37742. // The unpacked model FlatBuffer.
  37743. tflite::ModelT model_t_;
  37744. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_version.cc b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_version.cc
  37745. index 17ffbbc67fbec..78e9a9f1abec1 100644
  37746. --- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_version.cc
  37747. +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_version.cc
  37748. @@ -137,7 +137,8 @@ template <typename T>
  37749. void UpdateMinimumVersionForArray(
  37750. const flatbuffers::Vector<flatbuffers::Offset<T>>* array,
  37751. Version* min_version) {
  37752. - if (array == nullptr) return;
  37753. + if (array == nullptr)
  37754. + return;
  37755. for (int i = 0; i < array->size(); ++i) {
  37756. UpdateMinimumVersionForTable<T>(array->Get(i), min_version);
  37757. @@ -146,8 +147,10 @@ void UpdateMinimumVersionForArray(
  37758. template <>
  37759. void UpdateMinimumVersionForTable<tflite::AssociatedFile>(
  37760. - const tflite::AssociatedFile* table, Version* min_version) {
  37761. - if (table == nullptr) return;
  37762. + const tflite::AssociatedFile* table,
  37763. + Version* min_version) {
  37764. + if (table == nullptr)
  37765. + return;
  37766. if (table->type() == AssociatedFileType_VOCABULARY) {
  37767. UpdateMinimumVersion(
  37768. @@ -164,8 +167,10 @@ void UpdateMinimumVersionForTable<tflite::AssociatedFile>(
  37769. template <>
  37770. void UpdateMinimumVersionForTable<tflite::ProcessUnit>(
  37771. - const tflite::ProcessUnit* table, Version* min_version) {
  37772. - if (table == nullptr) return;
  37773. + const tflite::ProcessUnit* table,
  37774. + Version* min_version) {
  37775. + if (table == nullptr)
  37776. + return;
  37777. tflite::ProcessUnitOptions process_unit_type = table->options_type();
  37778. if (process_unit_type == ProcessUnitOptions_BertTokenizerOptions) {
  37779. @@ -191,7 +196,8 @@ void UpdateMinimumVersionForTable<tflite::ProcessUnit>(
  37780. template <>
  37781. void UpdateMinimumVersionForTable<tflite::Content>(const tflite::Content* table,
  37782. Version* min_version) {
  37783. - if (table == nullptr) return;
  37784. + if (table == nullptr)
  37785. + return;
  37786. // Checks the ContenProperties field.
  37787. if (table->content_properties_type() == ContentProperties_AudioProperties) {
  37788. @@ -203,8 +209,10 @@ void UpdateMinimumVersionForTable<tflite::Content>(const tflite::Content* table,
  37789. template <>
  37790. void UpdateMinimumVersionForTable<tflite::TensorMetadata>(
  37791. - const tflite::TensorMetadata* table, Version* min_version) {
  37792. - if (table == nullptr) return;
  37793. + const tflite::TensorMetadata* table,
  37794. + Version* min_version) {
  37795. + if (table == nullptr)
  37796. + return;
  37797. // Checks the associated_files field.
  37798. UpdateMinimumVersionForArray<tflite::AssociatedFile>(
  37799. @@ -220,8 +228,10 @@ void UpdateMinimumVersionForTable<tflite::TensorMetadata>(
  37800. template <>
  37801. void UpdateMinimumVersionForTable<tflite::SubGraphMetadata>(
  37802. - const tflite::SubGraphMetadata* table, Version* min_version) {
  37803. - if (table == nullptr) return;
  37804. + const tflite::SubGraphMetadata* table,
  37805. + Version* min_version) {
  37806. + if (table == nullptr)
  37807. + return;
  37808. // Checks in the input/output metadata arrays.
  37809. UpdateMinimumVersionForArray<tflite::TensorMetadata>(
  37810. @@ -268,7 +278,8 @@ void UpdateMinimumVersionForTable<tflite::SubGraphMetadata>(
  37811. template <>
  37812. void UpdateMinimumVersionForTable<tflite::ModelMetadata>(
  37813. - const tflite::ModelMetadata* table, Version* min_version) {
  37814. + const tflite::ModelMetadata* table,
  37815. + Version* min_version) {
  37816. if (table == nullptr) {
  37817. // Should never happen, because VerifyModelMetadataBuffer has verified it.
  37818. TFLITE_LOG(FATAL) << "The ModelMetadata object is null.";
  37819. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_readonly_mem_file.cc b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_readonly_mem_file.cc
  37820. index 3dac8c24af942..392b6b411fe03 100644
  37821. --- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_readonly_mem_file.cc
  37822. +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_readonly_mem_file.cc
  37823. @@ -41,14 +41,17 @@ zlib_filefunc64_def& ZipReadOnlyMemFile::GetFileFunc64Def() {
  37824. }
  37825. /* static */
  37826. -voidpf ZipReadOnlyMemFile::OpenFile(voidpf opaque, const void* filename,
  37827. +voidpf ZipReadOnlyMemFile::OpenFile(voidpf opaque,
  37828. + const void* filename,
  37829. int mode) {
  37830. // Result is never used, but needs to be non-null for `zipOpen2` not to fail.
  37831. return opaque;
  37832. }
  37833. /* static */
  37834. -uLong ZipReadOnlyMemFile::ReadFile(voidpf opaque, voidpf stream, void* buf,
  37835. +uLong ZipReadOnlyMemFile::ReadFile(voidpf opaque,
  37836. + voidpf stream,
  37837. + void* buf,
  37838. uLong size) {
  37839. auto* mem_file = static_cast<ZipReadOnlyMemFile*>(opaque);
  37840. if (mem_file->offset_ < 0 || mem_file->Size() < mem_file->offset_) {
  37841. @@ -65,8 +68,10 @@ uLong ZipReadOnlyMemFile::ReadFile(voidpf opaque, voidpf stream, void* buf,
  37842. }
  37843. /* static */
  37844. -uLong ZipReadOnlyMemFile::WriteFile(voidpf opaque, voidpf stream,
  37845. - const void* buf, uLong size) {
  37846. +uLong ZipReadOnlyMemFile::WriteFile(voidpf opaque,
  37847. + voidpf stream,
  37848. + const void* buf,
  37849. + uLong size) {
  37850. // File is not writable.
  37851. return 0;
  37852. }
  37853. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_readonly_mem_file.h b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_readonly_mem_file.h
  37854. index 13927a7afa698..a1799ff509de5 100644
  37855. --- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_readonly_mem_file.h
  37856. +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_readonly_mem_file.h
  37857. @@ -58,7 +58,9 @@ class ZipReadOnlyMemFile {
  37858. // The file function implementations used in the `zlib_filefunc64_def`.
  37859. static voidpf OpenFile(voidpf opaque, const void* filename, int mode);
  37860. static uLong ReadFile(voidpf opaque, voidpf stream, void* buf, uLong size);
  37861. - static uLong WriteFile(voidpf opaque, voidpf stream, const void* buf,
  37862. + static uLong WriteFile(voidpf opaque,
  37863. + voidpf stream,
  37864. + const void* buf,
  37865. uLong size);
  37866. static ZPOS64_T TellFile(voidpf opaque, voidpf stream);
  37867. static long SeekFile // NOLINT
  37868. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_writable_mem_file.cc b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_writable_mem_file.cc
  37869. index 5999be028689a..38ad17ad8935c 100644
  37870. --- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_writable_mem_file.cc
  37871. +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_writable_mem_file.cc
  37872. @@ -40,17 +40,22 @@ zlib_filefunc64_def& ZipWritableMemFile::GetFileFunc64Def() {
  37873. return zlib_filefunc64_def_;
  37874. }
  37875. -absl::string_view ZipWritableMemFile::GetFileContent() const { return data_; }
  37876. +absl::string_view ZipWritableMemFile::GetFileContent() const {
  37877. + return data_;
  37878. +}
  37879. /* static */
  37880. -voidpf ZipWritableMemFile::OpenFile(voidpf opaque, const void* filename,
  37881. +voidpf ZipWritableMemFile::OpenFile(voidpf opaque,
  37882. + const void* filename,
  37883. int mode) {
  37884. // Result is never used, but needs to be non-null for `zipOpen2` not to fail.
  37885. return opaque;
  37886. }
  37887. /* static */
  37888. -uLong ZipWritableMemFile::ReadFile(voidpf opaque, voidpf stream, void* buf,
  37889. +uLong ZipWritableMemFile::ReadFile(voidpf opaque,
  37890. + voidpf stream,
  37891. + void* buf,
  37892. uLong size) {
  37893. auto* mem_file = static_cast<ZipWritableMemFile*>(opaque);
  37894. if (mem_file->offset_ < 0 || mem_file->Size() < mem_file->offset_) {
  37895. @@ -67,8 +72,10 @@ uLong ZipWritableMemFile::ReadFile(voidpf opaque, voidpf stream, void* buf,
  37896. }
  37897. /* static */
  37898. -uLong ZipWritableMemFile::WriteFile(voidpf opaque, voidpf stream,
  37899. - const void* buf, uLong size) {
  37900. +uLong ZipWritableMemFile::WriteFile(voidpf opaque,
  37901. + voidpf stream,
  37902. + const void* buf,
  37903. + uLong size) {
  37904. auto* mem_file = static_cast<ZipWritableMemFile*>(opaque);
  37905. if (mem_file->offset_ + size > mem_file->Size()) {
  37906. mem_file->data_.resize(mem_file->offset_ + size);
  37907. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_writable_mem_file.h b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_writable_mem_file.h
  37908. index 762dd58f0fb41..30e42fdb72a31 100644
  37909. --- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_writable_mem_file.h
  37910. +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_writable_mem_file.h
  37911. @@ -59,7 +59,9 @@ class ZipWritableMemFile {
  37912. // The file function implementations used in the `zlib_filefunc64_def`.
  37913. static voidpf OpenFile(voidpf opaque, const void* filename, int mode);
  37914. static uLong ReadFile(voidpf opaque, voidpf stream, void* buf, uLong size);
  37915. - static uLong WriteFile(voidpf opaque, voidpf stream, const void* buf,
  37916. + static uLong WriteFile(voidpf opaque,
  37917. + voidpf stream,
  37918. + const void* buf,
  37919. uLong size);
  37920. static ZPOS64_T TellFile(voidpf opaque, voidpf stream);
  37921. static long SeekFile // NOLINT
  37922. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/flatbuffers_lib/flatbuffers_lib.cc b/third_party/tflite_support/src/tensorflow_lite_support/metadata/flatbuffers_lib/flatbuffers_lib.cc
  37923. index 6185722504f69..8e00452bea983 100644
  37924. --- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/flatbuffers_lib/flatbuffers_lib.cc
  37925. +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/flatbuffers_lib/flatbuffers_lib.cc
  37926. @@ -14,7 +14,7 @@ limitations under the License.
  37927. ==============================================================================*/
  37928. #include "flatbuffers/flatbuffers.h" // from @flatbuffers
  37929. -#include "flatbuffers/idl.h" // from @flatbuffers
  37930. +#include "flatbuffers/idl.h" // from @flatbuffers
  37931. #include "pybind11/pybind11.h"
  37932. #include "pybind11/pytypes.h"
  37933. #include "pybind11/stl.h"
  37934. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/BoundedInputStream.java b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/BoundedInputStream.java
  37935. index 6c3d23270f3f0..15bcb45c1a4b1 100644
  37936. --- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/BoundedInputStream.java
  37937. +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/BoundedInputStream.java
  37938. @@ -33,84 +33,84 @@ import java.nio.ByteBuffer;
  37939. * synchronized as well.
  37940. */
  37941. final class BoundedInputStream extends InputStream {
  37942. - private final ByteBuffer singleByteBuffer = ByteBuffer.allocate(1);
  37943. - private final long end; // The valid data for the stream is between [start, end).
  37944. - private long position;
  37945. - private final SeekableByteChannelCompat channel;
  37946. -
  37947. - /**
  37948. - * Creates a {@link BoundedInputStream} with a {@link SeekableByteChannelCompat}.
  37949. - *
  37950. - * @param channel the {@link SeekableByteChannelCompat} that backs up this {@link
  37951. - * BoundedInputStream}
  37952. - * @param start the starting position of this {@link BoundedInputStream} in the given {@link
  37953. - * SeekableByteChannelCompat}
  37954. - * @param remaining the length of this {@link BoundedInputStream}
  37955. - * @throws IllegalArgumentException if {@code start} or {@code remaining} is negative
  37956. - */
  37957. - BoundedInputStream(SeekableByteChannelCompat channel, long start, long remaining) {
  37958. - checkArgument(
  37959. - remaining >= 0 && start >= 0,
  37960. - String.format("Invalid length of stream at offset=%d, length=%d", start, remaining));
  37961. -
  37962. - end = start + remaining;
  37963. - this.channel = channel;
  37964. - position = start;
  37965. - }
  37966. -
  37967. - @Override
  37968. - public int available() throws IOException {
  37969. - return (int) (Math.min(end, channel.size()) - position);
  37970. - }
  37971. -
  37972. - @Override
  37973. - public int read() throws IOException {
  37974. - if (position >= end) {
  37975. - return -1;
  37976. + private final ByteBuffer singleByteBuffer = ByteBuffer.allocate(1);
  37977. + private final long end; // The valid data for the stream is between [start, end).
  37978. + private long position;
  37979. + private final SeekableByteChannelCompat channel;
  37980. +
  37981. + /**
  37982. + * Creates a {@link BoundedInputStream} with a {@link SeekableByteChannelCompat}.
  37983. + *
  37984. + * @param channel the {@link SeekableByteChannelCompat} that backs up this {@link
  37985. + * BoundedInputStream}
  37986. + * @param start the starting position of this {@link BoundedInputStream} in the given {@link
  37987. + * SeekableByteChannelCompat}
  37988. + * @param remaining the length of this {@link BoundedInputStream}
  37989. + * @throws IllegalArgumentException if {@code start} or {@code remaining} is negative
  37990. + */
  37991. + BoundedInputStream(SeekableByteChannelCompat channel, long start, long remaining) {
  37992. + checkArgument(remaining >= 0 && start >= 0,
  37993. + String.format(
  37994. + "Invalid length of stream at offset=%d, length=%d", start, remaining));
  37995. +
  37996. + end = start + remaining;
  37997. + this.channel = channel;
  37998. + position = start;
  37999. }
  38000. - singleByteBuffer.rewind();
  38001. - int count = read(position, singleByteBuffer);
  38002. - if (count < 0) {
  38003. - return count;
  38004. + @Override
  38005. + public int available() throws IOException {
  38006. + return (int) (Math.min(end, channel.size()) - position);
  38007. }
  38008. - position++;
  38009. - return singleByteBuffer.get() & 0xff;
  38010. - }
  38011. + @Override
  38012. + public int read() throws IOException {
  38013. + if (position >= end) {
  38014. + return -1;
  38015. + }
  38016. - @Override
  38017. - public int read(byte[] b, int off, int len) throws IOException {
  38018. - checkNotNull(b);
  38019. - checkElementIndex(off, b.length, "The start offset");
  38020. - checkElementIndex(len, b.length - off + 1, "The maximumn number of bytes to read");
  38021. + singleByteBuffer.rewind();
  38022. + int count = read(position, singleByteBuffer);
  38023. + if (count < 0) {
  38024. + return count;
  38025. + }
  38026. - if (len == 0) {
  38027. - return 0;
  38028. + position++;
  38029. + return singleByteBuffer.get() & 0xff;
  38030. }
  38031. - if (len > end - position) {
  38032. - if (position >= end) {
  38033. - return -1;
  38034. - }
  38035. - len = (int) (end - position);
  38036. + @Override
  38037. + public int read(byte[] b, int off, int len) throws IOException {
  38038. + checkNotNull(b);
  38039. + checkElementIndex(off, b.length, "The start offset");
  38040. + checkElementIndex(len, b.length - off + 1, "The maximumn number of bytes to read");
  38041. +
  38042. + if (len == 0) {
  38043. + return 0;
  38044. + }
  38045. +
  38046. + if (len > end - position) {
  38047. + if (position >= end) {
  38048. + return -1;
  38049. + }
  38050. + len = (int) (end - position);
  38051. + }
  38052. +
  38053. + ByteBuffer buf = ByteBuffer.wrap(b, off, len);
  38054. + int count = read(position, buf);
  38055. + if (count > 0) {
  38056. + position += count;
  38057. + }
  38058. + return count;
  38059. }
  38060. - ByteBuffer buf = ByteBuffer.wrap(b, off, len);
  38061. - int count = read(position, buf);
  38062. - if (count > 0) {
  38063. - position += count;
  38064. + private int read(long position, ByteBuffer buf) throws IOException {
  38065. + int count;
  38066. + synchronized (channel) {
  38067. + channel.position(position);
  38068. + count = channel.read(buf);
  38069. + }
  38070. + buf.flip();
  38071. + return count;
  38072. }
  38073. - return count;
  38074. - }
  38075. -
  38076. - private int read(long position, ByteBuffer buf) throws IOException {
  38077. - int count;
  38078. - synchronized (channel) {
  38079. - channel.position(position);
  38080. - count = channel.read(buf);
  38081. - }
  38082. - buf.flip();
  38083. - return count;
  38084. - }
  38085. }
  38086. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ByteBufferChannel.java b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ByteBufferChannel.java
  38087. index e5d54a415edc4..354119b02822e 100644
  38088. --- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ByteBufferChannel.java
  38089. +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ByteBufferChannel.java
  38090. @@ -15,116 +15,114 @@ limitations under the License.
  38091. package org.tensorflow.lite.support.metadata;
  38092. -import static java.lang.Math.min;
  38093. import static org.tensorflow.lite.support.metadata.Preconditions.checkArgument;
  38094. import static org.tensorflow.lite.support.metadata.Preconditions.checkNotNull;
  38095. +import static java.lang.Math.min;
  38096. +
  38097. import java.nio.ByteBuffer;
  38098. import java.nio.channels.NonWritableChannelException;
  38099. /** Implements the {@link SeekableByteChannelCompat} on top of {@link ByteBuffer}. */
  38100. final class ByteBufferChannel implements SeekableByteChannelCompat {
  38101. + /** The ByteBuffer that holds the data. */
  38102. + private final ByteBuffer buffer;
  38103. +
  38104. + /**
  38105. + * Creates a {@link ByteBufferChannel} that wraps a {@link ByteBuffer}.
  38106. + *
  38107. + * @param buffer the {@link ByteBuffer} that backs this {@link ByteBufferChannel}
  38108. + * @throws NullPointerException if {@code buffer} is null
  38109. + */
  38110. + public ByteBufferChannel(ByteBuffer buffer) {
  38111. + checkNotNull(buffer, "The ByteBuffer cannot be null.");
  38112. + this.buffer = buffer;
  38113. + }
  38114. +
  38115. + @Override
  38116. + public void close() {}
  38117. - /** The ByteBuffer that holds the data. */
  38118. - private final ByteBuffer buffer;
  38119. -
  38120. - /**
  38121. - * Creates a {@link ByteBufferChannel} that wraps a {@link ByteBuffer}.
  38122. - *
  38123. - * @param buffer the {@link ByteBuffer} that backs this {@link ByteBufferChannel}
  38124. - * @throws NullPointerException if {@code buffer} is null
  38125. - */
  38126. - public ByteBufferChannel(ByteBuffer buffer) {
  38127. - checkNotNull(buffer, "The ByteBuffer cannot be null.");
  38128. - this.buffer = buffer;
  38129. - }
  38130. -
  38131. - @Override
  38132. - public void close() {}
  38133. -
  38134. - @Override
  38135. - public boolean isOpen() {
  38136. - return true;
  38137. - }
  38138. -
  38139. - @Override
  38140. - public long position() {
  38141. - return buffer.position();
  38142. - }
  38143. -
  38144. - /**
  38145. - * Sets this channel's position.
  38146. - *
  38147. - * @param newPosition the new position, a non-negative integer counting the number of bytes from
  38148. - * the beginning of the entity
  38149. - * @return this channel
  38150. - * @throws IllegalArgumentException if the new position is negative, or greater than the size of
  38151. - * the underlying {@link ByteBuffer}, or greater than Integer.MAX_VALUE
  38152. - */
  38153. - @Override
  38154. - public synchronized ByteBufferChannel position(long newPosition) {
  38155. - checkArgument(
  38156. - (newPosition >= 0 && newPosition <= Integer.MAX_VALUE),
  38157. - "The new position should be non-negative and be less than Integer.MAX_VALUE.");
  38158. - buffer.position((int) newPosition);
  38159. - return this;
  38160. - }
  38161. -
  38162. - /**
  38163. - * {@inheritDoc}
  38164. - *
  38165. - * <p>Bytes are read starting at this channel's current position, and then the position is updated
  38166. - * with the number of bytes actually read. Otherwise this method behaves exactly as specified in
  38167. - * the {@link ReadableByteChannel} interface.
  38168. - */
  38169. - @Override
  38170. - public synchronized int read(ByteBuffer dst) {
  38171. - if (buffer.remaining() == 0) {
  38172. - return -1;
  38173. + @Override
  38174. + public boolean isOpen() {
  38175. + return true;
  38176. }
  38177. - int count = min(dst.remaining(), buffer.remaining());
  38178. - if (count > 0) {
  38179. - ByteBuffer tempBuffer = buffer.slice();
  38180. - tempBuffer.order(buffer.order()).limit(count);
  38181. - dst.put(tempBuffer);
  38182. - buffer.position(buffer.position() + count);
  38183. + @Override
  38184. + public long position() {
  38185. + return buffer.position();
  38186. }
  38187. - return count;
  38188. - }
  38189. -
  38190. - @Override
  38191. - public long size() {
  38192. - return buffer.limit();
  38193. - }
  38194. -
  38195. - @Override
  38196. - public synchronized ByteBufferChannel truncate(long size) {
  38197. - checkArgument(
  38198. - (size >= 0 && size <= Integer.MAX_VALUE),
  38199. - "The new size should be non-negative and be less than Integer.MAX_VALUE.");
  38200. -
  38201. - if (size < buffer.limit()) {
  38202. - buffer.limit((int) size);
  38203. - if (buffer.position() > size) {
  38204. - buffer.position((int) size);
  38205. - }
  38206. +
  38207. + /**
  38208. + * Sets this channel's position.
  38209. + *
  38210. + * @param newPosition the new position, a non-negative integer counting the number of bytes from
  38211. + * the beginning of the entity
  38212. + * @return this channel
  38213. + * @throws IllegalArgumentException if the new position is negative, or greater than the size of
  38214. + * the underlying {@link ByteBuffer}, or greater than Integer.MAX_VALUE
  38215. + */
  38216. + @Override
  38217. + public synchronized ByteBufferChannel position(long newPosition) {
  38218. + checkArgument((newPosition >= 0 && newPosition <= Integer.MAX_VALUE),
  38219. + "The new position should be non-negative and be less than Integer.MAX_VALUE.");
  38220. + buffer.position((int) newPosition);
  38221. + return this;
  38222. + }
  38223. +
  38224. + /**
  38225. + * {@inheritDoc}
  38226. + *
  38227. + * <p>Bytes are read starting at this channel's current position, and then the position is
  38228. + * updated with the number of bytes actually read. Otherwise this method behaves exactly as
  38229. + * specified in the {@link ReadableByteChannel} interface.
  38230. + */
  38231. + @Override
  38232. + public synchronized int read(ByteBuffer dst) {
  38233. + if (buffer.remaining() == 0) {
  38234. + return -1;
  38235. + }
  38236. +
  38237. + int count = min(dst.remaining(), buffer.remaining());
  38238. + if (count > 0) {
  38239. + ByteBuffer tempBuffer = buffer.slice();
  38240. + tempBuffer.order(buffer.order()).limit(count);
  38241. + dst.put(tempBuffer);
  38242. + buffer.position(buffer.position() + count);
  38243. + }
  38244. + return count;
  38245. + }
  38246. +
  38247. + @Override
  38248. + public long size() {
  38249. + return buffer.limit();
  38250. }
  38251. - return this;
  38252. - }
  38253. - @Override
  38254. - public synchronized int write(ByteBuffer src) {
  38255. - if (buffer.isReadOnly()) {
  38256. - throw new NonWritableChannelException();
  38257. + @Override
  38258. + public synchronized ByteBufferChannel truncate(long size) {
  38259. + checkArgument((size >= 0 && size <= Integer.MAX_VALUE),
  38260. + "The new size should be non-negative and be less than Integer.MAX_VALUE.");
  38261. +
  38262. + if (size < buffer.limit()) {
  38263. + buffer.limit((int) size);
  38264. + if (buffer.position() > size) {
  38265. + buffer.position((int) size);
  38266. + }
  38267. + }
  38268. + return this;
  38269. }
  38270. - int count = min(src.remaining(), buffer.remaining());
  38271. - if (count > 0) {
  38272. - ByteBuffer tempBuffer = src.slice();
  38273. - tempBuffer.order(buffer.order()).limit(count);
  38274. - buffer.put(tempBuffer);
  38275. + @Override
  38276. + public synchronized int write(ByteBuffer src) {
  38277. + if (buffer.isReadOnly()) {
  38278. + throw new NonWritableChannelException();
  38279. + }
  38280. +
  38281. + int count = min(src.remaining(), buffer.remaining());
  38282. + if (count > 0) {
  38283. + ByteBuffer tempBuffer = src.slice();
  38284. + tempBuffer.order(buffer.order()).limit(count);
  38285. + buffer.put(tempBuffer);
  38286. + }
  38287. + return count;
  38288. }
  38289. - return count;
  38290. - }
  38291. }
  38292. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/MetadataExtractor.java b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/MetadataExtractor.java
  38293. index 183d416481156..3fb3c48118748 100644
  38294. --- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/MetadataExtractor.java
  38295. +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/MetadataExtractor.java
  38296. @@ -17,15 +17,16 @@ package org.tensorflow.lite.support.metadata;
  38297. import static org.tensorflow.lite.support.metadata.Preconditions.checkArgument;
  38298. +import org.checkerframework.checker.nullness.qual.Nullable;
  38299. +import org.tensorflow.lite.schema.Tensor;
  38300. +import org.tensorflow.lite.support.metadata.schema.ModelMetadata;
  38301. +import org.tensorflow.lite.support.metadata.schema.TensorMetadata;
  38302. +
  38303. import java.io.IOException;
  38304. import java.io.InputStream;
  38305. import java.nio.ByteBuffer;
  38306. import java.util.Set;
  38307. import java.util.zip.ZipException;
  38308. -import org.checkerframework.checker.nullness.qual.Nullable;
  38309. -import org.tensorflow.lite.schema.Tensor;
  38310. -import org.tensorflow.lite.support.metadata.schema.ModelMetadata;
  38311. -import org.tensorflow.lite.support.metadata.schema.TensorMetadata;
  38312. /**
  38313. * Loads metadata from TFLite Model FlatBuffer.
  38314. @@ -53,328 +54,329 @@ import org.tensorflow.lite.support.metadata.schema.TensorMetadata;
  38315. * MetadataExtractor} omits subgraph index as an input in its methods.
  38316. */
  38317. public class MetadataExtractor {
  38318. + /** The helper class to load metadata from TFLite model FlatBuffer. */
  38319. + private final ModelInfo modelInfo;
  38320. +
  38321. + /** The helper class to load metadata from TFLite metadata FlatBuffer. */
  38322. + @Nullable
  38323. + private final ModelMetadataInfo metadataInfo;
  38324. +
  38325. + /** The handler to load associated files through zip. */
  38326. + @Nullable
  38327. + private final ZipFile zipFile;
  38328. +
  38329. + /**
  38330. + * Creates a {@link MetadataExtractor} with TFLite model FlatBuffer.
  38331. + *
  38332. + * @param buffer the TFLite model FlatBuffer
  38333. + * @throws IllegalArgumentException if the number of input or output tensors in the model does
  38334. + * not
  38335. + * match that in the metadata
  38336. + * @throws IOException if an error occurs while reading the model as a Zip file
  38337. + */
  38338. + public MetadataExtractor(ByteBuffer buffer) throws IOException {
  38339. + modelInfo = new ModelInfo(buffer);
  38340. + ByteBuffer metadataBuffer = modelInfo.getMetadataBuffer();
  38341. + if (metadataBuffer != null) {
  38342. + metadataInfo = new ModelMetadataInfo(metadataBuffer);
  38343. +
  38344. + // Prints warning message if the minimum parser version is not satisfied.
  38345. + if (!isMinimumParserVersionSatisfied()) {
  38346. + System.err.printf(
  38347. + "<Warning> Some fields in the metadata belong to a future schema. The minimum parser"
  38348. + + " version required is %s, but the version of the current metadata parser is %s",
  38349. + metadataInfo.getMininumParserVersion(), MetadataParser.VERSION);
  38350. + }
  38351. +
  38352. + checkArgument(modelInfo.getInputTensorCount() == metadataInfo.getInputTensorCount(),
  38353. + String.format(
  38354. + "The number of input tensors in the model is %d. The number of input tensors that"
  38355. + + " recorded in the metadata is %d. These two values does not match.",
  38356. + modelInfo.getInputTensorCount(), metadataInfo.getInputTensorCount()));
  38357. + checkArgument(modelInfo.getOutputTensorCount() == metadataInfo.getOutputTensorCount(),
  38358. + String.format(
  38359. + "The number of output tensors in the model is %d. The number of output tensors that"
  38360. + + " recorded in the metadata is %d. These two values does not match.",
  38361. + modelInfo.getOutputTensorCount(), metadataInfo.getOutputTensorCount()));
  38362. + } else {
  38363. + // It is allowed to pass in a model FlatBuffer without TFLite metadata. However,
  38364. + // invoking methods that read from TFLite metadata will cause runtime errors.
  38365. + metadataInfo = null;
  38366. + }
  38367. +
  38368. + zipFile = createZipFile(buffer);
  38369. + }
  38370. - /** The helper class to load metadata from TFLite model FlatBuffer. */
  38371. - private final ModelInfo modelInfo;
  38372. -
  38373. - /** The helper class to load metadata from TFLite metadata FlatBuffer. */
  38374. - @Nullable private final ModelMetadataInfo metadataInfo;
  38375. -
  38376. - /** The handler to load associated files through zip. */
  38377. - @Nullable private final ZipFile zipFile;
  38378. -
  38379. - /**
  38380. - * Creates a {@link MetadataExtractor} with TFLite model FlatBuffer.
  38381. - *
  38382. - * @param buffer the TFLite model FlatBuffer
  38383. - * @throws IllegalArgumentException if the number of input or output tensors in the model does not
  38384. - * match that in the metadata
  38385. - * @throws IOException if an error occurs while reading the model as a Zip file
  38386. - */
  38387. - public MetadataExtractor(ByteBuffer buffer) throws IOException {
  38388. - modelInfo = new ModelInfo(buffer);
  38389. - ByteBuffer metadataBuffer = modelInfo.getMetadataBuffer();
  38390. - if (metadataBuffer != null) {
  38391. - metadataInfo = new ModelMetadataInfo(metadataBuffer);
  38392. -
  38393. - // Prints warning message if the minimum parser version is not satisfied.
  38394. - if (!isMinimumParserVersionSatisfied()) {
  38395. - System.err.printf(
  38396. - "<Warning> Some fields in the metadata belong to a future schema. The minimum parser"
  38397. - + " version required is %s, but the version of the current metadata parser is %s",
  38398. - metadataInfo.getMininumParserVersion(), MetadataParser.VERSION);
  38399. - }
  38400. -
  38401. - checkArgument(
  38402. - modelInfo.getInputTensorCount() == metadataInfo.getInputTensorCount(),
  38403. - String.format(
  38404. - "The number of input tensors in the model is %d. The number of input tensors that"
  38405. - + " recorded in the metadata is %d. These two values does not match.",
  38406. - modelInfo.getInputTensorCount(), metadataInfo.getInputTensorCount()));
  38407. - checkArgument(
  38408. - modelInfo.getOutputTensorCount() == metadataInfo.getOutputTensorCount(),
  38409. - String.format(
  38410. - "The number of output tensors in the model is %d. The number of output tensors that"
  38411. - + " recorded in the metadata is %d. These two values does not match.",
  38412. - modelInfo.getOutputTensorCount(), metadataInfo.getOutputTensorCount()));
  38413. - } else {
  38414. - // It is allowed to pass in a model FlatBuffer without TFLite metadata. However, invoking
  38415. - // methods that read from TFLite metadata will cause runtime errors.
  38416. - metadataInfo = null;
  38417. + /**
  38418. + * Quantization parameters that corresponds to the table, {@code QuantizationParameters}, in the
  38419. + * <a
  38420. + * href="https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/schema/schema.fbs">TFLite
  38421. + * Model schema file.</a>
  38422. + *
  38423. + * <p>Since per-channel quantization does not apply to input and output tensors, {@code scale}
  38424. + * and
  38425. + * {@code zero_point} are both single values instead of arrays.
  38426. + *
  38427. + * <p>For tensor that are not quantized, the values of scale and zero_point are both 0.
  38428. + *
  38429. + * <p>Given a quantized value q, the corresponding float value f should be: <br>
  38430. + * f = scale * (q - zero_point) <br>
  38431. + */
  38432. + public static class QuantizationParams {
  38433. + /** The scale value used in quantization. */
  38434. + private final float scale;
  38435. + /** The zero point value used in quantization. */
  38436. + private final int zeroPoint;
  38437. +
  38438. + /**
  38439. + * Creates a {@link QuantizationParams} with {@code scale} and {@code zero_point}.
  38440. + *
  38441. + * @param scale The scale value used in quantization.
  38442. + * @param zeroPoint The zero point value used in quantization.
  38443. + */
  38444. + public QuantizationParams(final float scale, final int zeroPoint) {
  38445. + this.scale = scale;
  38446. + this.zeroPoint = zeroPoint;
  38447. + }
  38448. +
  38449. + /** Returns the scale value. */
  38450. + public float getScale() {
  38451. + return scale;
  38452. + }
  38453. +
  38454. + /** Returns the zero point value. */
  38455. + public int getZeroPoint() {
  38456. + return zeroPoint;
  38457. + }
  38458. }
  38459. - zipFile = createZipFile(buffer);
  38460. - }
  38461. -
  38462. - /**
  38463. - * Quantization parameters that corresponds to the table, {@code QuantizationParameters}, in the
  38464. - * <a
  38465. - * href="https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/schema/schema.fbs">TFLite
  38466. - * Model schema file.</a>
  38467. - *
  38468. - * <p>Since per-channel quantization does not apply to input and output tensors, {@code scale} and
  38469. - * {@code zero_point} are both single values instead of arrays.
  38470. - *
  38471. - * <p>For tensor that are not quantized, the values of scale and zero_point are both 0.
  38472. - *
  38473. - * <p>Given a quantized value q, the corresponding float value f should be: <br>
  38474. - * f = scale * (q - zero_point) <br>
  38475. - */
  38476. - public static class QuantizationParams {
  38477. - /** The scale value used in quantization. */
  38478. - private final float scale;
  38479. - /** The zero point value used in quantization. */
  38480. - private final int zeroPoint;
  38481. + /** Returns {@code true} if the model has metadata. Otherwise, returns {@code false}. */
  38482. + public boolean hasMetadata() {
  38483. + return metadataInfo != null;
  38484. + }
  38485. /**
  38486. - * Creates a {@link QuantizationParams} with {@code scale} and {@code zero_point}.
  38487. + * Gets the packed associated file with the specified {@code fileName}.
  38488. *
  38489. - * @param scale The scale value used in quantization.
  38490. - * @param zeroPoint The zero point value used in quantization.
  38491. + * @param fileName the name of the associated file
  38492. + * @return the raw input stream containing specified file
  38493. + * @throws IllegalStateException if the model is not a zip file
  38494. + * @throws IllegalArgumentException if the specified file does not exist in the model
  38495. */
  38496. - public QuantizationParams(final float scale, final int zeroPoint) {
  38497. - this.scale = scale;
  38498. - this.zeroPoint = zeroPoint;
  38499. + public InputStream getAssociatedFile(String fileName) {
  38500. + assertZipFile();
  38501. + return zipFile.getRawInputStream(fileName);
  38502. }
  38503. - /** Returns the scale value. */
  38504. - public float getScale() {
  38505. - return scale;
  38506. + /**
  38507. + * Gets the file names of the associated files.
  38508. + *
  38509. + * @return the file names of the associated files
  38510. + * @throws IllegalStateException if the model is not a zip file
  38511. + */
  38512. + public Set<String> getAssociatedFileNames() {
  38513. + assertZipFile();
  38514. + return zipFile.getFileNames();
  38515. }
  38516. - /** Returns the zero point value. */
  38517. - public int getZeroPoint() {
  38518. - return zeroPoint;
  38519. + /** Gets the count of input tensors in the model. */
  38520. + public int getInputTensorCount() {
  38521. + return modelInfo.getInputTensorCount();
  38522. }
  38523. - }
  38524. -
  38525. - /** Returns {@code true} if the model has metadata. Otherwise, returns {@code false}. */
  38526. - public boolean hasMetadata() {
  38527. - return metadataInfo != null;
  38528. - }
  38529. -
  38530. - /**
  38531. - * Gets the packed associated file with the specified {@code fileName}.
  38532. - *
  38533. - * @param fileName the name of the associated file
  38534. - * @return the raw input stream containing specified file
  38535. - * @throws IllegalStateException if the model is not a zip file
  38536. - * @throws IllegalArgumentException if the specified file does not exist in the model
  38537. - */
  38538. - public InputStream getAssociatedFile(String fileName) {
  38539. - assertZipFile();
  38540. - return zipFile.getRawInputStream(fileName);
  38541. - }
  38542. -
  38543. - /**
  38544. - * Gets the file names of the associated files.
  38545. - *
  38546. - * @return the file names of the associated files
  38547. - * @throws IllegalStateException if the model is not a zip file
  38548. - */
  38549. - public Set<String> getAssociatedFileNames() {
  38550. - assertZipFile();
  38551. - return zipFile.getFileNames();
  38552. - }
  38553. -
  38554. - /** Gets the count of input tensors in the model. */
  38555. - public int getInputTensorCount() {
  38556. - return modelInfo.getInputTensorCount();
  38557. - }
  38558. -
  38559. - /**
  38560. - * Gets the metadata for the input tensor specified by {@code inputIndex}.
  38561. - *
  38562. - * @param inputIndex the index of the desired input tensor
  38563. - * @throws IllegalStateException if this model does not contain model metadata
  38564. - */
  38565. - @Nullable
  38566. - public TensorMetadata getInputTensorMetadata(int inputIndex) {
  38567. - assertMetadataInfo();
  38568. - return metadataInfo.getInputTensorMetadata(inputIndex);
  38569. - }
  38570. -
  38571. - /**
  38572. - * Gets the quantization parameters for the input tensor specified by {@code inputIndex}.
  38573. - *
  38574. - * @param inputIndex the index of the desired input tensor
  38575. - */
  38576. - public QuantizationParams getInputTensorQuantizationParams(int inputIndex) {
  38577. - Tensor tensor = modelInfo.getInputTensor(inputIndex);
  38578. - return modelInfo.getQuantizationParams(tensor);
  38579. - }
  38580. -
  38581. - /**
  38582. - * Gets the shape of the input tensor with {@code inputIndex}.
  38583. - *
  38584. - * @param inputIndex the index of the desired input tensor
  38585. - */
  38586. - public int[] getInputTensorShape(int inputIndex) {
  38587. - return modelInfo.getInputTensorShape(inputIndex);
  38588. - }
  38589. -
  38590. - /**
  38591. - * Gets the {@link TensorType} of the input tensor with {@code inputIndex}.
  38592. - *
  38593. - * @param inputIndex the index of the desired input tensor
  38594. - */
  38595. - public byte getInputTensorType(int inputIndex) {
  38596. - return modelInfo.getInputTensorType(inputIndex);
  38597. - }
  38598. -
  38599. - /**
  38600. - * Gets the root handler for the model metadata.
  38601. - *
  38602. - * @throws IllegalStateException if this model does not contain model metadata
  38603. - */
  38604. - public ModelMetadata getModelMetadata() {
  38605. - assertMetadataInfo();
  38606. - return metadataInfo.getModelMetadata();
  38607. - }
  38608. -
  38609. - /** Gets the count of output tensors in the model. */
  38610. - public int getOutputTensorCount() {
  38611. - return modelInfo.getOutputTensorCount();
  38612. - }
  38613. -
  38614. - /**
  38615. - * Gets the metadata for the output tensor specified by {@code outputIndex}.
  38616. - *
  38617. - * @param outputIndex the index of the desired output tensor
  38618. - * @throws IllegalStateException if this model does not contain model metadata
  38619. - */
  38620. - @Nullable
  38621. - public TensorMetadata getOutputTensorMetadata(int outputIndex) {
  38622. - assertMetadataInfo();
  38623. - return metadataInfo.getOutputTensorMetadata(outputIndex);
  38624. - }
  38625. -
  38626. - /**
  38627. - * Gets the quantization parameters for the output tensor specified by {@code outputIndex}.
  38628. - *
  38629. - * @param outputIndex the index of the desired output tensor
  38630. - */
  38631. - public QuantizationParams getOutputTensorQuantizationParams(int outputIndex) {
  38632. - Tensor tensor = modelInfo.getOutputTensor(outputIndex);
  38633. - return modelInfo.getQuantizationParams(tensor);
  38634. - }
  38635. -
  38636. - /**
  38637. - * Gets the shape of the output tensor with {@code outputIndex}.
  38638. - *
  38639. - * @param outputIndex the index of the desired output tensor
  38640. - */
  38641. - public int[] getOutputTensorShape(int outputIndex) {
  38642. - return modelInfo.getOutputTensorShape(outputIndex);
  38643. - }
  38644. -
  38645. - /**
  38646. - * Gets the {@link TensorType} of the output tensor with {@code outputIndex}.
  38647. - *
  38648. - * @param outputIndex the index of the desired output tensor
  38649. - */
  38650. - public byte getOutputTensorType(int outputIndex) {
  38651. - return modelInfo.getOutputTensorType(outputIndex);
  38652. - }
  38653. -
  38654. - /**
  38655. - * Returns {@code true} if the minimum parser version required by the given metadata flatbuffer
  38656. - * precedes or equals to the version of the metadata parser that this MetadataExtractor library is
  38657. - * relying on. All fields in the metadata can be parsed correctly with this metadata extractor
  38658. - * library in this case. Otherwise, it returns {@code false}.
  38659. - *
  38660. - * <p>For example, assume the underlying metadata parser version is {@code 1.14.1},
  38661. - *
  38662. - * <ul>
  38663. - * <li>it returns {@code true}, if the required minimum parser version is the same or older,
  38664. - * such as {@code 1.14.1} or {@code 1.14.0}. Null version precedes all numeric versions,
  38665. - * because some metadata flatbuffers are generated before the first versioned release; <br>
  38666. - * <li>it returns {@code false}, if the required minimum parser version is newer, such as {@code
  38667. - * 1.14.2}.
  38668. - * </ul>
  38669. - */
  38670. - public final boolean isMinimumParserVersionSatisfied() {
  38671. - String minVersion = metadataInfo.getMininumParserVersion();
  38672. - if (minVersion == null) {
  38673. - return true;
  38674. +
  38675. + /**
  38676. + * Gets the metadata for the input tensor specified by {@code inputIndex}.
  38677. + *
  38678. + * @param inputIndex the index of the desired input tensor
  38679. + * @throws IllegalStateException if this model does not contain model metadata
  38680. + */
  38681. + @Nullable
  38682. + public TensorMetadata getInputTensorMetadata(int inputIndex) {
  38683. + assertMetadataInfo();
  38684. + return metadataInfo.getInputTensorMetadata(inputIndex);
  38685. }
  38686. - return compareVersions(minVersion, MetadataParser.VERSION) <= 0;
  38687. - }
  38688. -
  38689. - /**
  38690. - * Asserts if {@link #metadataInfo} is not initialized. Some models may not have metadata and this
  38691. - * is allowed. However, invoking methods that reads the metadata is not allowed.
  38692. - *
  38693. - * @throws IllegalStateException if this model does not contain model metadata
  38694. - */
  38695. - private void assertMetadataInfo() {
  38696. - if (metadataInfo == null) {
  38697. - throw new IllegalStateException("This model does not contain model metadata.");
  38698. +
  38699. + /**
  38700. + * Gets the quantization parameters for the input tensor specified by {@code inputIndex}.
  38701. + *
  38702. + * @param inputIndex the index of the desired input tensor
  38703. + */
  38704. + public QuantizationParams getInputTensorQuantizationParams(int inputIndex) {
  38705. + Tensor tensor = modelInfo.getInputTensor(inputIndex);
  38706. + return modelInfo.getQuantizationParams(tensor);
  38707. }
  38708. - }
  38709. -
  38710. - /**
  38711. - * Asserts if {@link #zipFile} is not initialized. Some models may not have associated files, thus
  38712. - * are not Zip files. This is allowed. However, invoking methods that reads those associated files
  38713. - * is not allowed.
  38714. - *
  38715. - * @throws IllegalStateException if this model is not a Zip file
  38716. - */
  38717. - private void assertZipFile() {
  38718. - if (zipFile == null) {
  38719. - throw new IllegalStateException(
  38720. - "This model does not contain associated files, and is not a Zip file.");
  38721. +
  38722. + /**
  38723. + * Gets the shape of the input tensor with {@code inputIndex}.
  38724. + *
  38725. + * @param inputIndex the index of the desired input tensor
  38726. + */
  38727. + public int[] getInputTensorShape(int inputIndex) {
  38728. + return modelInfo.getInputTensorShape(inputIndex);
  38729. }
  38730. - }
  38731. -
  38732. - /**
  38733. - * Creates a Zip file handler to read the associated files. If the model is not a zip file, i.e.
  38734. - * it does not have associated files, return a null handler.
  38735. - *
  38736. - * @param buffer the TFLite model FlatBuffer
  38737. - * @throws IOException if an error occurs while reading the model as a Zip file
  38738. - */
  38739. - @Nullable
  38740. - private static ZipFile createZipFile(ByteBuffer buffer) throws IOException {
  38741. - try {
  38742. - // Creates the handler to hold the associated files through the Zip.
  38743. - ByteBufferChannel byteBufferChannel = new ByteBufferChannel(buffer);
  38744. - return ZipFile.createFrom(byteBufferChannel);
  38745. - } catch (ZipException e) {
  38746. - // Some models may not have associate files. Therefore, Those models are not zip files.
  38747. - // However, invoking methods that read associated files later will lead into errors.
  38748. - return null;
  38749. +
  38750. + /**
  38751. + * Gets the {@link TensorType} of the input tensor with {@code inputIndex}.
  38752. + *
  38753. + * @param inputIndex the index of the desired input tensor
  38754. + */
  38755. + public byte getInputTensorType(int inputIndex) {
  38756. + return modelInfo.getInputTensorType(inputIndex);
  38757. }
  38758. - }
  38759. -
  38760. - /**
  38761. - * Compares two semantic version numbers.
  38762. - *
  38763. - * <p>Examples of comparing two versions: <br>
  38764. - * {@code 1.9} precedes {@code 1.14}; <br>
  38765. - * {@code 1.14} precedes {@code 1.14.1}; <br>
  38766. - * {@code 1.14} and {@code 1.14.0} are euqal;
  38767. - *
  38768. - * @return the value {@code 0} if the two versions are equal; a value less than {@code 0} if
  38769. - * {@code version1} precedes {@code version2}; a value greater than {@code 0} if {@code
  38770. - * version2} precedes {@code version1}.
  38771. - */
  38772. - private static int compareVersions(String version1, String version2) {
  38773. - // Using String.split instead of the recommanded Guava Splitter because we've been avoiding
  38774. - // depending on other third party libraries in this project.
  38775. - String[] levels1 = version1.split("\\.", 0);
  38776. - String[] levels2 = version2.split("\\.", 0);
  38777. -
  38778. - int length = Math.max(levels1.length, levels2.length);
  38779. - for (int i = 0; i < length; i++) {
  38780. - Integer v1 = i < levels1.length ? Integer.parseInt(levels1[i]) : 0;
  38781. - Integer v2 = i < levels2.length ? Integer.parseInt(levels2[i]) : 0;
  38782. - int compare = v1.compareTo(v2);
  38783. - if (compare != 0) {
  38784. - return compare;
  38785. - }
  38786. +
  38787. + /**
  38788. + * Gets the root handler for the model metadata.
  38789. + *
  38790. + * @throws IllegalStateException if this model does not contain model metadata
  38791. + */
  38792. + public ModelMetadata getModelMetadata() {
  38793. + assertMetadataInfo();
  38794. + return metadataInfo.getModelMetadata();
  38795. + }
  38796. +
  38797. + /** Gets the count of output tensors in the model. */
  38798. + public int getOutputTensorCount() {
  38799. + return modelInfo.getOutputTensorCount();
  38800. }
  38801. - return 0;
  38802. - }
  38803. + /**
  38804. + * Gets the metadata for the output tensor specified by {@code outputIndex}.
  38805. + *
  38806. + * @param outputIndex the index of the desired output tensor
  38807. + * @throws IllegalStateException if this model does not contain model metadata
  38808. + */
  38809. + @Nullable
  38810. + public TensorMetadata getOutputTensorMetadata(int outputIndex) {
  38811. + assertMetadataInfo();
  38812. + return metadataInfo.getOutputTensorMetadata(outputIndex);
  38813. + }
  38814. +
  38815. + /**
  38816. + * Gets the quantization parameters for the output tensor specified by {@code outputIndex}.
  38817. + *
  38818. + * @param outputIndex the index of the desired output tensor
  38819. + */
  38820. + public QuantizationParams getOutputTensorQuantizationParams(int outputIndex) {
  38821. + Tensor tensor = modelInfo.getOutputTensor(outputIndex);
  38822. + return modelInfo.getQuantizationParams(tensor);
  38823. + }
  38824. +
  38825. + /**
  38826. + * Gets the shape of the output tensor with {@code outputIndex}.
  38827. + *
  38828. + * @param outputIndex the index of the desired output tensor
  38829. + */
  38830. + public int[] getOutputTensorShape(int outputIndex) {
  38831. + return modelInfo.getOutputTensorShape(outputIndex);
  38832. + }
  38833. +
  38834. + /**
  38835. + * Gets the {@link TensorType} of the output tensor with {@code outputIndex}.
  38836. + *
  38837. + * @param outputIndex the index of the desired output tensor
  38838. + */
  38839. + public byte getOutputTensorType(int outputIndex) {
  38840. + return modelInfo.getOutputTensorType(outputIndex);
  38841. + }
  38842. +
  38843. + /**
  38844. + * Returns {@code true} if the minimum parser version required by the given metadata flatbuffer
  38845. + * precedes or equals to the version of the metadata parser that this MetadataExtractor library
  38846. + * is relying on. All fields in the metadata can be parsed correctly with this metadata
  38847. + * extractor library in this case. Otherwise, it returns {@code false}.
  38848. + *
  38849. + * <p>For example, assume the underlying metadata parser version is {@code 1.14.1},
  38850. + *
  38851. + * <ul>
  38852. + * <li>it returns {@code true}, if the required minimum parser version is the same or older,
  38853. + * such as {@code 1.14.1} or {@code 1.14.0}. Null version precedes all numeric versions,
  38854. + * because some metadata flatbuffers are generated before the first versioned release;
  38855. + * <br> <li>it returns {@code false}, if the required minimum parser version is newer, such as
  38856. + * {@code 1.14.2}.
  38857. + * </ul>
  38858. + */
  38859. + public final boolean isMinimumParserVersionSatisfied() {
  38860. + String minVersion = metadataInfo.getMininumParserVersion();
  38861. + if (minVersion == null) {
  38862. + return true;
  38863. + }
  38864. + return compareVersions(minVersion, MetadataParser.VERSION) <= 0;
  38865. + }
  38866. +
  38867. + /**
  38868. + * Asserts if {@link #metadataInfo} is not initialized. Some models may not have metadata and
  38869. + * this is allowed. However, invoking methods that reads the metadata is not allowed.
  38870. + *
  38871. + * @throws IllegalStateException if this model does not contain model metadata
  38872. + */
  38873. + private void assertMetadataInfo() {
  38874. + if (metadataInfo == null) {
  38875. + throw new IllegalStateException("This model does not contain model metadata.");
  38876. + }
  38877. + }
  38878. +
  38879. + /**
  38880. + * Asserts if {@link #zipFile} is not initialized. Some models may not have associated files,
  38881. + * thus are not Zip files. This is allowed. However, invoking methods that reads those
  38882. + * associated files is not allowed.
  38883. + *
  38884. + * @throws IllegalStateException if this model is not a Zip file
  38885. + */
  38886. + private void assertZipFile() {
  38887. + if (zipFile == null) {
  38888. + throw new IllegalStateException(
  38889. + "This model does not contain associated files, and is not a Zip file.");
  38890. + }
  38891. + }
  38892. +
  38893. + /**
  38894. + * Creates a Zip file handler to read the associated files. If the model is not a zip file, i.e.
  38895. + * it does not have associated files, return a null handler.
  38896. + *
  38897. + * @param buffer the TFLite model FlatBuffer
  38898. + * @throws IOException if an error occurs while reading the model as a Zip file
  38899. + */
  38900. + @Nullable
  38901. + private static ZipFile createZipFile(ByteBuffer buffer) throws IOException {
  38902. + try {
  38903. + // Creates the handler to hold the associated files through the Zip.
  38904. + ByteBufferChannel byteBufferChannel = new ByteBufferChannel(buffer);
  38905. + return ZipFile.createFrom(byteBufferChannel);
  38906. + } catch (ZipException e) {
  38907. + // Some models may not have associate files. Therefore, Those models are not zip files.
  38908. + // However, invoking methods that read associated files later will lead into errors.
  38909. + return null;
  38910. + }
  38911. + }
  38912. +
  38913. + /**
  38914. + * Compares two semantic version numbers.
  38915. + *
  38916. + * <p>Examples of comparing two versions: <br>
  38917. + * {@code 1.9} precedes {@code 1.14}; <br>
  38918. + * {@code 1.14} precedes {@code 1.14.1}; <br>
  38919. + * {@code 1.14} and {@code 1.14.0} are euqal;
  38920. + *
  38921. + * @return the value {@code 0} if the two versions are equal; a value less than {@code 0} if
  38922. + * {@code version1} precedes {@code version2}; a value greater than {@code 0} if {@code
  38923. + * version2} precedes {@code version1}.
  38924. + */
  38925. + private static int compareVersions(String version1, String version2) {
  38926. + // Using String.split instead of the recommanded Guava Splitter because we've been avoiding
  38927. + // depending on other third party libraries in this project.
  38928. + String[] levels1 = version1.split("\\.", 0);
  38929. + String[] levels2 = version2.split("\\.", 0);
  38930. +
  38931. + int length = Math.max(levels1.length, levels2.length);
  38932. + for (int i = 0; i < length; i++) {
  38933. + Integer v1 = i < levels1.length ? Integer.parseInt(levels1[i]) : 0;
  38934. + Integer v2 = i < levels2.length ? Integer.parseInt(levels2[i]) : 0;
  38935. + int compare = v1.compareTo(v2);
  38936. + if (compare != 0) {
  38937. + return compare;
  38938. + }
  38939. + }
  38940. +
  38941. + return 0;
  38942. + }
  38943. }
  38944. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/MetadataParser.java b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/MetadataParser.java
  38945. index 8a262a02eab14..1dbf9ebb46386 100644
  38946. --- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/MetadataParser.java
  38947. +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/MetadataParser.java
  38948. @@ -17,11 +17,11 @@ package org.tensorflow.lite.support.metadata;
  38949. /** Information about the metadata parser that this metadata extractor library is depending on. */
  38950. public final class MetadataParser {
  38951. - /**
  38952. - * The version of the metadata parser that this metadata extractor library is depending on. The
  38953. - * value should match the value of "Schema Semantic version" in metadata_schema.fbs.
  38954. - */
  38955. - public static final String VERSION = "1.4.0";
  38956. + /**
  38957. + * The version of the metadata parser that this metadata extractor library is depending on. The
  38958. + * value should match the value of "Schema Semantic version" in metadata_schema.fbs.
  38959. + */
  38960. + public static final String VERSION = "1.4.0";
  38961. - private MetadataParser() {}
  38962. + private MetadataParser() {}
  38963. }
  38964. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ModelInfo.java b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ModelInfo.java
  38965. index 309a3dbe77470..863ab83e306fb 100644
  38966. --- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ModelInfo.java
  38967. +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ModelInfo.java
  38968. @@ -18,10 +18,6 @@ package org.tensorflow.lite.support.metadata;
  38969. import static org.tensorflow.lite.support.metadata.Preconditions.checkArgument;
  38970. import static org.tensorflow.lite.support.metadata.Preconditions.checkNotNull;
  38971. -import java.nio.ByteBuffer;
  38972. -import java.util.ArrayList;
  38973. -import java.util.Collections;
  38974. -import java.util.List;
  38975. import org.checkerframework.checker.nullness.qual.Nullable;
  38976. import org.tensorflow.lite.schema.Buffer;
  38977. import org.tensorflow.lite.schema.Metadata;
  38978. @@ -32,235 +28,237 @@ import org.tensorflow.lite.schema.Tensor;
  38979. import org.tensorflow.lite.schema.TensorType;
  38980. import org.tensorflow.lite.support.metadata.MetadataExtractor.QuantizationParams;
  38981. +import java.nio.ByteBuffer;
  38982. +import java.util.ArrayList;
  38983. +import java.util.Collections;
  38984. +import java.util.List;
  38985. +
  38986. /** Extracts model information out of TFLite model FLatBuffer. */
  38987. final class ModelInfo {
  38988. - /** The model that is loaded from TFLite model FlatBuffer. */
  38989. - private final Model model;
  38990. -
  38991. - /** A list of input tensors. */
  38992. - private final List</* @Nullable */ Tensor> inputTensors;
  38993. -
  38994. - /** A list of output tensors. */
  38995. - private final List</* @Nullable */ Tensor> outputTensors;
  38996. -
  38997. - /** Identifier of the TFLite model metadata in the Metadata array. */
  38998. - static final String METADATA_FIELD_NAME = "TFLITE_METADATA";
  38999. -
  39000. - /**
  39001. - * Creates a {@link ModelInfo} with the model FlatBuffer, {@code buffer}.
  39002. - *
  39003. - * <p>Though TFLite model FlatBuffer supports multiple subgraphs, TFLite Interpreter only supports
  39004. - * single subgraph so far. See the <a
  39005. - * href="https://www.tensorflow.org/lite/convert/cmdline_examples#specifying_subgraphs">instruction
  39006. - * of how to specify subgraph during convertion for more information.</a> Therefore, all methods
  39007. - * in {@link ModelInfo} retrieves metadata of the first subgrpah as default.
  39008. - *
  39009. - * @param buffer the TFLite model FlatBuffer
  39010. - * @throws NullPointerException if {@code buffer} is null
  39011. - * @throws IllegalArgumentException if the model does not contain any subgraph, or the model does
  39012. - * not contain the expected identifier
  39013. - */
  39014. - ModelInfo(ByteBuffer buffer) {
  39015. - assertTFLiteModel(buffer);
  39016. -
  39017. - model = Model.getRootAsModel(buffer);
  39018. - checkArgument(model.subgraphsLength() > 0, "The model does not contain any subgraph.");
  39019. -
  39020. - inputTensors = getInputTensors(model);
  39021. - outputTensors = getOutputTensors(model);
  39022. - }
  39023. -
  39024. - /**
  39025. - * Gets the input tensor with {@code inputIndex}.
  39026. - *
  39027. - * @param inputIndex The index of the desired input tensor.
  39028. - * @throws IllegalArgumentException if the inputIndex specified is invalid.
  39029. - */
  39030. - @Nullable
  39031. - Tensor getInputTensor(int inputIndex) {
  39032. - checkArgument(
  39033. - inputIndex >= 0 && inputIndex < inputTensors.size(),
  39034. - "The inputIndex specified is invalid.");
  39035. - return inputTensors.get(inputIndex);
  39036. - }
  39037. -
  39038. - int getInputTensorCount() {
  39039. - return inputTensors.size();
  39040. - }
  39041. -
  39042. - /**
  39043. - * Gets shape of the input tensor with {@code inputIndex}.
  39044. - *
  39045. - * @param inputIndex The index of the desired intput tensor.
  39046. - */
  39047. - int[] getInputTensorShape(int inputIndex) {
  39048. - Tensor tensor = getInputTensor(inputIndex);
  39049. - return getShape(tensor);
  39050. - }
  39051. -
  39052. - /**
  39053. - * Gets the {@link TensorType} in byte of the input tensor with {@code inputIndex}.
  39054. - *
  39055. - * @param inputIndex The index of the desired intput tensor.
  39056. - */
  39057. - byte getInputTensorType(int inputIndex) {
  39058. - return getInputTensor(inputIndex).type();
  39059. - }
  39060. -
  39061. - /** Gets the metadata FlatBuffer from the model FlatBuffer. */
  39062. - @Nullable
  39063. - ByteBuffer getMetadataBuffer() {
  39064. - // Some models may not have metadata, and this is allowed.
  39065. - if (model.metadataLength() == 0) {
  39066. - return null;
  39067. + /** The model that is loaded from TFLite model FlatBuffer. */
  39068. + private final Model model;
  39069. +
  39070. + /** A list of input tensors. */
  39071. + private final List</* @Nullable */ Tensor> inputTensors;
  39072. +
  39073. + /** A list of output tensors. */
  39074. + private final List</* @Nullable */ Tensor> outputTensors;
  39075. +
  39076. + /** Identifier of the TFLite model metadata in the Metadata array. */
  39077. + static final String METADATA_FIELD_NAME = "TFLITE_METADATA";
  39078. +
  39079. + /**
  39080. + * Creates a {@link ModelInfo} with the model FlatBuffer, {@code buffer}.
  39081. + *
  39082. + * <p>Though TFLite model FlatBuffer supports multiple subgraphs, TFLite Interpreter only
  39083. + * supports single subgraph so far. See the <a
  39084. + * href="https://www.tensorflow.org/lite/convert/cmdline_examples#specifying_subgraphs">instruction
  39085. + * of how to specify subgraph during convertion for more information.</a> Therefore, all methods
  39086. + * in {@link ModelInfo} retrieves metadata of the first subgrpah as default.
  39087. + *
  39088. + * @param buffer the TFLite model FlatBuffer
  39089. + * @throws NullPointerException if {@code buffer} is null
  39090. + * @throws IllegalArgumentException if the model does not contain any subgraph, or the model
  39091. + * does
  39092. + * not contain the expected identifier
  39093. + */
  39094. + ModelInfo(ByteBuffer buffer) {
  39095. + assertTFLiteModel(buffer);
  39096. +
  39097. + model = Model.getRootAsModel(buffer);
  39098. + checkArgument(model.subgraphsLength() > 0, "The model does not contain any subgraph.");
  39099. +
  39100. + inputTensors = getInputTensors(model);
  39101. + outputTensors = getOutputTensors(model);
  39102. + }
  39103. +
  39104. + /**
  39105. + * Gets the input tensor with {@code inputIndex}.
  39106. + *
  39107. + * @param inputIndex The index of the desired input tensor.
  39108. + * @throws IllegalArgumentException if the inputIndex specified is invalid.
  39109. + */
  39110. + @Nullable
  39111. + Tensor getInputTensor(int inputIndex) {
  39112. + checkArgument(inputIndex >= 0 && inputIndex < inputTensors.size(),
  39113. + "The inputIndex specified is invalid.");
  39114. + return inputTensors.get(inputIndex);
  39115. + }
  39116. +
  39117. + int getInputTensorCount() {
  39118. + return inputTensors.size();
  39119. + }
  39120. +
  39121. + /**
  39122. + * Gets shape of the input tensor with {@code inputIndex}.
  39123. + *
  39124. + * @param inputIndex The index of the desired intput tensor.
  39125. + */
  39126. + int[] getInputTensorShape(int inputIndex) {
  39127. + Tensor tensor = getInputTensor(inputIndex);
  39128. + return getShape(tensor);
  39129. }
  39130. - for (int i = 0; i < model.metadataLength(); i++) {
  39131. - Metadata meta = model.metadata(i);
  39132. - if (METADATA_FIELD_NAME.equals(meta.name())) {
  39133. - long bufferIndex = meta.buffer();
  39134. - Buffer metadataBuf = model.buffers((int) bufferIndex);
  39135. - return metadataBuf.dataAsByteBuffer();
  39136. - }
  39137. + /**
  39138. + * Gets the {@link TensorType} in byte of the input tensor with {@code inputIndex}.
  39139. + *
  39140. + * @param inputIndex The index of the desired intput tensor.
  39141. + */
  39142. + byte getInputTensorType(int inputIndex) {
  39143. + return getInputTensor(inputIndex).type();
  39144. }
  39145. - return null;
  39146. - }
  39147. -
  39148. - /**
  39149. - * Gets the output tensor with {@code outputIndex}.
  39150. - *
  39151. - * @param outputIndex The index of the desired outtput tensor.
  39152. - * @throws IllegalArgumentException if the outputIndex specified is invalid.
  39153. - */
  39154. - @Nullable
  39155. - Tensor getOutputTensor(int outputIndex) {
  39156. - checkArgument(
  39157. - outputIndex >= 0 && outputIndex < outputTensors.size(),
  39158. - "The outputIndex specified is invalid.");
  39159. - return outputTensors.get(outputIndex);
  39160. - }
  39161. -
  39162. - int getOutputTensorCount() {
  39163. - return outputTensors.size();
  39164. - }
  39165. -
  39166. - /**
  39167. - * Gets shape of the output tensor with {@code outputIndex}.
  39168. - *
  39169. - * @param outputIndex The index of the desired outtput tensor.
  39170. - */
  39171. - int[] getOutputTensorShape(int outputIndex) {
  39172. - Tensor tensor = getOutputTensor(outputIndex);
  39173. - return getShape(tensor);
  39174. - }
  39175. -
  39176. - /**
  39177. - * Gets the {@link TensorType} in byte of the output tensor {@code outputIndex}.
  39178. - *
  39179. - * @param outputIndex The index of the desired outtput tensor.
  39180. - */
  39181. - byte getOutputTensorType(int outputIndex) {
  39182. - return getOutputTensor(outputIndex).type();
  39183. - }
  39184. -
  39185. - /**
  39186. - * Gets the quantization parameters of a tensor.
  39187. - *
  39188. - * <p>Only quantized tensors have valid {@code QuantizationParameters}. For tensor that are not
  39189. - * quantized, the values of scale and zero_point are both 0.
  39190. - *
  39191. - * @param tensor The tensor whoes quantization parameters is desired.
  39192. - * @throws NullPointerException if the tensor is null.
  39193. - * @throws IllegalArgumentException if {@code scale} and {@code zeroPoint} of the tensor's {@link
  39194. - * QuantizationParameters} are not single values.
  39195. - */
  39196. - QuantizationParams getQuantizationParams(Tensor tensor) {
  39197. - checkNotNull(tensor, "Tensor cannot be null.");
  39198. -
  39199. - float scale;
  39200. - int zeroPoint;
  39201. - QuantizationParameters quantization = tensor.quantization();
  39202. -
  39203. - // Tensors that are not quantized do not have quantization parameters, which can be null when
  39204. - // being extracted from the flatbuffer.
  39205. - if (quantization == null) {
  39206. - scale = 0.0f;
  39207. - zeroPoint = 0;
  39208. - return new QuantizationParams(scale, zeroPoint);
  39209. +
  39210. + /** Gets the metadata FlatBuffer from the model FlatBuffer. */
  39211. + @Nullable
  39212. + ByteBuffer getMetadataBuffer() {
  39213. + // Some models may not have metadata, and this is allowed.
  39214. + if (model.metadataLength() == 0) {
  39215. + return null;
  39216. + }
  39217. +
  39218. + for (int i = 0; i < model.metadataLength(); i++) {
  39219. + Metadata meta = model.metadata(i);
  39220. + if (METADATA_FIELD_NAME.equals(meta.name())) {
  39221. + long bufferIndex = meta.buffer();
  39222. + Buffer metadataBuf = model.buffers((int) bufferIndex);
  39223. + return metadataBuf.dataAsByteBuffer();
  39224. + }
  39225. + }
  39226. + return null;
  39227. + }
  39228. +
  39229. + /**
  39230. + * Gets the output tensor with {@code outputIndex}.
  39231. + *
  39232. + * @param outputIndex The index of the desired outtput tensor.
  39233. + * @throws IllegalArgumentException if the outputIndex specified is invalid.
  39234. + */
  39235. + @Nullable
  39236. + Tensor getOutputTensor(int outputIndex) {
  39237. + checkArgument(outputIndex >= 0 && outputIndex < outputTensors.size(),
  39238. + "The outputIndex specified is invalid.");
  39239. + return outputTensors.get(outputIndex);
  39240. + }
  39241. +
  39242. + int getOutputTensorCount() {
  39243. + return outputTensors.size();
  39244. + }
  39245. +
  39246. + /**
  39247. + * Gets shape of the output tensor with {@code outputIndex}.
  39248. + *
  39249. + * @param outputIndex The index of the desired outtput tensor.
  39250. + */
  39251. + int[] getOutputTensorShape(int outputIndex) {
  39252. + Tensor tensor = getOutputTensor(outputIndex);
  39253. + return getShape(tensor);
  39254. }
  39255. - // Tensors that are not quantized do not have quantization parameters.
  39256. - // quantization.scaleLength() and quantization.zeroPointLength() may both return 0.
  39257. - checkArgument(
  39258. - quantization.scaleLength() <= 1,
  39259. - "Input and output tensors do not support per-channel quantization.");
  39260. - checkArgument(
  39261. - quantization.zeroPointLength() <= 1,
  39262. - "Input and output tensors do not support per-channel quantization.");
  39263. -
  39264. - // For tensors that are not quantized, quantization.scale(0) and quantization.zeroPoint(0) will
  39265. - // both be the default value in flatbuffer, 0. This behavior is consistent with the TFlite C++
  39266. - // runtime.
  39267. - scale = quantization.scale(0);
  39268. - // zeroPoint is a long value in the schema, but an integer in the C++ runtime. Here we keep it
  39269. - // consistent with the C++ runtime.
  39270. - zeroPoint = (int) quantization.zeroPoint(0);
  39271. -
  39272. - return new QuantizationParams(scale, zeroPoint);
  39273. - }
  39274. -
  39275. - /**
  39276. - * Verifies if the buffer is a valid TFLite model.
  39277. - *
  39278. - * @param buffer the TFLite model flatbuffer
  39279. - * @throws NullPointerException if {@code buffer} is null.
  39280. - * @throws IllegalArgumentException if {@code buffer} does not contain the expected identifier
  39281. - */
  39282. - private static void assertTFLiteModel(ByteBuffer buffer) {
  39283. - checkNotNull(buffer, "Model flatbuffer cannot be null.");
  39284. - checkArgument(
  39285. - Model.ModelBufferHasIdentifier(buffer),
  39286. - "The identifier of the model is invalid. The buffer may not be a valid TFLite model"
  39287. - + " flatbuffer.");
  39288. - }
  39289. -
  39290. - /**
  39291. - * Gets the shape of a tensor.
  39292. - *
  39293. - * @param tensor The tensor whoes shape is desired.
  39294. - * @throws NullPointerException if the tensor is null.
  39295. - */
  39296. - private static int[] getShape(Tensor tensor) {
  39297. - checkNotNull(tensor, "Tensor cannot be null.");
  39298. - int shapeDim = tensor.shapeLength();
  39299. - int[] tensorShape = new int[shapeDim];
  39300. - for (int i = 0; i < shapeDim; i++) {
  39301. - tensorShape[i] = tensor.shape(i);
  39302. + /**
  39303. + * Gets the {@link TensorType} in byte of the output tensor {@code outputIndex}.
  39304. + *
  39305. + * @param outputIndex The index of the desired outtput tensor.
  39306. + */
  39307. + byte getOutputTensorType(int outputIndex) {
  39308. + return getOutputTensor(outputIndex).type();
  39309. }
  39310. - return tensorShape;
  39311. - }
  39312. -
  39313. - /** Gets input tensors from a model. */
  39314. - private static List<Tensor> getInputTensors(Model model) {
  39315. - // TFLite only support one subgraph currently.
  39316. - SubGraph subgraph = model.subgraphs(0);
  39317. - int tensorNum = subgraph.inputsLength();
  39318. - ArrayList<Tensor> inputTensors = new ArrayList<>(tensorNum);
  39319. - for (int i = 0; i < tensorNum; i++) {
  39320. - inputTensors.add(subgraph.tensors(subgraph.inputs(i)));
  39321. +
  39322. + /**
  39323. + * Gets the quantization parameters of a tensor.
  39324. + *
  39325. + * <p>Only quantized tensors have valid {@code QuantizationParameters}. For tensor that are not
  39326. + * quantized, the values of scale and zero_point are both 0.
  39327. + *
  39328. + * @param tensor The tensor whoes quantization parameters is desired.
  39329. + * @throws NullPointerException if the tensor is null.
  39330. + * @throws IllegalArgumentException if {@code scale} and {@code zeroPoint} of the tensor's
  39331. + * {@link
  39332. + * QuantizationParameters} are not single values.
  39333. + */
  39334. + QuantizationParams getQuantizationParams(Tensor tensor) {
  39335. + checkNotNull(tensor, "Tensor cannot be null.");
  39336. +
  39337. + float scale;
  39338. + int zeroPoint;
  39339. + QuantizationParameters quantization = tensor.quantization();
  39340. +
  39341. + // Tensors that are not quantized do not have quantization parameters, which can be null
  39342. + // when being extracted from the flatbuffer.
  39343. + if (quantization == null) {
  39344. + scale = 0.0f;
  39345. + zeroPoint = 0;
  39346. + return new QuantizationParams(scale, zeroPoint);
  39347. + }
  39348. +
  39349. + // Tensors that are not quantized do not have quantization parameters.
  39350. + // quantization.scaleLength() and quantization.zeroPointLength() may both return 0.
  39351. + checkArgument(quantization.scaleLength() <= 1,
  39352. + "Input and output tensors do not support per-channel quantization.");
  39353. + checkArgument(quantization.zeroPointLength() <= 1,
  39354. + "Input and output tensors do not support per-channel quantization.");
  39355. +
  39356. + // For tensors that are not quantized, quantization.scale(0) and quantization.zeroPoint(0)
  39357. + // will both be the default value in flatbuffer, 0. This behavior is consistent with the
  39358. + // TFlite C++ runtime.
  39359. + scale = quantization.scale(0);
  39360. + // zeroPoint is a long value in the schema, but an integer in the C++ runtime. Here we keep
  39361. + // it consistent with the C++ runtime.
  39362. + zeroPoint = (int) quantization.zeroPoint(0);
  39363. +
  39364. + return new QuantizationParams(scale, zeroPoint);
  39365. }
  39366. - return Collections.unmodifiableList(inputTensors);
  39367. - }
  39368. -
  39369. - /** Gets output tensors from a model. */
  39370. - private static List<Tensor> getOutputTensors(Model model) {
  39371. - // TFLite only support one subgraph currently.
  39372. - SubGraph subgraph = model.subgraphs(0);
  39373. - int tensorNum = subgraph.outputsLength();
  39374. - ArrayList<Tensor> outputTensors = new ArrayList<>(tensorNum);
  39375. - for (int i = 0; i < tensorNum; i++) {
  39376. - outputTensors.add(subgraph.tensors(subgraph.outputs(i)));
  39377. +
  39378. + /**
  39379. + * Verifies if the buffer is a valid TFLite model.
  39380. + *
  39381. + * @param buffer the TFLite model flatbuffer
  39382. + * @throws NullPointerException if {@code buffer} is null.
  39383. + * @throws IllegalArgumentException if {@code buffer} does not contain the expected identifier
  39384. + */
  39385. + private static void assertTFLiteModel(ByteBuffer buffer) {
  39386. + checkNotNull(buffer, "Model flatbuffer cannot be null.");
  39387. + checkArgument(Model.ModelBufferHasIdentifier(buffer),
  39388. + "The identifier of the model is invalid. The buffer may not be a valid TFLite model"
  39389. + + " flatbuffer.");
  39390. + }
  39391. +
  39392. + /**
  39393. + * Gets the shape of a tensor.
  39394. + *
  39395. + * @param tensor The tensor whoes shape is desired.
  39396. + * @throws NullPointerException if the tensor is null.
  39397. + */
  39398. + private static int[] getShape(Tensor tensor) {
  39399. + checkNotNull(tensor, "Tensor cannot be null.");
  39400. + int shapeDim = tensor.shapeLength();
  39401. + int[] tensorShape = new int[shapeDim];
  39402. + for (int i = 0; i < shapeDim; i++) {
  39403. + tensorShape[i] = tensor.shape(i);
  39404. + }
  39405. + return tensorShape;
  39406. + }
  39407. +
  39408. + /** Gets input tensors from a model. */
  39409. + private static List<Tensor> getInputTensors(Model model) {
  39410. + // TFLite only support one subgraph currently.
  39411. + SubGraph subgraph = model.subgraphs(0);
  39412. + int tensorNum = subgraph.inputsLength();
  39413. + ArrayList<Tensor> inputTensors = new ArrayList<>(tensorNum);
  39414. + for (int i = 0; i < tensorNum; i++) {
  39415. + inputTensors.add(subgraph.tensors(subgraph.inputs(i)));
  39416. + }
  39417. + return Collections.unmodifiableList(inputTensors);
  39418. + }
  39419. +
  39420. + /** Gets output tensors from a model. */
  39421. + private static List<Tensor> getOutputTensors(Model model) {
  39422. + // TFLite only support one subgraph currently.
  39423. + SubGraph subgraph = model.subgraphs(0);
  39424. + int tensorNum = subgraph.outputsLength();
  39425. + ArrayList<Tensor> outputTensors = new ArrayList<>(tensorNum);
  39426. + for (int i = 0; i < tensorNum; i++) {
  39427. + outputTensors.add(subgraph.tensors(subgraph.outputs(i)));
  39428. + }
  39429. + return Collections.unmodifiableList(outputTensors);
  39430. }
  39431. - return Collections.unmodifiableList(outputTensors);
  39432. - }
  39433. }
  39434. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ModelMetadataInfo.java b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ModelMetadataInfo.java
  39435. index 751ed500dc2fc..7ee01df094283 100644
  39436. --- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ModelMetadataInfo.java
  39437. +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ModelMetadataInfo.java
  39438. @@ -18,136 +18,133 @@ package org.tensorflow.lite.support.metadata;
  39439. import static org.tensorflow.lite.support.metadata.Preconditions.checkArgument;
  39440. import static org.tensorflow.lite.support.metadata.Preconditions.checkNotNull;
  39441. -import java.nio.ByteBuffer;
  39442. -import java.util.ArrayList;
  39443. -import java.util.Collections;
  39444. -import java.util.List;
  39445. import org.checkerframework.checker.nullness.qual.Nullable;
  39446. import org.tensorflow.lite.support.metadata.schema.ModelMetadata;
  39447. import org.tensorflow.lite.support.metadata.schema.SubGraphMetadata;
  39448. import org.tensorflow.lite.support.metadata.schema.TensorMetadata;
  39449. +import java.nio.ByteBuffer;
  39450. +import java.util.ArrayList;
  39451. +import java.util.Collections;
  39452. +import java.util.List;
  39453. +
  39454. /** Extracts model metadata information out of TFLite metadata FlatBuffer. */
  39455. final class ModelMetadataInfo {
  39456. - /** The root handler for the model metadata. */
  39457. - private final ModelMetadata modelMetadata;
  39458. -
  39459. - /** Metadata array of input tensors. */
  39460. - private final List</* @Nullable */ TensorMetadata> inputsMetadata;
  39461. -
  39462. - /** Metadata array of output tensors. */
  39463. - private final List</* @Nullable */ TensorMetadata> outputsMetadata;
  39464. -
  39465. - /** The minimum parser version required to fully understand the metadata flatbuffer. */
  39466. - private final String /* @Nullable */ minVersion;
  39467. -
  39468. - /**
  39469. - * Creates a {@link ModelMetadataInfo} with the metadata FlatBuffer, {@code buffer}.
  39470. - *
  39471. - * @param buffer the TFLite metadata FlatBuffer
  39472. - * @throws NullPointerException if {@code buffer} is null
  39473. - * @throws IllegalArgumentException if {@code buffer} does not contain any subgraph metadata, or
  39474. - * it does not contain the expected identifier
  39475. - */
  39476. - ModelMetadataInfo(ByteBuffer buffer) {
  39477. - assertTFLiteMetadata(buffer);
  39478. -
  39479. - modelMetadata = ModelMetadata.getRootAsModelMetadata(buffer);
  39480. - checkArgument(
  39481. - modelMetadata.subgraphMetadataLength() > 0,
  39482. - "The metadata flatbuffer does not contain any subgraph metadata.");
  39483. -
  39484. - inputsMetadata = getInputsMetadata(modelMetadata);
  39485. - outputsMetadata = getOutputsMetadata(modelMetadata);
  39486. - minVersion = modelMetadata.minParserVersion();
  39487. - }
  39488. -
  39489. - /** Gets the count of input tensors with metadata in the metadata FlatBuffer. */
  39490. - int getInputTensorCount() {
  39491. - return inputsMetadata.size();
  39492. - }
  39493. -
  39494. - /**
  39495. - * Gets the metadata for the input tensor specified by {@code inputIndex}.
  39496. - *
  39497. - * @param inputIndex The index of the desired intput tensor.
  39498. - * @throws IllegalArgumentException if the inputIndex specified is invalid.
  39499. - */
  39500. - @Nullable
  39501. - TensorMetadata getInputTensorMetadata(int inputIndex) {
  39502. - checkArgument(
  39503. - inputIndex >= 0 && inputIndex < inputsMetadata.size(),
  39504. - "The inputIndex specified is invalid.");
  39505. - return inputsMetadata.get(inputIndex);
  39506. - }
  39507. -
  39508. - /**
  39509. - * Gets the minimum parser version of the metadata. It can be {@code null} if the version is not
  39510. - * populated.
  39511. - */
  39512. - @Nullable
  39513. - String getMininumParserVersion() {
  39514. - return minVersion;
  39515. - }
  39516. -
  39517. - /** Gets the root handler for the model metadata. */
  39518. - ModelMetadata getModelMetadata() {
  39519. - return modelMetadata;
  39520. - }
  39521. -
  39522. - /** Gets the count of output tensors with metadata in the metadata FlatBuffer. */
  39523. - int getOutputTensorCount() {
  39524. - return outputsMetadata.size();
  39525. - }
  39526. -
  39527. - /**
  39528. - * Gets the metadata for the output tensor specified by {@code outputIndex}.
  39529. - *
  39530. - * @param outputIndex The index of the desired output tensor.
  39531. - * @throws IllegalArgumentException if the outputIndex specified is invalid.
  39532. - */
  39533. - @Nullable
  39534. - TensorMetadata getOutputTensorMetadata(int outputIndex) {
  39535. - checkArgument(
  39536. - outputIndex >= 0 && outputIndex < outputsMetadata.size(),
  39537. - "The outputIndex specified is invalid.");
  39538. - return outputsMetadata.get(outputIndex);
  39539. - }
  39540. -
  39541. - /**
  39542. - * Verifies if the buffer is a valid TFLite metadata flatbuffer.
  39543. - *
  39544. - * @param buffer the TFLite metadata flatbuffer
  39545. - * @throws NullPointerException if {@code buffer} is null.
  39546. - * @throws IllegalArgumentException if {@code buffer} does not contain the expected identifier
  39547. - */
  39548. - private static void assertTFLiteMetadata(ByteBuffer buffer) {
  39549. - checkNotNull(buffer, "Metadata flatbuffer cannot be null.");
  39550. - checkArgument(
  39551. - ModelMetadata.ModelMetadataBufferHasIdentifier(buffer),
  39552. - "The identifier of the metadata is invalid. The buffer may not be a valid TFLite metadata"
  39553. - + " flatbuffer.");
  39554. - }
  39555. -
  39556. - /** Gets metadata for all input tensors. */
  39557. - private static List<TensorMetadata> getInputsMetadata(ModelMetadata modelMetadata) {
  39558. - SubGraphMetadata subgraphMetadata = modelMetadata.subgraphMetadata(0);
  39559. - int tensorNum = subgraphMetadata.inputTensorMetadataLength();
  39560. - ArrayList<TensorMetadata> inputsMetadata = new ArrayList<>(tensorNum);
  39561. - for (int i = 0; i < tensorNum; i++) {
  39562. - inputsMetadata.add(subgraphMetadata.inputTensorMetadata(i));
  39563. + /** The root handler for the model metadata. */
  39564. + private final ModelMetadata modelMetadata;
  39565. +
  39566. + /** Metadata array of input tensors. */
  39567. + private final List</* @Nullable */ TensorMetadata> inputsMetadata;
  39568. +
  39569. + /** Metadata array of output tensors. */
  39570. + private final List</* @Nullable */ TensorMetadata> outputsMetadata;
  39571. +
  39572. + /** The minimum parser version required to fully understand the metadata flatbuffer. */
  39573. + private final String /* @Nullable */ minVersion;
  39574. +
  39575. + /**
  39576. + * Creates a {@link ModelMetadataInfo} with the metadata FlatBuffer, {@code buffer}.
  39577. + *
  39578. + * @param buffer the TFLite metadata FlatBuffer
  39579. + * @throws NullPointerException if {@code buffer} is null
  39580. + * @throws IllegalArgumentException if {@code buffer} does not contain any subgraph metadata, or
  39581. + * it does not contain the expected identifier
  39582. + */
  39583. + ModelMetadataInfo(ByteBuffer buffer) {
  39584. + assertTFLiteMetadata(buffer);
  39585. +
  39586. + modelMetadata = ModelMetadata.getRootAsModelMetadata(buffer);
  39587. + checkArgument(modelMetadata.subgraphMetadataLength() > 0,
  39588. + "The metadata flatbuffer does not contain any subgraph metadata.");
  39589. +
  39590. + inputsMetadata = getInputsMetadata(modelMetadata);
  39591. + outputsMetadata = getOutputsMetadata(modelMetadata);
  39592. + minVersion = modelMetadata.minParserVersion();
  39593. + }
  39594. +
  39595. + /** Gets the count of input tensors with metadata in the metadata FlatBuffer. */
  39596. + int getInputTensorCount() {
  39597. + return inputsMetadata.size();
  39598. + }
  39599. +
  39600. + /**
  39601. + * Gets the metadata for the input tensor specified by {@code inputIndex}.
  39602. + *
  39603. + * @param inputIndex The index of the desired intput tensor.
  39604. + * @throws IllegalArgumentException if the inputIndex specified is invalid.
  39605. + */
  39606. + @Nullable
  39607. + TensorMetadata getInputTensorMetadata(int inputIndex) {
  39608. + checkArgument(inputIndex >= 0 && inputIndex < inputsMetadata.size(),
  39609. + "The inputIndex specified is invalid.");
  39610. + return inputsMetadata.get(inputIndex);
  39611. }
  39612. - return Collections.unmodifiableList(inputsMetadata);
  39613. - }
  39614. -
  39615. - /** Gets metadata for all output tensors. */
  39616. - private static List<TensorMetadata> getOutputsMetadata(ModelMetadata modelMetadata) {
  39617. - SubGraphMetadata subgraphMetadata = modelMetadata.subgraphMetadata(0);
  39618. - int tensorNum = subgraphMetadata.outputTensorMetadataLength();
  39619. - ArrayList<TensorMetadata> outputsMetadata = new ArrayList<>(tensorNum);
  39620. - for (int i = 0; i < tensorNum; i++) {
  39621. - outputsMetadata.add(subgraphMetadata.outputTensorMetadata(i));
  39622. +
  39623. + /**
  39624. + * Gets the minimum parser version of the metadata. It can be {@code null} if the version is not
  39625. + * populated.
  39626. + */
  39627. + @Nullable
  39628. + String getMininumParserVersion() {
  39629. + return minVersion;
  39630. + }
  39631. +
  39632. + /** Gets the root handler for the model metadata. */
  39633. + ModelMetadata getModelMetadata() {
  39634. + return modelMetadata;
  39635. + }
  39636. +
  39637. + /** Gets the count of output tensors with metadata in the metadata FlatBuffer. */
  39638. + int getOutputTensorCount() {
  39639. + return outputsMetadata.size();
  39640. + }
  39641. +
  39642. + /**
  39643. + * Gets the metadata for the output tensor specified by {@code outputIndex}.
  39644. + *
  39645. + * @param outputIndex The index of the desired output tensor.
  39646. + * @throws IllegalArgumentException if the outputIndex specified is invalid.
  39647. + */
  39648. + @Nullable
  39649. + TensorMetadata getOutputTensorMetadata(int outputIndex) {
  39650. + checkArgument(outputIndex >= 0 && outputIndex < outputsMetadata.size(),
  39651. + "The outputIndex specified is invalid.");
  39652. + return outputsMetadata.get(outputIndex);
  39653. + }
  39654. +
  39655. + /**
  39656. + * Verifies if the buffer is a valid TFLite metadata flatbuffer.
  39657. + *
  39658. + * @param buffer the TFLite metadata flatbuffer
  39659. + * @throws NullPointerException if {@code buffer} is null.
  39660. + * @throws IllegalArgumentException if {@code buffer} does not contain the expected identifier
  39661. + */
  39662. + private static void assertTFLiteMetadata(ByteBuffer buffer) {
  39663. + checkNotNull(buffer, "Metadata flatbuffer cannot be null.");
  39664. + checkArgument(ModelMetadata.ModelMetadataBufferHasIdentifier(buffer),
  39665. + "The identifier of the metadata is invalid. The buffer may not be a valid TFLite metadata"
  39666. + + " flatbuffer.");
  39667. + }
  39668. +
  39669. + /** Gets metadata for all input tensors. */
  39670. + private static List<TensorMetadata> getInputsMetadata(ModelMetadata modelMetadata) {
  39671. + SubGraphMetadata subgraphMetadata = modelMetadata.subgraphMetadata(0);
  39672. + int tensorNum = subgraphMetadata.inputTensorMetadataLength();
  39673. + ArrayList<TensorMetadata> inputsMetadata = new ArrayList<>(tensorNum);
  39674. + for (int i = 0; i < tensorNum; i++) {
  39675. + inputsMetadata.add(subgraphMetadata.inputTensorMetadata(i));
  39676. + }
  39677. + return Collections.unmodifiableList(inputsMetadata);
  39678. + }
  39679. +
  39680. + /** Gets metadata for all output tensors. */
  39681. + private static List<TensorMetadata> getOutputsMetadata(ModelMetadata modelMetadata) {
  39682. + SubGraphMetadata subgraphMetadata = modelMetadata.subgraphMetadata(0);
  39683. + int tensorNum = subgraphMetadata.outputTensorMetadataLength();
  39684. + ArrayList<TensorMetadata> outputsMetadata = new ArrayList<>(tensorNum);
  39685. + for (int i = 0; i < tensorNum; i++) {
  39686. + outputsMetadata.add(subgraphMetadata.outputTensorMetadata(i));
  39687. + }
  39688. + return Collections.unmodifiableList(outputsMetadata);
  39689. }
  39690. - return Collections.unmodifiableList(outputsMetadata);
  39691. - }
  39692. }
  39693. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/Preconditions.java b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/Preconditions.java
  39694. index c2f20fbaacd76..ca3eed3490644 100644
  39695. --- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/Preconditions.java
  39696. +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/Preconditions.java
  39697. @@ -19,166 +19,170 @@ import org.checkerframework.checker.nullness.qual.Nullable;
  39698. /** Static error checking util methods. */
  39699. final class Preconditions {
  39700. - /**
  39701. - * Ensures that an object reference passed as a parameter to the calling method is not null.
  39702. - *
  39703. - * @param reference an object reference
  39704. - * @return the non-null reference that was validated
  39705. - * @throws NullPointerException if {@code reference} is null
  39706. - */
  39707. - public static <T extends Object> T checkNotNull(T reference) {
  39708. - if (reference == null) {
  39709. - throw new NullPointerException("The object reference is null.");
  39710. + /**
  39711. + * Ensures that an object reference passed as a parameter to the calling method is not null.
  39712. + *
  39713. + * @param reference an object reference
  39714. + * @return the non-null reference that was validated
  39715. + * @throws NullPointerException if {@code reference} is null
  39716. + */
  39717. + public static <T extends Object> T checkNotNull(T reference) {
  39718. + if (reference == null) {
  39719. + throw new NullPointerException("The object reference is null.");
  39720. + }
  39721. + return reference;
  39722. }
  39723. - return reference;
  39724. - }
  39725. -
  39726. - /**
  39727. - * Ensures that an object reference passed as a parameter to the calling method is not null.
  39728. - *
  39729. - * @param reference an object reference
  39730. - * @param errorMessage the exception message to use if the check fails; will be converted to a
  39731. - * string using {@link String#valueOf(Object)}
  39732. - * @return the non-null reference that was validated
  39733. - * @throws NullPointerException if {@code reference} is null
  39734. - */
  39735. - public static <T extends Object> T checkNotNull(T reference, @Nullable Object errorMessage) {
  39736. - if (reference == null) {
  39737. - throw new NullPointerException(String.valueOf(errorMessage));
  39738. +
  39739. + /**
  39740. + * Ensures that an object reference passed as a parameter to the calling method is not null.
  39741. + *
  39742. + * @param reference an object reference
  39743. + * @param errorMessage the exception message to use if the check fails; will be converted to a
  39744. + * string using {@link String#valueOf(Object)}
  39745. + * @return the non-null reference that was validated
  39746. + * @throws NullPointerException if {@code reference} is null
  39747. + */
  39748. + public static <T extends Object> T checkNotNull(T reference, @Nullable Object errorMessage) {
  39749. + if (reference == null) {
  39750. + throw new NullPointerException(String.valueOf(errorMessage));
  39751. + }
  39752. + return reference;
  39753. + }
  39754. +
  39755. + /**
  39756. + * Ensures that the given String is not empty and not null.
  39757. + *
  39758. + * @param string the String to test
  39759. + * @return the non-null non-empty String that was validated
  39760. + * @throws IllegalArgumentException if {@code string} is null or empty
  39761. + */
  39762. + public static String checkNotEmpty(String string) {
  39763. + if (string == null || string.length() == 0) {
  39764. + throw new IllegalArgumentException("Given String is empty or null.");
  39765. + }
  39766. + return string;
  39767. }
  39768. - return reference;
  39769. - }
  39770. -
  39771. - /**
  39772. - * Ensures that the given String is not empty and not null.
  39773. - *
  39774. - * @param string the String to test
  39775. - * @return the non-null non-empty String that was validated
  39776. - * @throws IllegalArgumentException if {@code string} is null or empty
  39777. - */
  39778. - public static String checkNotEmpty(String string) {
  39779. - if (string == null || string.length() == 0) {
  39780. - throw new IllegalArgumentException("Given String is empty or null.");
  39781. +
  39782. + /**
  39783. + * Ensures that the given String is not empty and not null.
  39784. + *
  39785. + * @param string the String to test
  39786. + * @param errorMessage the exception message to use if the check fails; will be converted to a
  39787. + * string using {@link String#valueOf(Object)}
  39788. + * @return the non-null non-empty String that was validated
  39789. + * @throws IllegalArgumentException if {@code string} is null or empty
  39790. + */
  39791. + public static String checkNotEmpty(String string, Object errorMessage) {
  39792. + if (string == null || string.length() == 0) {
  39793. + throw new IllegalArgumentException(String.valueOf(errorMessage));
  39794. + }
  39795. + return string;
  39796. }
  39797. - return string;
  39798. - }
  39799. -
  39800. - /**
  39801. - * Ensures that the given String is not empty and not null.
  39802. - *
  39803. - * @param string the String to test
  39804. - * @param errorMessage the exception message to use if the check fails; will be converted to a
  39805. - * string using {@link String#valueOf(Object)}
  39806. - * @return the non-null non-empty String that was validated
  39807. - * @throws IllegalArgumentException if {@code string} is null or empty
  39808. - */
  39809. - public static String checkNotEmpty(String string, Object errorMessage) {
  39810. - if (string == null || string.length() == 0) {
  39811. - throw new IllegalArgumentException(String.valueOf(errorMessage));
  39812. +
  39813. + /**
  39814. + * Ensures the truth of an expression involving one or more parameters to the calling method.
  39815. + *
  39816. + * @param expression a boolean expression.
  39817. + * @throws IllegalArgumentException if {@code expression} is false.
  39818. + */
  39819. + public static void checkArgument(boolean expression) {
  39820. + if (!expression) {
  39821. + throw new IllegalArgumentException();
  39822. + }
  39823. }
  39824. - return string;
  39825. - }
  39826. -
  39827. - /**
  39828. - * Ensures the truth of an expression involving one or more parameters to the calling method.
  39829. - *
  39830. - * @param expression a boolean expression.
  39831. - * @throws IllegalArgumentException if {@code expression} is false.
  39832. - */
  39833. - public static void checkArgument(boolean expression) {
  39834. - if (!expression) {
  39835. - throw new IllegalArgumentException();
  39836. +
  39837. + /**
  39838. + * Ensures the truth of an expression involving one or more parameters to the calling method.
  39839. + *
  39840. + * @param expression a boolean expression.
  39841. + * @param errorMessage the exception message to use if the check fails; will be converted to a
  39842. + * string using {@link String#valueOf(Object)}.
  39843. + * @throws IllegalArgumentException if {@code expression} is false.
  39844. + */
  39845. + public static void checkArgument(boolean expression, @Nullable Object errorMessage) {
  39846. + if (!expression) {
  39847. + throw new IllegalArgumentException(String.valueOf(errorMessage));
  39848. + }
  39849. }
  39850. - }
  39851. -
  39852. - /**
  39853. - * Ensures the truth of an expression involving one or more parameters to the calling method.
  39854. - *
  39855. - * @param expression a boolean expression.
  39856. - * @param errorMessage the exception message to use if the check fails; will be converted to a
  39857. - * string using {@link String#valueOf(Object)}.
  39858. - * @throws IllegalArgumentException if {@code expression} is false.
  39859. - */
  39860. - public static void checkArgument(boolean expression, @Nullable Object errorMessage) {
  39861. - if (!expression) {
  39862. - throw new IllegalArgumentException(String.valueOf(errorMessage));
  39863. +
  39864. + /**
  39865. + * Ensures that {@code index} specifies a valid <i>element</i> in an array, list or string of
  39866. + * size
  39867. + * {@code size}. An element index may range from zero, inclusive, to {@code size}, exclusive.
  39868. + *
  39869. + * @param index a user-supplied index identifying an element of an array, list or string
  39870. + * @param size the size of that array, list or string
  39871. + * @return the value of {@code index}
  39872. + * @throws IndexOutOfBoundsException if {@code index} is negative or is not less than {@code
  39873. + * size}
  39874. + * @throws IllegalArgumentException if {@code size} is negative
  39875. + */
  39876. + public static int checkElementIndex(int index, int size) {
  39877. + return checkElementIndex(index, size, "index");
  39878. }
  39879. - }
  39880. -
  39881. - /**
  39882. - * Ensures that {@code index} specifies a valid <i>element</i> in an array, list or string of size
  39883. - * {@code size}. An element index may range from zero, inclusive, to {@code size}, exclusive.
  39884. - *
  39885. - * @param index a user-supplied index identifying an element of an array, list or string
  39886. - * @param size the size of that array, list or string
  39887. - * @return the value of {@code index}
  39888. - * @throws IndexOutOfBoundsException if {@code index} is negative or is not less than {@code size}
  39889. - * @throws IllegalArgumentException if {@code size} is negative
  39890. - */
  39891. - public static int checkElementIndex(int index, int size) {
  39892. - return checkElementIndex(index, size, "index");
  39893. - }
  39894. -
  39895. - /**
  39896. - * Ensures that {@code index} specifies a valid <i>element</i> in an array, list or string of size
  39897. - * {@code size}. An element index may range from zero, inclusive, to {@code size}, exclusive.
  39898. - *
  39899. - * @param index a user-supplied index identifying an element of an array, list or string
  39900. - * @param size the size of that array, list or string
  39901. - * @param desc the text to use to describe this index in an error message
  39902. - * @return the value of {@code index}
  39903. - * @throws IndexOutOfBoundsException if {@code index} is negative or is not less than {@code size}
  39904. - * @throws IllegalArgumentException if {@code size} is negative
  39905. - */
  39906. - public static int checkElementIndex(int index, int size, @Nullable String desc) {
  39907. - // Carefully optimized for execution by hotspot (explanatory comment above)
  39908. - if (index < 0 || index >= size) {
  39909. - throw new IndexOutOfBoundsException(badElementIndex(index, size, desc));
  39910. +
  39911. + /**
  39912. + * Ensures that {@code index} specifies a valid <i>element</i> in an array, list or string of
  39913. + * size
  39914. + * {@code size}. An element index may range from zero, inclusive, to {@code size}, exclusive.
  39915. + *
  39916. + * @param index a user-supplied index identifying an element of an array, list or string
  39917. + * @param size the size of that array, list or string
  39918. + * @param desc the text to use to describe this index in an error message
  39919. + * @return the value of {@code index}
  39920. + * @throws IndexOutOfBoundsException if {@code index} is negative or is not less than {@code
  39921. + * size}
  39922. + * @throws IllegalArgumentException if {@code size} is negative
  39923. + */
  39924. + public static int checkElementIndex(int index, int size, @Nullable String desc) {
  39925. + // Carefully optimized for execution by hotspot (explanatory comment above)
  39926. + if (index < 0 || index >= size) {
  39927. + throw new IndexOutOfBoundsException(badElementIndex(index, size, desc));
  39928. + }
  39929. + return index;
  39930. }
  39931. - return index;
  39932. - }
  39933. -
  39934. - /**
  39935. - * Ensures the truth of an expression involving the state of the calling instance, but not
  39936. - * involving any parameters to the calling method.
  39937. - *
  39938. - * @param expression a boolean expression
  39939. - * @throws IllegalStateException if {@code expression} is false
  39940. - * @see Verify#verify Verify.verify()
  39941. - */
  39942. - public static void checkState(boolean expression) {
  39943. - if (!expression) {
  39944. - throw new IllegalStateException();
  39945. +
  39946. + /**
  39947. + * Ensures the truth of an expression involving the state of the calling instance, but not
  39948. + * involving any parameters to the calling method.
  39949. + *
  39950. + * @param expression a boolean expression
  39951. + * @throws IllegalStateException if {@code expression} is false
  39952. + * @see Verify#verify Verify.verify()
  39953. + */
  39954. + public static void checkState(boolean expression) {
  39955. + if (!expression) {
  39956. + throw new IllegalStateException();
  39957. + }
  39958. }
  39959. - }
  39960. -
  39961. - /**
  39962. - * Ensures the truth of an expression involving the state of the calling instance, but not
  39963. - * involving any parameters to the calling method.
  39964. - *
  39965. - * @param expression a boolean expression
  39966. - * @param errorMessage the exception message to use if the check fails; will be converted to a
  39967. - * string using {@link String#valueOf(Object)}
  39968. - * @throws IllegalStateException if {@code expression} is false
  39969. - * @see Verify#verify Verify.verify()
  39970. - */
  39971. - public static void checkState(boolean expression, @Nullable Object errorMessage) {
  39972. - if (!expression) {
  39973. - throw new IllegalStateException(String.valueOf(errorMessage));
  39974. +
  39975. + /**
  39976. + * Ensures the truth of an expression involving the state of the calling instance, but not
  39977. + * involving any parameters to the calling method.
  39978. + *
  39979. + * @param expression a boolean expression
  39980. + * @param errorMessage the exception message to use if the check fails; will be converted to a
  39981. + * string using {@link String#valueOf(Object)}
  39982. + * @throws IllegalStateException if {@code expression} is false
  39983. + * @see Verify#verify Verify.verify()
  39984. + */
  39985. + public static void checkState(boolean expression, @Nullable Object errorMessage) {
  39986. + if (!expression) {
  39987. + throw new IllegalStateException(String.valueOf(errorMessage));
  39988. + }
  39989. }
  39990. - }
  39991. -
  39992. - private static String badElementIndex(int index, int size, @Nullable String desc) {
  39993. - if (index < 0) {
  39994. - return String.format("%s (%s) must not be negative", desc, index);
  39995. - } else if (size < 0) {
  39996. - throw new IllegalArgumentException("negative size: " + size);
  39997. - } else { // index >= size
  39998. - return String.format("%s (%s) must be less than size (%s)", desc, index, size);
  39999. +
  40000. + private static String badElementIndex(int index, int size, @Nullable String desc) {
  40001. + if (index < 0) {
  40002. + return String.format("%s (%s) must not be negative", desc, index);
  40003. + } else if (size < 0) {
  40004. + throw new IllegalArgumentException("negative size: " + size);
  40005. + } else { // index >= size
  40006. + return String.format("%s (%s) must be less than size (%s)", desc, index, size);
  40007. + }
  40008. }
  40009. - }
  40010. - private Preconditions() {
  40011. - throw new AssertionError("Preconditions is Uninstantiable.");
  40012. - }
  40013. + private Preconditions() {
  40014. + throw new AssertionError("Preconditions is Uninstantiable.");
  40015. + }
  40016. }
  40017. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/SeekableByteChannelCompat.java b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/SeekableByteChannelCompat.java
  40018. index c655786755baa..1408a3a73d86b 100644
  40019. --- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/SeekableByteChannelCompat.java
  40020. +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/SeekableByteChannelCompat.java
  40021. @@ -29,79 +29,79 @@ import java.nio.channels.Channel;
  40022. * the MetadtaExtractor library consistent with the common used Java libraries.
  40023. */
  40024. interface SeekableByteChannelCompat extends Channel {
  40025. - /**
  40026. - * Reads a sequence of bytes from this channel into the given buffer.
  40027. - *
  40028. - * @param dst The buffer into which bytes are to be transferred
  40029. - * @return The number of bytes read, possibly zero, or <tt>-1</tt> if the channel has reached
  40030. - * end-of-stream
  40031. - * @throws NonReadableChannelException If this channel was not opened for reading
  40032. - * @throws ClosedChannelException If this channel is closed
  40033. - * @throws AsynchronousCloseException If another thread closes this channel while the read
  40034. - * operation is in progress
  40035. - * @throws ClosedByInterruptException If another thread interrupts the current thread while the
  40036. - * read operation is in progress, thereby closing the channel and setting the current thread's
  40037. - * interrupt status
  40038. - * @throws IOException If some other I/O error occurs
  40039. - */
  40040. - int read(ByteBuffer dst) throws IOException;
  40041. + /**
  40042. + * Reads a sequence of bytes from this channel into the given buffer.
  40043. + *
  40044. + * @param dst The buffer into which bytes are to be transferred
  40045. + * @return The number of bytes read, possibly zero, or <tt>-1</tt> if the channel has reached
  40046. + * end-of-stream
  40047. + * @throws NonReadableChannelException If this channel was not opened for reading
  40048. + * @throws ClosedChannelException If this channel is closed
  40049. + * @throws AsynchronousCloseException If another thread closes this channel while the read
  40050. + * operation is in progress
  40051. + * @throws ClosedByInterruptException If another thread interrupts the current thread while the
  40052. + * read operation is in progress, thereby closing the channel and setting the current
  40053. + * thread's interrupt status
  40054. + * @throws IOException If some other I/O error occurs
  40055. + */
  40056. + int read(ByteBuffer dst) throws IOException;
  40057. - /**
  40058. - * Writes a sequence of bytes to this channel from the given buffer.
  40059. - *
  40060. - * @param src The buffer from which bytes are to be retrieved
  40061. - * @return The number of bytes written, possibly zero
  40062. - * @throws NonWritableChannelException If this channel was not opened for writing
  40063. - * @throws ClosedChannelException If this channel is closed
  40064. - * @throws AsynchronousCloseException If another thread closes this channel while the write
  40065. - * operation is in progress
  40066. - * @throws ClosedByInterruptException If another thread interrupts the current thread while the
  40067. - * write operation is in progress, thereby closing the channel and setting the current
  40068. - * thread's interrupt status
  40069. - * @throws IOException If some other I/O error occurs
  40070. - */
  40071. - int write(ByteBuffer src) throws IOException;
  40072. + /**
  40073. + * Writes a sequence of bytes to this channel from the given buffer.
  40074. + *
  40075. + * @param src The buffer from which bytes are to be retrieved
  40076. + * @return The number of bytes written, possibly zero
  40077. + * @throws NonWritableChannelException If this channel was not opened for writing
  40078. + * @throws ClosedChannelException If this channel is closed
  40079. + * @throws AsynchronousCloseException If another thread closes this channel while the write
  40080. + * operation is in progress
  40081. + * @throws ClosedByInterruptException If another thread interrupts the current thread while the
  40082. + * write operation is in progress, thereby closing the channel and setting the current
  40083. + * thread's interrupt status
  40084. + * @throws IOException If some other I/O error occurs
  40085. + */
  40086. + int write(ByteBuffer src) throws IOException;
  40087. - /**
  40088. - * Returns this channel's position.
  40089. - *
  40090. - * @return This channel's position, a non-negative integer counting the number of bytes from the
  40091. - * beginning of the entity to the current position
  40092. - * @throws ClosedChannelException If this channel is closed
  40093. - * @throws IOException If some other I/O error occurs
  40094. - */
  40095. - long position() throws IOException;
  40096. + /**
  40097. + * Returns this channel's position.
  40098. + *
  40099. + * @return This channel's position, a non-negative integer counting the number of bytes from the
  40100. + * beginning of the entity to the current position
  40101. + * @throws ClosedChannelException If this channel is closed
  40102. + * @throws IOException If some other I/O error occurs
  40103. + */
  40104. + long position() throws IOException;
  40105. - /**
  40106. - * Sets this channel's position.
  40107. - *
  40108. - * @param newPosition The new position, a non-negative integer counting the number of bytes from
  40109. - * the beginning of the entity
  40110. - * @return This channel
  40111. - * @throws ClosedChannelException If this channel is closed
  40112. - * @throws IllegalArgumentException If the new position is negative
  40113. - * @throws IOException If some other I/O error occurs
  40114. - */
  40115. - SeekableByteChannelCompat position(long newPosition) throws IOException;
  40116. + /**
  40117. + * Sets this channel's position.
  40118. + *
  40119. + * @param newPosition The new position, a non-negative integer counting the number of bytes from
  40120. + * the beginning of the entity
  40121. + * @return This channel
  40122. + * @throws ClosedChannelException If this channel is closed
  40123. + * @throws IllegalArgumentException If the new position is negative
  40124. + * @throws IOException If some other I/O error occurs
  40125. + */
  40126. + SeekableByteChannelCompat position(long newPosition) throws IOException;
  40127. - /**
  40128. - * Returns the current size of entity to which this channel is connected.
  40129. - *
  40130. - * @return The current size, measured in bytes
  40131. - * @throws ClosedChannelException If this channel is closed
  40132. - * @throws IOException If some other I/O error occurs
  40133. - */
  40134. - long size() throws IOException;
  40135. + /**
  40136. + * Returns the current size of entity to which this channel is connected.
  40137. + *
  40138. + * @return The current size, measured in bytes
  40139. + * @throws ClosedChannelException If this channel is closed
  40140. + * @throws IOException If some other I/O error occurs
  40141. + */
  40142. + long size() throws IOException;
  40143. - /**
  40144. - * Truncates the entity, to which this channel is connected, to the given size.
  40145. - *
  40146. - * @param size The new size, a non-negative byte count
  40147. - * @return This channel
  40148. - * @throws NonWritableChannelException If this channel was not opened for writing
  40149. - * @throws ClosedChannelException If this channel is closed
  40150. - * @throws IllegalArgumentException If the new size is negative
  40151. - * @throws IOException If some other I/O error occurs
  40152. - */
  40153. - SeekableByteChannelCompat truncate(long size) throws IOException;
  40154. + /**
  40155. + * Truncates the entity, to which this channel is connected, to the given size.
  40156. + *
  40157. + * @param size The new size, a non-negative byte count
  40158. + * @return This channel
  40159. + * @throws NonWritableChannelException If this channel was not opened for writing
  40160. + * @throws ClosedChannelException If this channel is closed
  40161. + * @throws IllegalArgumentException If the new size is negative
  40162. + * @throws IOException If some other I/O error occurs
  40163. + */
  40164. + SeekableByteChannelCompat truncate(long size) throws IOException;
  40165. }
  40166. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ZipFile.java b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ZipFile.java
  40167. index 6b43e724fd814..c8a3fb806d920 100644
  40168. --- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ZipFile.java
  40169. +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ZipFile.java
  40170. @@ -45,393 +45,389 @@ import java.util.zip.ZipException;
  40171. * size limit for Zip64, which is 4GB.
  40172. */
  40173. final class ZipFile implements Closeable {
  40174. - /** Maps String to list of ZipEntrys, name -> actual entries. */
  40175. - private final Map<String, List<ZipEntry>> nameMap;
  40176. -
  40177. - /** The actual data source. */
  40178. - private final ByteBufferChannel archive;
  40179. -
  40180. - /**
  40181. - * Opens the given {@link ByteBufferChannel} for reading, assuming "UTF8" for file names. {@link
  40182. - * ZipFile} does not synchronized over the buffer that is passed into it.
  40183. - *
  40184. - * @param channel the archive
  40185. - * @throws IOException if an error occurs while creating this {@link ZipFile}
  40186. - * @throws ZipException if the channel is not a zip archive
  40187. - * @throws NullPointerException if the archive is null
  40188. - */
  40189. - public static ZipFile createFrom(ByteBufferChannel channel) throws IOException {
  40190. - checkNotNull(channel);
  40191. - ZipParser zipParser = new ZipParser(channel);
  40192. - Map<String, List<ZipEntry>> nameMap = zipParser.parseEntries();
  40193. - return new ZipFile(channel, nameMap);
  40194. - }
  40195. -
  40196. - @Override
  40197. - public void close() {
  40198. - archive.close();
  40199. - }
  40200. -
  40201. - /**
  40202. - * Exposes the raw stream of the archive entry.
  40203. - *
  40204. - * <p>Since the associated files will not be compressed when being packed to the zip file, the raw
  40205. - * stream represents the non-compressed files.
  40206. - *
  40207. - * <p><b>WARNING:</b> The returned {@link InputStream}, is <b>not</b> thread-safe. If multiple
  40208. - * threads concurrently reading from the returned {@link InputStream}, it must be synchronized
  40209. - * externally.
  40210. - *
  40211. - * @param name name of the entry to get the stream for
  40212. - * @return the raw input stream containing data
  40213. - * @throws IllegalArgumentException if the specified file does not exist in the zip file
  40214. - */
  40215. - public InputStream getRawInputStream(String name) {
  40216. - checkArgument(
  40217. - nameMap.containsKey(name),
  40218. - String.format("The file, %s, does not exist in the zip file.", name));
  40219. -
  40220. - List<ZipEntry> entriesWithTheSameName = nameMap.get(name);
  40221. - ZipEntry entry = entriesWithTheSameName.get(0);
  40222. - long start = entry.getDataOffset();
  40223. - long remaining = entry.getSize();
  40224. - return new BoundedInputStream(archive, start, remaining);
  40225. - }
  40226. -
  40227. - /**
  40228. - * Exposes the file names of the included files.
  40229. - *
  40230. - * @return the file names of the included files
  40231. - */
  40232. - public Set<String> getFileNames() {
  40233. - return nameMap.keySet();
  40234. - }
  40235. -
  40236. - private ZipFile(ByteBufferChannel channel, Map<String, List<ZipEntry>> nameMap) {
  40237. - archive = channel;
  40238. - this.nameMap = nameMap;
  40239. - }
  40240. -
  40241. - /* Parses a Zip archive and gets the information for each {@link ZipEntry}. */
  40242. - private static class ZipParser {
  40243. - private final ByteBufferChannel archive;
  40244. -
  40245. - // Cached buffers that will only be used locally in the class to reduce garbage collection.
  40246. - private final ByteBuffer longBuffer =
  40247. - ByteBuffer.allocate(ZipConstants.LONG_BYTE_SIZE).order(ByteOrder.LITTLE_ENDIAN);
  40248. - private final ByteBuffer intBuffer =
  40249. - ByteBuffer.allocate(ZipConstants.INT_BYTE_SIZE).order(ByteOrder.LITTLE_ENDIAN);
  40250. - private final ByteBuffer shortBuffer =
  40251. - ByteBuffer.allocate(ZipConstants.SHORT_BYTE_SIZE).order(ByteOrder.LITTLE_ENDIAN);
  40252. + /** Maps String to list of ZipEntrys, name -> actual entries. */
  40253. + private final Map<String, List<ZipEntry>> nameMap;
  40254. - private ZipParser(ByteBufferChannel archive) {
  40255. - this.archive = archive;
  40256. - }
  40257. -
  40258. - /**
  40259. - * Parses the underlying {@code archive} and returns the information as a list of {@link
  40260. - * ZipEntry}.
  40261. - */
  40262. - private Map<String, List<ZipEntry>> parseEntries() throws IOException {
  40263. - List<ZipEntry> entries = parseCentralDirectory();
  40264. - return parseLocalFileHeaderData(entries);
  40265. - }
  40266. -
  40267. - /**
  40268. - * Checks if the current position contains a central file header signature, {@link
  40269. - * ZipConstants#CENSIG}.
  40270. - */
  40271. - private boolean foundCentralFileheaderSignature() {
  40272. - long signature = (long) getInt();
  40273. - return signature == ZipConstants.CENSIG;
  40274. - }
  40275. -
  40276. - /**
  40277. - * Gets the value as a Java int from two bytes starting at the current position of the archive.
  40278. - */
  40279. - private int getShort() {
  40280. - shortBuffer.rewind();
  40281. - archive.read(shortBuffer);
  40282. - shortBuffer.flip();
  40283. - return (int) shortBuffer.getShort();
  40284. - }
  40285. + /** The actual data source. */
  40286. + private final ByteBufferChannel archive;
  40287. /**
  40288. - * Gets the value as a Java long from four bytes starting at the current position of the
  40289. - * archive.
  40290. + * Opens the given {@link ByteBufferChannel} for reading, assuming "UTF8" for file names. {@link
  40291. + * ZipFile} does not synchronized over the buffer that is passed into it.
  40292. + *
  40293. + * @param channel the archive
  40294. + * @throws IOException if an error occurs while creating this {@link ZipFile}
  40295. + * @throws ZipException if the channel is not a zip archive
  40296. + * @throws NullPointerException if the archive is null
  40297. */
  40298. - private int getInt() {
  40299. - intBuffer.rewind();
  40300. - archive.read(intBuffer);
  40301. - intBuffer.flip();
  40302. - return intBuffer.getInt();
  40303. + public static ZipFile createFrom(ByteBufferChannel channel) throws IOException {
  40304. + checkNotNull(channel);
  40305. + ZipParser zipParser = new ZipParser(channel);
  40306. + Map<String, List<ZipEntry>> nameMap = zipParser.parseEntries();
  40307. + return new ZipFile(channel, nameMap);
  40308. }
  40309. - /**
  40310. - * Gets the value as a Java long from four bytes starting at the current position of the
  40311. - * archive.
  40312. - */
  40313. - private long getLong() {
  40314. - longBuffer.rewind();
  40315. - archive.read(longBuffer);
  40316. - longBuffer.flip();
  40317. - return longBuffer.getLong();
  40318. + @Override
  40319. + public void close() {
  40320. + archive.close();
  40321. }
  40322. /**
  40323. - * Positions the archive at the start of the central directory.
  40324. + * Exposes the raw stream of the archive entry.
  40325. + *
  40326. + * <p>Since the associated files will not be compressed when being packed to the zip file, the
  40327. + * raw stream represents the non-compressed files.
  40328. *
  40329. - * <p>First, it searches for the signature of the "end of central directory record", {@link
  40330. - * ZipConstants#ENDSIG}. Position the stream at the start of the "end of central directory
  40331. - * record". The zip file are created without archive comments, thus {@link ZipConstants#ENDSIG}
  40332. - * should appear exactly at {@link ZipConstants#ENDHDR} from the end of the zip file.
  40333. + * <p><b>WARNING:</b> The returned {@link InputStream}, is <b>not</b> thread-safe. If multiple
  40334. + * threads concurrently reading from the returned {@link InputStream}, it must be synchronized
  40335. + * externally.
  40336. *
  40337. - * <p>Then, parse the "end of central dir record" and position the archive at the start of the
  40338. - * central directory.
  40339. + * @param name name of the entry to get the stream for
  40340. + * @return the raw input stream containing data
  40341. + * @throws IllegalArgumentException if the specified file does not exist in the zip file
  40342. */
  40343. - private void locateCentralDirectory() throws IOException {
  40344. - if (archive.size() < ZipConstants.ENDHDR) {
  40345. - throw new ZipException("The archive is not a ZIP archive.");
  40346. - }
  40347. -
  40348. - // Positions the archive at the start of the "end of central directory record".
  40349. - long offsetRecord = archive.size() - ZipConstants.ENDHDR;
  40350. - archive.position(offsetRecord);
  40351. -
  40352. - // Checks for the signature, {@link ZipConstants#ENDSIG}.
  40353. - long endSig = getLong();
  40354. - if (endSig != ZipConstants.ENDSIG) {
  40355. - throw new ZipException("The archive is not a ZIP archive.");
  40356. - }
  40357. -
  40358. - // Positions the archive at the “offset of central directory”.
  40359. - skipBytes(ZipConstants.ENDOFF - ZipConstants.ENDSUB);
  40360. - // Gets the offset to central directory
  40361. - long offsetDirectory = getInt();
  40362. - // Goes to the central directory.
  40363. - archive.position(offsetDirectory);
  40364. + public InputStream getRawInputStream(String name) {
  40365. + checkArgument(nameMap.containsKey(name),
  40366. + String.format("The file, %s, does not exist in the zip file.", name));
  40367. +
  40368. + List<ZipEntry> entriesWithTheSameName = nameMap.get(name);
  40369. + ZipEntry entry = entriesWithTheSameName.get(0);
  40370. + long start = entry.getDataOffset();
  40371. + long remaining = entry.getSize();
  40372. + return new BoundedInputStream(archive, start, remaining);
  40373. }
  40374. /**
  40375. - * Reads the central directory of the given archive and populates the internal tables with
  40376. - * {@link ZipEntry} instances.
  40377. + * Exposes the file names of the included files.
  40378. + *
  40379. + * @return the file names of the included files
  40380. */
  40381. - private List<ZipEntry> parseCentralDirectory() throws IOException {
  40382. - /** List of entries in the order they appear inside the central directory. */
  40383. - List<ZipEntry> entries = new ArrayList<>();
  40384. - locateCentralDirectory();
  40385. -
  40386. - while (foundCentralFileheaderSignature()) {
  40387. - ZipEntry entry = parseCentralDirectoryEntry();
  40388. - entries.add(entry);
  40389. - }
  40390. -
  40391. - return entries;
  40392. + public Set<String> getFileNames() {
  40393. + return nameMap.keySet();
  40394. }
  40395. - /**
  40396. - * Reads an individual entry of the central directory, creats an ZipEntry from it and adds it to
  40397. - * the global maps.
  40398. - */
  40399. - private ZipEntry parseCentralDirectoryEntry() throws IOException {
  40400. - // Positions the archive at the "compressed size" and read the value.
  40401. - skipBytes(ZipConstants.CENSIZ - ZipConstants.CENVEM);
  40402. - long compressSize = getInt();
  40403. -
  40404. - // Positions the archive at the "filename length" and read the value.
  40405. - skipBytes(ZipConstants.CENNAM - ZipConstants.CENLEN);
  40406. - int fileNameLen = getShort();
  40407. -
  40408. - // Reads the extra field length and the comment length.
  40409. - int extraLen = getShort();
  40410. - int commentLen = getShort();
  40411. -
  40412. - // Positions the archive at the "local file header offset" and read the value.
  40413. - skipBytes(ZipConstants.CENOFF - ZipConstants.CENDSK);
  40414. - long localHeaderOffset = getInt();
  40415. -
  40416. - // Reads the file name.
  40417. - byte[] fileNameBuf = new byte[fileNameLen];
  40418. - archive.read(ByteBuffer.wrap(fileNameBuf));
  40419. - String fileName = new String(fileNameBuf, Charset.forName("UTF-8"));
  40420. + private ZipFile(ByteBufferChannel channel, Map<String, List<ZipEntry>> nameMap) {
  40421. + archive = channel;
  40422. + this.nameMap = nameMap;
  40423. + }
  40424. - // Skips the extra field and the comment.
  40425. - skipBytes(extraLen + commentLen);
  40426. + /* Parses a Zip archive and gets the information for each {@link ZipEntry}. */
  40427. + private static class ZipParser {
  40428. + private final ByteBufferChannel archive;
  40429. - ZipEntry entry = new ZipEntry();
  40430. - entry.setSize(compressSize);
  40431. - entry.setLocalHeaderOffset(localHeaderOffset);
  40432. - entry.setName(fileName);
  40433. + // Cached buffers that will only be used locally in the class to reduce garbage collection.
  40434. + private final ByteBuffer longBuffer =
  40435. + ByteBuffer.allocate(ZipConstants.LONG_BYTE_SIZE).order(ByteOrder.LITTLE_ENDIAN);
  40436. + private final ByteBuffer intBuffer =
  40437. + ByteBuffer.allocate(ZipConstants.INT_BYTE_SIZE).order(ByteOrder.LITTLE_ENDIAN);
  40438. + private final ByteBuffer shortBuffer =
  40439. + ByteBuffer.allocate(ZipConstants.SHORT_BYTE_SIZE).order(ByteOrder.LITTLE_ENDIAN);
  40440. - return entry;
  40441. - }
  40442. + private ZipParser(ByteBufferChannel archive) {
  40443. + this.archive = archive;
  40444. + }
  40445. - /** Walks through all recorded entries and records the offsets for the entry data. */
  40446. - private Map<String, List<ZipEntry>> parseLocalFileHeaderData(List<ZipEntry> entries) {
  40447. - /** Maps String to list of ZipEntrys, name -> actual entries. */
  40448. - Map<String, List<ZipEntry>> nameMap = new LinkedHashMap<>();
  40449. -
  40450. - for (ZipEntry entry : entries) {
  40451. - long offset = entry.getLocalHeaderOffset();
  40452. - archive.position(offset + ZipConstants.LOCNAM);
  40453. -
  40454. - // Gets the data offset of this entry.
  40455. - int fileNameLen = getShort();
  40456. - int extraFieldLen = getShort();
  40457. - long dataOffset =
  40458. - offset
  40459. - + ZipConstants.LOCEXT
  40460. - + ZipConstants.SHORT_BYTE_SIZE
  40461. - + fileNameLen
  40462. - + extraFieldLen;
  40463. - entry.setDataOffset(dataOffset);
  40464. -
  40465. - // Puts the entry into the nameMap.
  40466. - String name = entry.getName();
  40467. - List<ZipEntry> entriesWithTheSameName;
  40468. - if (nameMap.containsKey(name)) {
  40469. - entriesWithTheSameName = nameMap.get(name);
  40470. - } else {
  40471. - entriesWithTheSameName = new ArrayList<>();
  40472. - nameMap.put(name, entriesWithTheSameName);
  40473. + /**
  40474. + * Parses the underlying {@code archive} and returns the information as a list of {@link
  40475. + * ZipEntry}.
  40476. + */
  40477. + private Map<String, List<ZipEntry>> parseEntries() throws IOException {
  40478. + List<ZipEntry> entries = parseCentralDirectory();
  40479. + return parseLocalFileHeaderData(entries);
  40480. }
  40481. - entriesWithTheSameName.add(entry);
  40482. - }
  40483. - return nameMap;
  40484. - }
  40485. + /**
  40486. + * Checks if the current position contains a central file header signature, {@link
  40487. + * ZipConstants#CENSIG}.
  40488. + */
  40489. + private boolean foundCentralFileheaderSignature() {
  40490. + long signature = (long) getInt();
  40491. + return signature == ZipConstants.CENSIG;
  40492. + }
  40493. - /** Skips the given number of bytes or throws an EOFException if skipping failed. */
  40494. - private void skipBytes(int count) throws IOException {
  40495. - long currentPosition = archive.position();
  40496. - long newPosition = currentPosition + count;
  40497. - if (newPosition > archive.size()) {
  40498. - throw new EOFException();
  40499. - }
  40500. - archive.position(newPosition);
  40501. - }
  40502. - }
  40503. + /**
  40504. + * Gets the value as a Java int from two bytes starting at the current position of the
  40505. + * archive.
  40506. + */
  40507. + private int getShort() {
  40508. + shortBuffer.rewind();
  40509. + archive.read(shortBuffer);
  40510. + shortBuffer.flip();
  40511. + return (int) shortBuffer.getShort();
  40512. + }
  40513. - /** Stores the data offset and the size of an entry in the archive. */
  40514. - private static class ZipEntry {
  40515. + /**
  40516. + * Gets the value as a Java long from four bytes starting at the current position of the
  40517. + * archive.
  40518. + */
  40519. + private int getInt() {
  40520. + intBuffer.rewind();
  40521. + archive.read(intBuffer);
  40522. + intBuffer.flip();
  40523. + return intBuffer.getInt();
  40524. + }
  40525. - private String name;
  40526. - private long dataOffset = -1;
  40527. - private long size = -1;
  40528. - private long localHeaderOffset = -1;
  40529. + /**
  40530. + * Gets the value as a Java long from four bytes starting at the current position of the
  40531. + * archive.
  40532. + */
  40533. + private long getLong() {
  40534. + longBuffer.rewind();
  40535. + archive.read(longBuffer);
  40536. + longBuffer.flip();
  40537. + return longBuffer.getLong();
  40538. + }
  40539. - public long getSize() {
  40540. - return size;
  40541. - }
  40542. + /**
  40543. + * Positions the archive at the start of the central directory.
  40544. + *
  40545. + * <p>First, it searches for the signature of the "end of central directory record", {@link
  40546. + * ZipConstants#ENDSIG}. Position the stream at the start of the "end of central directory
  40547. + * record". The zip file are created without archive comments, thus {@link
  40548. + * ZipConstants#ENDSIG} should appear exactly at {@link ZipConstants#ENDHDR} from the end of
  40549. + * the zip file.
  40550. + *
  40551. + * <p>Then, parse the "end of central dir record" and position the archive at the start of
  40552. + * the central directory.
  40553. + */
  40554. + private void locateCentralDirectory() throws IOException {
  40555. + if (archive.size() < ZipConstants.ENDHDR) {
  40556. + throw new ZipException("The archive is not a ZIP archive.");
  40557. + }
  40558. +
  40559. + // Positions the archive at the start of the "end of central directory record".
  40560. + long offsetRecord = archive.size() - ZipConstants.ENDHDR;
  40561. + archive.position(offsetRecord);
  40562. +
  40563. + // Checks for the signature, {@link ZipConstants#ENDSIG}.
  40564. + long endSig = getLong();
  40565. + if (endSig != ZipConstants.ENDSIG) {
  40566. + throw new ZipException("The archive is not a ZIP archive.");
  40567. + }
  40568. +
  40569. + // Positions the archive at the “offset of central directory”.
  40570. + skipBytes(ZipConstants.ENDOFF - ZipConstants.ENDSUB);
  40571. + // Gets the offset to central directory
  40572. + long offsetDirectory = getInt();
  40573. + // Goes to the central directory.
  40574. + archive.position(offsetDirectory);
  40575. + }
  40576. - public long getDataOffset() {
  40577. - return dataOffset;
  40578. - }
  40579. + /**
  40580. + * Reads the central directory of the given archive and populates the internal tables with
  40581. + * {@link ZipEntry} instances.
  40582. + */
  40583. + private List<ZipEntry> parseCentralDirectory() throws IOException {
  40584. + /** List of entries in the order they appear inside the central directory. */
  40585. + List<ZipEntry> entries = new ArrayList<>();
  40586. + locateCentralDirectory();
  40587. +
  40588. + while (foundCentralFileheaderSignature()) {
  40589. + ZipEntry entry = parseCentralDirectoryEntry();
  40590. + entries.add(entry);
  40591. + }
  40592. +
  40593. + return entries;
  40594. + }
  40595. - public String getName() {
  40596. - return name;
  40597. - }
  40598. + /**
  40599. + * Reads an individual entry of the central directory, creats an ZipEntry from it and adds
  40600. + * it to the global maps.
  40601. + */
  40602. + private ZipEntry parseCentralDirectoryEntry() throws IOException {
  40603. + // Positions the archive at the "compressed size" and read the value.
  40604. + skipBytes(ZipConstants.CENSIZ - ZipConstants.CENVEM);
  40605. + long compressSize = getInt();
  40606. +
  40607. + // Positions the archive at the "filename length" and read the value.
  40608. + skipBytes(ZipConstants.CENNAM - ZipConstants.CENLEN);
  40609. + int fileNameLen = getShort();
  40610. +
  40611. + // Reads the extra field length and the comment length.
  40612. + int extraLen = getShort();
  40613. + int commentLen = getShort();
  40614. +
  40615. + // Positions the archive at the "local file header offset" and read the value.
  40616. + skipBytes(ZipConstants.CENOFF - ZipConstants.CENDSK);
  40617. + long localHeaderOffset = getInt();
  40618. +
  40619. + // Reads the file name.
  40620. + byte[] fileNameBuf = new byte[fileNameLen];
  40621. + archive.read(ByteBuffer.wrap(fileNameBuf));
  40622. + String fileName = new String(fileNameBuf, Charset.forName("UTF-8"));
  40623. +
  40624. + // Skips the extra field and the comment.
  40625. + skipBytes(extraLen + commentLen);
  40626. +
  40627. + ZipEntry entry = new ZipEntry();
  40628. + entry.setSize(compressSize);
  40629. + entry.setLocalHeaderOffset(localHeaderOffset);
  40630. + entry.setName(fileName);
  40631. +
  40632. + return entry;
  40633. + }
  40634. - public long getLocalHeaderOffset() {
  40635. - return localHeaderOffset;
  40636. - }
  40637. + /** Walks through all recorded entries and records the offsets for the entry data. */
  40638. + private Map<String, List<ZipEntry>> parseLocalFileHeaderData(List<ZipEntry> entries) {
  40639. + /** Maps String to list of ZipEntrys, name -> actual entries. */
  40640. + Map<String, List<ZipEntry>> nameMap = new LinkedHashMap<>();
  40641. +
  40642. + for (ZipEntry entry : entries) {
  40643. + long offset = entry.getLocalHeaderOffset();
  40644. + archive.position(offset + ZipConstants.LOCNAM);
  40645. +
  40646. + // Gets the data offset of this entry.
  40647. + int fileNameLen = getShort();
  40648. + int extraFieldLen = getShort();
  40649. + long dataOffset = offset + ZipConstants.LOCEXT + ZipConstants.SHORT_BYTE_SIZE
  40650. + + fileNameLen + extraFieldLen;
  40651. + entry.setDataOffset(dataOffset);
  40652. +
  40653. + // Puts the entry into the nameMap.
  40654. + String name = entry.getName();
  40655. + List<ZipEntry> entriesWithTheSameName;
  40656. + if (nameMap.containsKey(name)) {
  40657. + entriesWithTheSameName = nameMap.get(name);
  40658. + } else {
  40659. + entriesWithTheSameName = new ArrayList<>();
  40660. + nameMap.put(name, entriesWithTheSameName);
  40661. + }
  40662. + entriesWithTheSameName.add(entry);
  40663. + }
  40664. +
  40665. + return nameMap;
  40666. + }
  40667. - public void setSize(long size) {
  40668. - this.size = size;
  40669. + /** Skips the given number of bytes or throws an EOFException if skipping failed. */
  40670. + private void skipBytes(int count) throws IOException {
  40671. + long currentPosition = archive.position();
  40672. + long newPosition = currentPosition + count;
  40673. + if (newPosition > archive.size()) {
  40674. + throw new EOFException();
  40675. + }
  40676. + archive.position(newPosition);
  40677. + }
  40678. }
  40679. - public void setDataOffset(long dataOffset) {
  40680. - this.dataOffset = dataOffset;
  40681. - }
  40682. + /** Stores the data offset and the size of an entry in the archive. */
  40683. + private static class ZipEntry {
  40684. + private String name;
  40685. + private long dataOffset = -1;
  40686. + private long size = -1;
  40687. + private long localHeaderOffset = -1;
  40688. - public void setName(String name) {
  40689. - this.name = name;
  40690. - }
  40691. + public long getSize() {
  40692. + return size;
  40693. + }
  40694. - public void setLocalHeaderOffset(long localHeaderOffset) {
  40695. - this.localHeaderOffset = localHeaderOffset;
  40696. - }
  40697. - }
  40698. + public long getDataOffset() {
  40699. + return dataOffset;
  40700. + }
  40701. - /**
  40702. - * Various constants for this {@link ZipFile}.
  40703. - *
  40704. - * <p>Referenced from {@link java.util.zip.ZipConstants}.
  40705. - */
  40706. - private static class ZipConstants {
  40707. - /** length of Java short in bytes. */
  40708. - static final int SHORT_BYTE_SIZE = Short.SIZE / 8;
  40709. + public String getName() {
  40710. + return name;
  40711. + }
  40712. - /** length of Java int in bytes. */
  40713. - static final int INT_BYTE_SIZE = Integer.SIZE / 8;
  40714. + public long getLocalHeaderOffset() {
  40715. + return localHeaderOffset;
  40716. + }
  40717. - /** length of Java long in bytes. */
  40718. - static final int LONG_BYTE_SIZE = Long.SIZE / 8;
  40719. + public void setSize(long size) {
  40720. + this.size = size;
  40721. + }
  40722. - /*
  40723. - * Header signatures
  40724. - */
  40725. - static final long LOCSIG = 0x04034b50L; // "PK\003\004"
  40726. - static final long EXTSIG = 0x08074b50L; // "PK\007\008"
  40727. - static final long CENSIG = 0x02014b50L; // "PK\001\002"
  40728. - static final long ENDSIG = 0x06054b50L; // "PK\005\006"
  40729. + public void setDataOffset(long dataOffset) {
  40730. + this.dataOffset = dataOffset;
  40731. + }
  40732. - /*
  40733. - * Header sizes in bytes (including signatures)
  40734. - */
  40735. - static final int LOCHDR = 30; // LOC header size
  40736. - static final int EXTHDR = 16; // EXT header size
  40737. - static final int CENHDR = 46; // CEN header size
  40738. - static final int ENDHDR = 22; // END header size
  40739. + public void setName(String name) {
  40740. + this.name = name;
  40741. + }
  40742. - /*
  40743. - * Local file (LOC) header field offsets
  40744. - */
  40745. - static final int LOCVER = 4; // version needed to extract
  40746. - static final int LOCFLG = 6; // general purpose bit flag
  40747. - static final int LOCHOW = 8; // compression method
  40748. - static final int LOCTIM = 10; // modification time
  40749. - static final int LOCCRC = 14; // uncompressed file crc-32 value
  40750. - static final int LOCSIZ = 18; // compressed size
  40751. - static final int LOCLEN = 22; // uncompressed size
  40752. - static final int LOCNAM = 26; // filename length
  40753. - static final int LOCEXT = 28; // extra field length
  40754. -
  40755. - /*
  40756. - * Extra local (EXT) header field offsets
  40757. - */
  40758. - static final int EXTCRC = 4; // uncompressed file crc-32 value
  40759. - static final int EXTSIZ = 8; // compressed size
  40760. - static final int EXTLEN = 12; // uncompressed size
  40761. + public void setLocalHeaderOffset(long localHeaderOffset) {
  40762. + this.localHeaderOffset = localHeaderOffset;
  40763. + }
  40764. + }
  40765. - /*
  40766. - * Central directory (CEN) header field offsets
  40767. - */
  40768. - static final int CENVEM = 4; // version made by
  40769. - static final int CENVER = 6; // version needed to extract
  40770. - static final int CENFLG = 8; // encrypt, decrypt flags
  40771. - static final int CENHOW = 10; // compression method
  40772. - static final int CENTIM = 12; // modification time
  40773. - static final int CENCRC = 16; // uncompressed file crc-32 value
  40774. - static final int CENSIZ = 20; // compressed size
  40775. - static final int CENLEN = 24; // uncompressed size
  40776. - static final int CENNAM = 28; // filename length
  40777. - static final int CENEXT = 30; // extra field length
  40778. - static final int CENCOM = 32; // comment length
  40779. - static final int CENDSK = 34; // disk number start
  40780. - static final int CENATT = 36; // internal file attributes
  40781. - static final int CENATX = 38; // external file attributes
  40782. - static final int CENOFF = 42; // LOC header offset
  40783. -
  40784. - /*
  40785. - * End of central directory (END) header field offsets
  40786. + /**
  40787. + * Various constants for this {@link ZipFile}.
  40788. + *
  40789. + * <p>Referenced from {@link java.util.zip.ZipConstants}.
  40790. */
  40791. - static final int ENDSUB = 8; // number of entries on this disk
  40792. - static final int ENDTOT = 10; // total number of entries
  40793. - static final int ENDSIZ = 12; // central directory size in bytes
  40794. - static final int ENDOFF = 16; // offset of first CEN header
  40795. - static final int ENDCOM = 20; // zip file comment length
  40796. -
  40797. - private ZipConstants() {}
  40798. - }
  40799. + private static class ZipConstants {
  40800. + /** length of Java short in bytes. */
  40801. + static final int SHORT_BYTE_SIZE = Short.SIZE / 8;
  40802. +
  40803. + /** length of Java int in bytes. */
  40804. + static final int INT_BYTE_SIZE = Integer.SIZE / 8;
  40805. +
  40806. + /** length of Java long in bytes. */
  40807. + static final int LONG_BYTE_SIZE = Long.SIZE / 8;
  40808. +
  40809. + /*
  40810. + * Header signatures
  40811. + */
  40812. + static final long LOCSIG = 0x04034b50L; // "PK\003\004"
  40813. + static final long EXTSIG = 0x08074b50L; // "PK\007\008"
  40814. + static final long CENSIG = 0x02014b50L; // "PK\001\002"
  40815. + static final long ENDSIG = 0x06054b50L; // "PK\005\006"
  40816. +
  40817. + /*
  40818. + * Header sizes in bytes (including signatures)
  40819. + */
  40820. + static final int LOCHDR = 30; // LOC header size
  40821. + static final int EXTHDR = 16; // EXT header size
  40822. + static final int CENHDR = 46; // CEN header size
  40823. + static final int ENDHDR = 22; // END header size
  40824. +
  40825. + /*
  40826. + * Local file (LOC) header field offsets
  40827. + */
  40828. + static final int LOCVER = 4; // version needed to extract
  40829. + static final int LOCFLG = 6; // general purpose bit flag
  40830. + static final int LOCHOW = 8; // compression method
  40831. + static final int LOCTIM = 10; // modification time
  40832. + static final int LOCCRC = 14; // uncompressed file crc-32 value
  40833. + static final int LOCSIZ = 18; // compressed size
  40834. + static final int LOCLEN = 22; // uncompressed size
  40835. + static final int LOCNAM = 26; // filename length
  40836. + static final int LOCEXT = 28; // extra field length
  40837. +
  40838. + /*
  40839. + * Extra local (EXT) header field offsets
  40840. + */
  40841. + static final int EXTCRC = 4; // uncompressed file crc-32 value
  40842. + static final int EXTSIZ = 8; // compressed size
  40843. + static final int EXTLEN = 12; // uncompressed size
  40844. +
  40845. + /*
  40846. + * Central directory (CEN) header field offsets
  40847. + */
  40848. + static final int CENVEM = 4; // version made by
  40849. + static final int CENVER = 6; // version needed to extract
  40850. + static final int CENFLG = 8; // encrypt, decrypt flags
  40851. + static final int CENHOW = 10; // compression method
  40852. + static final int CENTIM = 12; // modification time
  40853. + static final int CENCRC = 16; // uncompressed file crc-32 value
  40854. + static final int CENSIZ = 20; // compressed size
  40855. + static final int CENLEN = 24; // uncompressed size
  40856. + static final int CENNAM = 28; // filename length
  40857. + static final int CENEXT = 30; // extra field length
  40858. + static final int CENCOM = 32; // comment length
  40859. + static final int CENDSK = 34; // disk number start
  40860. + static final int CENATT = 36; // internal file attributes
  40861. + static final int CENATX = 38; // external file attributes
  40862. + static final int CENOFF = 42; // LOC header offset
  40863. +
  40864. + /*
  40865. + * End of central directory (END) header field offsets
  40866. + */
  40867. + static final int ENDSUB = 8; // number of entries on this disk
  40868. + static final int ENDTOT = 10; // total number of entries
  40869. + static final int ENDSIZ = 12; // central directory size in bytes
  40870. + static final int ENDOFF = 16; // offset of first CEN header
  40871. + static final int ENDCOM = 20; // zip file comment length
  40872. +
  40873. + private ZipConstants() {}
  40874. + }
  40875. }
  40876. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/BoundedInputStreamTest.java b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/BoundedInputStreamTest.java
  40877. index 3847bc1d2ce01..e0825a1fe7862 100644
  40878. --- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/BoundedInputStreamTest.java
  40879. +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/BoundedInputStreamTest.java
  40880. @@ -16,244 +16,223 @@ limitations under the License.
  40881. package org.tensorflow.lite.support.metadata;
  40882. import static com.google.common.truth.Truth.assertThat;
  40883. +
  40884. import static org.junit.Assert.assertArrayEquals;
  40885. import static org.junit.Assert.assertThrows;
  40886. -import java.nio.ByteBuffer;
  40887. import org.junit.Test;
  40888. import org.junit.runner.RunWith;
  40889. import org.robolectric.RobolectricTestRunner;
  40890. +import java.nio.ByteBuffer;
  40891. +
  40892. /** Tests of {@link BoundedInputStream}. */
  40893. @RunWith(RobolectricTestRunner.class)
  40894. public class BoundedInputStreamTest {
  40895. + private static final byte[] testBytes = new byte[] {10, 20, 30, 40, 50};
  40896. + private static final int[] testInts = new int[] {10, 20, 30, 40, 50};
  40897. + private static final int TEST_BYTES_LENGTH = testBytes.length;
  40898. +
  40899. + @Test
  40900. + public void boundedInputStream_negtiveStart_throwsException() throws Exception {
  40901. + long start = -1;
  40902. + long remaining = 2;
  40903. + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
  40904. + () -> createBoundedInputStream(testBytes, start, remaining));
  40905. + assertThat(exception).hasMessageThat().isEqualTo(String.format(
  40906. + "Invalid length of stream at offset=%d, length=%d", start, remaining));
  40907. + }
  40908. +
  40909. + @Test
  40910. + public void boundedInputStream_negtiveRemaining_throwsException() throws Exception {
  40911. + long start = 1;
  40912. + long remaining = -2;
  40913. + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
  40914. + () -> createBoundedInputStream(testBytes, start, remaining));
  40915. + assertThat(exception).hasMessageThat().isEqualTo(String.format(
  40916. + "Invalid length of stream at offset=%d, length=%d", start, remaining));
  40917. + }
  40918. +
  40919. + @Test
  40920. + public void available_atStart() throws Exception {
  40921. + int start = 3;
  40922. + BoundedInputStream boundedInputStream =
  40923. + createBoundedInputStream(testBytes, start, TEST_BYTES_LENGTH);
  40924. +
  40925. + int available = boundedInputStream.available();
  40926. + assertThat(available).isEqualTo(TEST_BYTES_LENGTH - start);
  40927. + }
  40928. +
  40929. + @Test
  40930. + public void available_afterRead() throws Exception {
  40931. + BoundedInputStream boundedInputStream =
  40932. + createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH);
  40933. + // Read a byte out of boundedInputStream. The number of remaining bytes is TEST_BYTES_LENGTH
  40934. + // -1.
  40935. + boundedInputStream.read();
  40936. +
  40937. + int available = boundedInputStream.available();
  40938. + assertThat(available).isEqualTo(TEST_BYTES_LENGTH - 1);
  40939. + }
  40940. +
  40941. + @Test
  40942. + public void read_repeatedRead() throws Exception {
  40943. + int[] values = new int[TEST_BYTES_LENGTH];
  40944. + BoundedInputStream boundedInputStream =
  40945. + createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH);
  40946. +
  40947. + for (int i = 0; i < TEST_BYTES_LENGTH; i++) {
  40948. + values[i] = boundedInputStream.read();
  40949. + }
  40950. +
  40951. + assertArrayEquals(testInts, values);
  40952. + }
  40953. +
  40954. + @Test
  40955. + public void read_reachTheEnd() throws Exception {
  40956. + BoundedInputStream boundedInputStream =
  40957. + createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH);
  40958. + boundedInputStream.skip(TEST_BYTES_LENGTH);
  40959. + int value = boundedInputStream.read();
  40960. +
  40961. + assertThat(value).isEqualTo(-1);
  40962. + }
  40963. +
  40964. + @Test
  40965. + public void read_channelSizeisSmallerThanTheStreamSpecified() throws Exception {
  40966. + BoundedInputStream boundedInputStream =
  40967. + createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH + 1);
  40968. + boundedInputStream.skip(TEST_BYTES_LENGTH);
  40969. +
  40970. + int value = boundedInputStream.read();
  40971. +
  40972. + assertThat(value).isEqualTo(-1);
  40973. + }
  40974. - private static final byte[] testBytes = new byte[] {10, 20, 30, 40, 50};
  40975. - private static final int[] testInts = new int[] {10, 20, 30, 40, 50};
  40976. - private static final int TEST_BYTES_LENGTH = testBytes.length;
  40977. -
  40978. - @Test
  40979. - public void boundedInputStream_negtiveStart_throwsException() throws Exception {
  40980. - long start = -1;
  40981. - long remaining = 2;
  40982. - IllegalArgumentException exception =
  40983. - assertThrows(
  40984. - IllegalArgumentException.class,
  40985. - () -> createBoundedInputStream(testBytes, start, remaining));
  40986. - assertThat(exception)
  40987. - .hasMessageThat()
  40988. - .isEqualTo(
  40989. - String.format("Invalid length of stream at offset=%d, length=%d", start, remaining));
  40990. - }
  40991. -
  40992. - @Test
  40993. - public void boundedInputStream_negtiveRemaining_throwsException() throws Exception {
  40994. - long start = 1;
  40995. - long remaining = -2;
  40996. - IllegalArgumentException exception =
  40997. - assertThrows(
  40998. - IllegalArgumentException.class,
  40999. - () -> createBoundedInputStream(testBytes, start, remaining));
  41000. - assertThat(exception)
  41001. - .hasMessageThat()
  41002. - .isEqualTo(
  41003. - String.format("Invalid length of stream at offset=%d, length=%d", start, remaining));
  41004. - }
  41005. -
  41006. - @Test
  41007. - public void available_atStart() throws Exception {
  41008. - int start = 3;
  41009. - BoundedInputStream boundedInputStream =
  41010. - createBoundedInputStream(testBytes, start, TEST_BYTES_LENGTH);
  41011. -
  41012. - int available = boundedInputStream.available();
  41013. - assertThat(available).isEqualTo(TEST_BYTES_LENGTH - start);
  41014. - }
  41015. -
  41016. - @Test
  41017. - public void available_afterRead() throws Exception {
  41018. - BoundedInputStream boundedInputStream =
  41019. - createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH);
  41020. - // Read a byte out of boundedInputStream. The number of remaining bytes is TEST_BYTES_LENGTH -1.
  41021. - boundedInputStream.read();
  41022. -
  41023. - int available = boundedInputStream.available();
  41024. - assertThat(available).isEqualTo(TEST_BYTES_LENGTH - 1);
  41025. - }
  41026. -
  41027. - @Test
  41028. - public void read_repeatedRead() throws Exception {
  41029. - int[] values = new int[TEST_BYTES_LENGTH];
  41030. - BoundedInputStream boundedInputStream =
  41031. - createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH);
  41032. -
  41033. - for (int i = 0; i < TEST_BYTES_LENGTH; i++) {
  41034. - values[i] = boundedInputStream.read();
  41035. + @Test
  41036. + public void readArray_nullArray_throwsException() throws Exception {
  41037. + byte[] array = null;
  41038. + int offset = 0;
  41039. + int length = 1;
  41040. + BoundedInputStream boundedInputStream =
  41041. + createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH);
  41042. +
  41043. + NullPointerException exception = assertThrows(
  41044. + NullPointerException.class, () -> boundedInputStream.read(array, offset, length));
  41045. + assertThat(exception).hasMessageThat().isEqualTo("The object reference is null.");
  41046. }
  41047. - assertArrayEquals(testInts, values);
  41048. - }
  41049. -
  41050. - @Test
  41051. - public void read_reachTheEnd() throws Exception {
  41052. - BoundedInputStream boundedInputStream =
  41053. - createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH);
  41054. - boundedInputStream.skip(TEST_BYTES_LENGTH);
  41055. - int value = boundedInputStream.read();
  41056. -
  41057. - assertThat(value).isEqualTo(-1);
  41058. - }
  41059. -
  41060. - @Test
  41061. - public void read_channelSizeisSmallerThanTheStreamSpecified() throws Exception {
  41062. - BoundedInputStream boundedInputStream =
  41063. - createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH + 1);
  41064. - boundedInputStream.skip(TEST_BYTES_LENGTH);
  41065. -
  41066. - int value = boundedInputStream.read();
  41067. -
  41068. - assertThat(value).isEqualTo(-1);
  41069. - }
  41070. -
  41071. - @Test
  41072. - public void readArray_nullArray_throwsException() throws Exception {
  41073. - byte[] array = null;
  41074. - int offset = 0;
  41075. - int length = 1;
  41076. - BoundedInputStream boundedInputStream =
  41077. - createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH);
  41078. -
  41079. - NullPointerException exception =
  41080. - assertThrows(
  41081. - NullPointerException.class, () -> boundedInputStream.read(array, offset, length));
  41082. - assertThat(exception).hasMessageThat().isEqualTo("The object reference is null.");
  41083. - }
  41084. -
  41085. - @Test
  41086. - public void readArray_negativeOffset_throwsException() throws Exception {
  41087. - byte[] array = new byte[5];
  41088. - int offset = -1;
  41089. - int length = array.length;
  41090. - BoundedInputStream boundedInputStream =
  41091. - createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH);
  41092. -
  41093. - IndexOutOfBoundsException exception =
  41094. - assertThrows(
  41095. - IndexOutOfBoundsException.class, () -> boundedInputStream.read(array, offset, length));
  41096. - assertThat(exception)
  41097. - .hasMessageThat()
  41098. - .isEqualTo(String.format("The start offset (%s) must not be negative", offset));
  41099. - }
  41100. -
  41101. - @Test
  41102. - public void readArray_OffsetEqualsArrayLength_throwsException() throws Exception {
  41103. - byte[] array = new byte[5];
  41104. - int offset = array.length;
  41105. - int length = 0;
  41106. - BoundedInputStream boundedInputStream =
  41107. - createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH);
  41108. -
  41109. - IndexOutOfBoundsException exception =
  41110. - assertThrows(
  41111. - IndexOutOfBoundsException.class, () -> boundedInputStream.read(array, offset, length));
  41112. - assertThat(exception)
  41113. - .hasMessageThat()
  41114. - .isEqualTo(
  41115. - String.format(
  41116. + @Test
  41117. + public void readArray_negativeOffset_throwsException() throws Exception {
  41118. + byte[] array = new byte[5];
  41119. + int offset = -1;
  41120. + int length = array.length;
  41121. + BoundedInputStream boundedInputStream =
  41122. + createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH);
  41123. +
  41124. + IndexOutOfBoundsException exception = assertThrows(IndexOutOfBoundsException.class,
  41125. + () -> boundedInputStream.read(array, offset, length));
  41126. + assertThat(exception).hasMessageThat().isEqualTo(
  41127. + String.format("The start offset (%s) must not be negative", offset));
  41128. + }
  41129. +
  41130. + @Test
  41131. + public void readArray_OffsetEqualsArrayLength_throwsException() throws Exception {
  41132. + byte[] array = new byte[5];
  41133. + int offset = array.length;
  41134. + int length = 0;
  41135. + BoundedInputStream boundedInputStream =
  41136. + createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH);
  41137. +
  41138. + IndexOutOfBoundsException exception = assertThrows(IndexOutOfBoundsException.class,
  41139. + () -> boundedInputStream.read(array, offset, length));
  41140. + assertThat(exception).hasMessageThat().isEqualTo(String.format(
  41141. "The start offset (%s) must be less than size (%s)", offset, array.length));
  41142. - }
  41143. -
  41144. - @Test
  41145. - public void readArray_negativeLength_throwsException() throws Exception {
  41146. - byte[] array = new byte[5];
  41147. - int offset = 0;
  41148. - int length = -1;
  41149. - BoundedInputStream boundedInputStream =
  41150. - createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH);
  41151. -
  41152. - IndexOutOfBoundsException exception =
  41153. - assertThrows(
  41154. - IndexOutOfBoundsException.class, () -> boundedInputStream.read(array, offset, length));
  41155. - assertThat(exception)
  41156. - .hasMessageThat()
  41157. - .isEqualTo(
  41158. - String.format(
  41159. + }
  41160. +
  41161. + @Test
  41162. + public void readArray_negativeLength_throwsException() throws Exception {
  41163. + byte[] array = new byte[5];
  41164. + int offset = 0;
  41165. + int length = -1;
  41166. + BoundedInputStream boundedInputStream =
  41167. + createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH);
  41168. +
  41169. + IndexOutOfBoundsException exception = assertThrows(IndexOutOfBoundsException.class,
  41170. + () -> boundedInputStream.read(array, offset, length));
  41171. + assertThat(exception).hasMessageThat().isEqualTo(String.format(
  41172. "The maximumn number of bytes to read (%s) must not be negative", length));
  41173. - }
  41174. -
  41175. - @Test
  41176. - public void readArray_exceedEndOfArray_throwsException() throws Exception {
  41177. - byte[] array = new byte[5];
  41178. - int offset = 0;
  41179. - int length = array.length + 1;
  41180. - BoundedInputStream boundedInputStream =
  41181. - createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH);
  41182. -
  41183. - IndexOutOfBoundsException exception =
  41184. - assertThrows(
  41185. - IndexOutOfBoundsException.class, () -> boundedInputStream.read(array, offset, length));
  41186. - assertThat(exception)
  41187. - .hasMessageThat()
  41188. - .isEqualTo(
  41189. - String.format(
  41190. - "The maximumn number of bytes to read (%s) must be less than size (%s)",
  41191. - length, array.length - offset + 1));
  41192. - }
  41193. -
  41194. - @Test
  41195. - public void readArray_zeroLength() throws Exception {
  41196. - byte[] array = new byte[5];
  41197. - int offset = 0;
  41198. - int length = 0;
  41199. - BoundedInputStream boundedInputStream =
  41200. - createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH);
  41201. -
  41202. - int value = boundedInputStream.read(array, offset, length);
  41203. - assertThat(value).isEqualTo(0);
  41204. - }
  41205. -
  41206. - @Test
  41207. - public void readArray_exceedEndOfStream() throws Exception {
  41208. - byte[] array = new byte[5];
  41209. - int offset = 0;
  41210. - int length = 1;
  41211. - BoundedInputStream boundedInputStream =
  41212. - createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH);
  41213. -
  41214. - // Move the position of the stream to the end.
  41215. - boundedInputStream.skip(TEST_BYTES_LENGTH);
  41216. -
  41217. - int value = boundedInputStream.read(array, offset, length);
  41218. -
  41219. - assertThat(value).isEqualTo(-1);
  41220. - }
  41221. -
  41222. - @Test
  41223. - public void readArray_lengthGreaterThanStreamRemaining() throws Exception {
  41224. - byte[] array = new byte[5];
  41225. - int offset = 1;
  41226. - int length = array.length - 1; // 4
  41227. - BoundedInputStream boundedInputStream =
  41228. - createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH);
  41229. -
  41230. - // Moves the position of the stream to end-2.
  41231. - boundedInputStream.skip(TEST_BYTES_LENGTH - 2);
  41232. -
  41233. - // Reads the last two bytes of the stream to the array, and put the data at offset 1.
  41234. - int value = boundedInputStream.read(array, offset, length);
  41235. -
  41236. - byte[] expectedArray = new byte[] {0, 40, 50, 0, 0};
  41237. - assertArrayEquals(expectedArray, array);
  41238. - assertThat(value).isEqualTo(2);
  41239. -
  41240. - // Reachs the end of the stream, thus cannot read anymore.
  41241. - assertThat(boundedInputStream.read()).isEqualTo(-1);
  41242. - }
  41243. -
  41244. - private static BoundedInputStream createBoundedInputStream(
  41245. - final byte[] testBytes, long start, long remaining) {
  41246. - ByteBuffer buffer = ByteBuffer.wrap(testBytes);
  41247. - SeekableByteChannelCompat channel = new ByteBufferChannel(buffer);
  41248. - return new BoundedInputStream(channel, start, remaining);
  41249. - }
  41250. + }
  41251. +
  41252. + @Test
  41253. + public void readArray_exceedEndOfArray_throwsException() throws Exception {
  41254. + byte[] array = new byte[5];
  41255. + int offset = 0;
  41256. + int length = array.length + 1;
  41257. + BoundedInputStream boundedInputStream =
  41258. + createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH);
  41259. +
  41260. + IndexOutOfBoundsException exception = assertThrows(IndexOutOfBoundsException.class,
  41261. + () -> boundedInputStream.read(array, offset, length));
  41262. + assertThat(exception).hasMessageThat().isEqualTo(String.format(
  41263. + "The maximumn number of bytes to read (%s) must be less than size (%s)", length,
  41264. + array.length - offset + 1));
  41265. + }
  41266. +
  41267. + @Test
  41268. + public void readArray_zeroLength() throws Exception {
  41269. + byte[] array = new byte[5];
  41270. + int offset = 0;
  41271. + int length = 0;
  41272. + BoundedInputStream boundedInputStream =
  41273. + createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH);
  41274. +
  41275. + int value = boundedInputStream.read(array, offset, length);
  41276. + assertThat(value).isEqualTo(0);
  41277. + }
  41278. +
  41279. + @Test
  41280. + public void readArray_exceedEndOfStream() throws Exception {
  41281. + byte[] array = new byte[5];
  41282. + int offset = 0;
  41283. + int length = 1;
  41284. + BoundedInputStream boundedInputStream =
  41285. + createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH);
  41286. +
  41287. + // Move the position of the stream to the end.
  41288. + boundedInputStream.skip(TEST_BYTES_LENGTH);
  41289. +
  41290. + int value = boundedInputStream.read(array, offset, length);
  41291. +
  41292. + assertThat(value).isEqualTo(-1);
  41293. + }
  41294. +
  41295. + @Test
  41296. + public void readArray_lengthGreaterThanStreamRemaining() throws Exception {
  41297. + byte[] array = new byte[5];
  41298. + int offset = 1;
  41299. + int length = array.length - 1; // 4
  41300. + BoundedInputStream boundedInputStream =
  41301. + createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH);
  41302. +
  41303. + // Moves the position of the stream to end-2.
  41304. + boundedInputStream.skip(TEST_BYTES_LENGTH - 2);
  41305. +
  41306. + // Reads the last two bytes of the stream to the array, and put the data at offset 1.
  41307. + int value = boundedInputStream.read(array, offset, length);
  41308. +
  41309. + byte[] expectedArray = new byte[] {0, 40, 50, 0, 0};
  41310. + assertArrayEquals(expectedArray, array);
  41311. + assertThat(value).isEqualTo(2);
  41312. +
  41313. + // Reachs the end of the stream, thus cannot read anymore.
  41314. + assertThat(boundedInputStream.read()).isEqualTo(-1);
  41315. + }
  41316. +
  41317. + private static BoundedInputStream createBoundedInputStream(
  41318. + final byte[] testBytes, long start, long remaining) {
  41319. + ByteBuffer buffer = ByteBuffer.wrap(testBytes);
  41320. + SeekableByteChannelCompat channel = new ByteBufferChannel(buffer);
  41321. + return new BoundedInputStream(channel, start, remaining);
  41322. + }
  41323. }
  41324. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/ByteBufferChannelTest.java b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/ByteBufferChannelTest.java
  41325. index abda43058aa90..ce625de8034b7 100644
  41326. --- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/ByteBufferChannelTest.java
  41327. +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/ByteBufferChannelTest.java
  41328. @@ -16,254 +16,252 @@ limitations under the License.
  41329. package org.tensorflow.lite.support.metadata;
  41330. import static com.google.common.truth.Truth.assertThat;
  41331. -import static java.nio.charset.StandardCharsets.UTF_8;
  41332. +
  41333. import static org.junit.Assert.assertThrows;
  41334. -import java.nio.ByteBuffer;
  41335. +import static java.nio.charset.StandardCharsets.UTF_8;
  41336. +
  41337. import org.junit.Test;
  41338. import org.junit.runner.RunWith;
  41339. import org.robolectric.RobolectricTestRunner;
  41340. +import java.nio.ByteBuffer;
  41341. +
  41342. /** Tests of {@link ByteBufferChannel}. */
  41343. @RunWith(RobolectricTestRunner.class)
  41344. public final class ByteBufferChannelTest {
  41345. - private static final String VALID_STRING = "1234567890";
  41346. - private final ByteBuffer validByteBuffer = ByteBuffer.wrap(VALID_STRING.getBytes(UTF_8));
  41347. - private final int validByteBufferLength = validByteBuffer.limit();
  41348. -
  41349. - @Test
  41350. - public void byteBufferChannel_validByteBuffer() {
  41351. - ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
  41352. - assertThat(byteBufferChannel).isNotNull();
  41353. - }
  41354. -
  41355. - @Test
  41356. - public void byteBufferChannel_nullByteBuffer_throwsException() {
  41357. - NullPointerException exception =
  41358. - assertThrows(NullPointerException.class, () -> new ByteBufferChannel(/*buffer=*/ null));
  41359. - assertThat(exception).hasMessageThat().isEqualTo("The ByteBuffer cannot be null.");
  41360. - }
  41361. -
  41362. - @Test
  41363. - public void isOpen_openedByteBufferChannel() {
  41364. - ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
  41365. - assertThat(byteBufferChannel.isOpen()).isTrue();
  41366. - }
  41367. -
  41368. - @Test
  41369. - public void position_newByteBufferChannelWithInitialPosition0() {
  41370. - ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
  41371. - long position = byteBufferChannel.position();
  41372. -
  41373. - long expectedPosition = 0;
  41374. - assertThat(position).isEqualTo(expectedPosition);
  41375. - }
  41376. -
  41377. - @Test
  41378. - public void position_validNewPosition() {
  41379. - ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
  41380. - long validNewPosition = 6;
  41381. -
  41382. - byteBufferChannel.position(validNewPosition);
  41383. - long position = byteBufferChannel.position();
  41384. -
  41385. - assertThat(position).isEqualTo(validNewPosition);
  41386. - }
  41387. -
  41388. - @Test
  41389. - public void position_negtiveNewPosition_throwsException() {
  41390. - ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
  41391. - long invalidNewPosition = -1;
  41392. -
  41393. - IllegalArgumentException exception =
  41394. - assertThrows(
  41395. - IllegalArgumentException.class, () -> byteBufferChannel.position(invalidNewPosition));
  41396. - assertThat(exception)
  41397. - .hasMessageThat()
  41398. - .isEqualTo("The new position should be non-negative and be less than Integer.MAX_VALUE.");
  41399. - }
  41400. -
  41401. - @Test
  41402. - public void position_newPositionGreaterThanMaxIntegerValue_throwsException() {
  41403. - ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
  41404. - long invalidNewPosition = Integer.MAX_VALUE + 1;
  41405. -
  41406. - IllegalArgumentException exception =
  41407. - assertThrows(
  41408. - IllegalArgumentException.class, () -> byteBufferChannel.position(invalidNewPosition));
  41409. - assertThat(exception)
  41410. - .hasMessageThat()
  41411. - .isEqualTo("The new position should be non-negative and be less than Integer.MAX_VALUE.");
  41412. - }
  41413. -
  41414. - @Test
  41415. - public void position_newPositionGreaterThanByfferLength_throwsException() {
  41416. - ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
  41417. - long invalidNewPosition = (long) validByteBufferLength + 1;
  41418. -
  41419. - IllegalArgumentException exception =
  41420. - assertThrows(
  41421. - IllegalArgumentException.class, () -> byteBufferChannel.position(invalidNewPosition));
  41422. - assertThat(exception).hasMessageThat().isEqualTo("newPosition > limit: (11 > 10)");
  41423. - }
  41424. -
  41425. - @Test
  41426. - public void read_fromPosition0() {
  41427. - ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
  41428. - long validNewPosition = 0;
  41429. -
  41430. - byteBufferChannel.position(validNewPosition);
  41431. - ByteBuffer dstBuffer = ByteBuffer.allocate(validByteBufferLength);
  41432. - int numBytes = byteBufferChannel.read(dstBuffer);
  41433. -
  41434. - assertThat(numBytes).isEqualTo(validByteBufferLength);
  41435. - assertThat(dstBuffer).isEqualTo(validByteBuffer);
  41436. - }
  41437. -
  41438. - @Test
  41439. - public void read_fromPosition5() {
  41440. - ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
  41441. - long validNewPosition = 5;
  41442. -
  41443. - byteBufferChannel.position(validNewPosition);
  41444. - ByteBuffer dstBuffer = ByteBuffer.allocate(validByteBufferLength);
  41445. - int numBytes = byteBufferChannel.read(dstBuffer);
  41446. -
  41447. - assertThat(numBytes).isEqualTo(validByteBufferLength - (int) validNewPosition);
  41448. - String dstString = convertByteByfferToString(dstBuffer, numBytes);
  41449. - String expectedString = "67890";
  41450. - assertThat(dstString).isEqualTo(expectedString);
  41451. - }
  41452. -
  41453. - @Test
  41454. - public void read_fromPositionValidByteBufferLength() {
  41455. - ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
  41456. - long validNewPosition = validByteBufferLength;
  41457. -
  41458. - byteBufferChannel.position(validNewPosition);
  41459. - ByteBuffer dstBuffer = ByteBuffer.allocate(validByteBufferLength);
  41460. - int numBytes = byteBufferChannel.read(dstBuffer);
  41461. -
  41462. - assertThat(numBytes).isEqualTo(-1);
  41463. - }
  41464. -
  41465. - @Test
  41466. - public void read_dstBufferRemaining0() {
  41467. - ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
  41468. - long validNewPosition = 0;
  41469. -
  41470. - byteBufferChannel.position(validNewPosition);
  41471. - ByteBuffer dstBuffer = ByteBuffer.allocate(validByteBufferLength);
  41472. - dstBuffer.position(validByteBufferLength);
  41473. - int numBytes = byteBufferChannel.read(dstBuffer);
  41474. -
  41475. - assertThat(numBytes).isEqualTo(0);
  41476. - String dstString = convertByteByfferToString(dstBuffer, numBytes);
  41477. - String expectedString = "";
  41478. - assertThat(dstString).isEqualTo(expectedString);
  41479. - }
  41480. -
  41481. - @Test
  41482. - public void read_dstBufferIsSmallerThanTheBufferChannel() {
  41483. - ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
  41484. - int dstBufferLength = 3;
  41485. -
  41486. - ByteBuffer dstBuffer = ByteBuffer.allocate(dstBufferLength);
  41487. - int numBytes = byteBufferChannel.read(dstBuffer);
  41488. -
  41489. - assertThat(numBytes).isEqualTo(dstBufferLength);
  41490. - assertThat(validByteBuffer.position()).isEqualTo(dstBufferLength);
  41491. -
  41492. - String dstString = convertByteByfferToString(dstBuffer, dstBufferLength);
  41493. - String expectedString = "123";
  41494. - assertThat(dstString).isEqualTo(expectedString);
  41495. - }
  41496. -
  41497. - @Test
  41498. - public void size_validBuffer() {
  41499. - ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
  41500. - assertThat(byteBufferChannel.size()).isEqualTo(validByteBufferLength);
  41501. - }
  41502. -
  41503. - @Test
  41504. - public void truncate_validSizeAndPosition0() {
  41505. - ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
  41506. - long truncateSize = 3;
  41507. -
  41508. - byteBufferChannel.truncate(truncateSize);
  41509. -
  41510. - assertThat(byteBufferChannel.size()).isEqualTo(truncateSize);
  41511. - assertThat(byteBufferChannel.position()).isEqualTo(0);
  41512. - }
  41513. -
  41514. - @Test
  41515. - public void truncate_validSizeAndPosition5() {
  41516. - ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
  41517. - long validNewPosition = 5;
  41518. -
  41519. - byteBufferChannel.position(validNewPosition);
  41520. - long truncateSize = 3;
  41521. - byteBufferChannel.truncate(truncateSize);
  41522. -
  41523. - assertThat(byteBufferChannel.position()).isEqualTo(truncateSize);
  41524. - }
  41525. -
  41526. - @Test
  41527. - public void truncate_sizeNotSmallerThanBufferSize() {
  41528. - ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
  41529. - long truncateSize = (long) validByteBufferLength;
  41530. -
  41531. - byteBufferChannel.truncate(truncateSize);
  41532. -
  41533. - assertThat(byteBufferChannel.position()).isEqualTo(0);
  41534. - }
  41535. -
  41536. - @Test
  41537. - public void write_srcBufferSmallerThanBufferChannel() {
  41538. - String srcString = "5555";
  41539. - long newPosition = 3;
  41540. - String expectedString = "1235555890";
  41541. - ByteBuffer srcBuffer = ByteBuffer.wrap(srcString.getBytes(UTF_8));
  41542. -
  41543. - ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
  41544. - byteBufferChannel.position(newPosition);
  41545. - byteBufferChannel.write(srcBuffer);
  41546. -
  41547. - assertThat(byteBufferChannel.position()).isEqualTo(newPosition + srcString.length());
  41548. - ByteBuffer dstBuffer = ByteBuffer.allocate(validByteBufferLength);
  41549. - byteBufferChannel.position(0);
  41550. - byteBufferChannel.read(dstBuffer);
  41551. - ByteBuffer expectedBuffer = ByteBuffer.wrap(expectedString.getBytes(UTF_8));
  41552. - dstBuffer.rewind();
  41553. - expectedBuffer.rewind();
  41554. - assertThat(dstBuffer).isEqualTo(expectedBuffer);
  41555. - }
  41556. -
  41557. - @Test
  41558. - public void write_srcBufferGreaterThanBufferChannel() {
  41559. - String srcString = "5555";
  41560. - long newPosition = 8;
  41561. - String expectedString = "1234567855";
  41562. - ByteBuffer srcBuffer = ByteBuffer.wrap(srcString.getBytes(UTF_8));
  41563. -
  41564. - ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
  41565. - byteBufferChannel.position(newPosition);
  41566. - byteBufferChannel.write(srcBuffer);
  41567. - assertThat(byteBufferChannel.position()).isEqualTo(validByteBufferLength);
  41568. -
  41569. - ByteBuffer dstBuffer = ByteBuffer.allocate(validByteBufferLength);
  41570. - byteBufferChannel.position(0);
  41571. - byteBufferChannel.read(dstBuffer);
  41572. - ByteBuffer expectedBuffer = ByteBuffer.wrap(expectedString.getBytes(UTF_8));
  41573. - dstBuffer.rewind();
  41574. - expectedBuffer.rewind();
  41575. - assertThat(dstBuffer).isEqualTo(expectedBuffer);
  41576. - }
  41577. -
  41578. - private static String convertByteByfferToString(ByteBuffer buffer, int arrLength) {
  41579. - byte[] bytes = new byte[arrLength];
  41580. - buffer.rewind();
  41581. - buffer.get(bytes);
  41582. - return new String(bytes, UTF_8);
  41583. - }
  41584. + private static final String VALID_STRING = "1234567890";
  41585. + private final ByteBuffer validByteBuffer = ByteBuffer.wrap(VALID_STRING.getBytes(UTF_8));
  41586. + private final int validByteBufferLength = validByteBuffer.limit();
  41587. +
  41588. + @Test
  41589. + public void byteBufferChannel_validByteBuffer() {
  41590. + ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
  41591. + assertThat(byteBufferChannel).isNotNull();
  41592. + }
  41593. +
  41594. + @Test
  41595. + public void byteBufferChannel_nullByteBuffer_throwsException() {
  41596. + NullPointerException exception = assertThrows(
  41597. + NullPointerException.class, () -> new ByteBufferChannel(/*buffer=*/null));
  41598. + assertThat(exception).hasMessageThat().isEqualTo("The ByteBuffer cannot be null.");
  41599. + }
  41600. +
  41601. + @Test
  41602. + public void isOpen_openedByteBufferChannel() {
  41603. + ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
  41604. + assertThat(byteBufferChannel.isOpen()).isTrue();
  41605. + }
  41606. +
  41607. + @Test
  41608. + public void position_newByteBufferChannelWithInitialPosition0() {
  41609. + ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
  41610. + long position = byteBufferChannel.position();
  41611. +
  41612. + long expectedPosition = 0;
  41613. + assertThat(position).isEqualTo(expectedPosition);
  41614. + }
  41615. +
  41616. + @Test
  41617. + public void position_validNewPosition() {
  41618. + ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
  41619. + long validNewPosition = 6;
  41620. +
  41621. + byteBufferChannel.position(validNewPosition);
  41622. + long position = byteBufferChannel.position();
  41623. +
  41624. + assertThat(position).isEqualTo(validNewPosition);
  41625. + }
  41626. +
  41627. + @Test
  41628. + public void position_negtiveNewPosition_throwsException() {
  41629. + ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
  41630. + long invalidNewPosition = -1;
  41631. +
  41632. + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
  41633. + () -> byteBufferChannel.position(invalidNewPosition));
  41634. + assertThat(exception).hasMessageThat().isEqualTo(
  41635. + "The new position should be non-negative and be less than Integer.MAX_VALUE.");
  41636. + }
  41637. +
  41638. + @Test
  41639. + public void position_newPositionGreaterThanMaxIntegerValue_throwsException() {
  41640. + ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
  41641. + long invalidNewPosition = Integer.MAX_VALUE + 1;
  41642. +
  41643. + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
  41644. + () -> byteBufferChannel.position(invalidNewPosition));
  41645. + assertThat(exception).hasMessageThat().isEqualTo(
  41646. + "The new position should be non-negative and be less than Integer.MAX_VALUE.");
  41647. + }
  41648. +
  41649. + @Test
  41650. + public void position_newPositionGreaterThanByfferLength_throwsException() {
  41651. + ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
  41652. + long invalidNewPosition = (long) validByteBufferLength + 1;
  41653. +
  41654. + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
  41655. + () -> byteBufferChannel.position(invalidNewPosition));
  41656. + assertThat(exception).hasMessageThat().isEqualTo("newPosition > limit: (11 > 10)");
  41657. + }
  41658. +
  41659. + @Test
  41660. + public void read_fromPosition0() {
  41661. + ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
  41662. + long validNewPosition = 0;
  41663. +
  41664. + byteBufferChannel.position(validNewPosition);
  41665. + ByteBuffer dstBuffer = ByteBuffer.allocate(validByteBufferLength);
  41666. + int numBytes = byteBufferChannel.read(dstBuffer);
  41667. +
  41668. + assertThat(numBytes).isEqualTo(validByteBufferLength);
  41669. + assertThat(dstBuffer).isEqualTo(validByteBuffer);
  41670. + }
  41671. +
  41672. + @Test
  41673. + public void read_fromPosition5() {
  41674. + ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
  41675. + long validNewPosition = 5;
  41676. +
  41677. + byteBufferChannel.position(validNewPosition);
  41678. + ByteBuffer dstBuffer = ByteBuffer.allocate(validByteBufferLength);
  41679. + int numBytes = byteBufferChannel.read(dstBuffer);
  41680. +
  41681. + assertThat(numBytes).isEqualTo(validByteBufferLength - (int) validNewPosition);
  41682. + String dstString = convertByteByfferToString(dstBuffer, numBytes);
  41683. + String expectedString = "67890";
  41684. + assertThat(dstString).isEqualTo(expectedString);
  41685. + }
  41686. +
  41687. + @Test
  41688. + public void read_fromPositionValidByteBufferLength() {
  41689. + ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
  41690. + long validNewPosition = validByteBufferLength;
  41691. +
  41692. + byteBufferChannel.position(validNewPosition);
  41693. + ByteBuffer dstBuffer = ByteBuffer.allocate(validByteBufferLength);
  41694. + int numBytes = byteBufferChannel.read(dstBuffer);
  41695. +
  41696. + assertThat(numBytes).isEqualTo(-1);
  41697. + }
  41698. +
  41699. + @Test
  41700. + public void read_dstBufferRemaining0() {
  41701. + ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
  41702. + long validNewPosition = 0;
  41703. +
  41704. + byteBufferChannel.position(validNewPosition);
  41705. + ByteBuffer dstBuffer = ByteBuffer.allocate(validByteBufferLength);
  41706. + dstBuffer.position(validByteBufferLength);
  41707. + int numBytes = byteBufferChannel.read(dstBuffer);
  41708. +
  41709. + assertThat(numBytes).isEqualTo(0);
  41710. + String dstString = convertByteByfferToString(dstBuffer, numBytes);
  41711. + String expectedString = "";
  41712. + assertThat(dstString).isEqualTo(expectedString);
  41713. + }
  41714. +
  41715. + @Test
  41716. + public void read_dstBufferIsSmallerThanTheBufferChannel() {
  41717. + ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
  41718. + int dstBufferLength = 3;
  41719. +
  41720. + ByteBuffer dstBuffer = ByteBuffer.allocate(dstBufferLength);
  41721. + int numBytes = byteBufferChannel.read(dstBuffer);
  41722. +
  41723. + assertThat(numBytes).isEqualTo(dstBufferLength);
  41724. + assertThat(validByteBuffer.position()).isEqualTo(dstBufferLength);
  41725. +
  41726. + String dstString = convertByteByfferToString(dstBuffer, dstBufferLength);
  41727. + String expectedString = "123";
  41728. + assertThat(dstString).isEqualTo(expectedString);
  41729. + }
  41730. +
  41731. + @Test
  41732. + public void size_validBuffer() {
  41733. + ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
  41734. + assertThat(byteBufferChannel.size()).isEqualTo(validByteBufferLength);
  41735. + }
  41736. +
  41737. + @Test
  41738. + public void truncate_validSizeAndPosition0() {
  41739. + ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
  41740. + long truncateSize = 3;
  41741. +
  41742. + byteBufferChannel.truncate(truncateSize);
  41743. +
  41744. + assertThat(byteBufferChannel.size()).isEqualTo(truncateSize);
  41745. + assertThat(byteBufferChannel.position()).isEqualTo(0);
  41746. + }
  41747. +
  41748. + @Test
  41749. + public void truncate_validSizeAndPosition5() {
  41750. + ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
  41751. + long validNewPosition = 5;
  41752. +
  41753. + byteBufferChannel.position(validNewPosition);
  41754. + long truncateSize = 3;
  41755. + byteBufferChannel.truncate(truncateSize);
  41756. +
  41757. + assertThat(byteBufferChannel.position()).isEqualTo(truncateSize);
  41758. + }
  41759. +
  41760. + @Test
  41761. + public void truncate_sizeNotSmallerThanBufferSize() {
  41762. + ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
  41763. + long truncateSize = (long) validByteBufferLength;
  41764. +
  41765. + byteBufferChannel.truncate(truncateSize);
  41766. +
  41767. + assertThat(byteBufferChannel.position()).isEqualTo(0);
  41768. + }
  41769. +
  41770. + @Test
  41771. + public void write_srcBufferSmallerThanBufferChannel() {
  41772. + String srcString = "5555";
  41773. + long newPosition = 3;
  41774. + String expectedString = "1235555890";
  41775. + ByteBuffer srcBuffer = ByteBuffer.wrap(srcString.getBytes(UTF_8));
  41776. +
  41777. + ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
  41778. + byteBufferChannel.position(newPosition);
  41779. + byteBufferChannel.write(srcBuffer);
  41780. +
  41781. + assertThat(byteBufferChannel.position()).isEqualTo(newPosition + srcString.length());
  41782. + ByteBuffer dstBuffer = ByteBuffer.allocate(validByteBufferLength);
  41783. + byteBufferChannel.position(0);
  41784. + byteBufferChannel.read(dstBuffer);
  41785. + ByteBuffer expectedBuffer = ByteBuffer.wrap(expectedString.getBytes(UTF_8));
  41786. + dstBuffer.rewind();
  41787. + expectedBuffer.rewind();
  41788. + assertThat(dstBuffer).isEqualTo(expectedBuffer);
  41789. + }
  41790. +
  41791. + @Test
  41792. + public void write_srcBufferGreaterThanBufferChannel() {
  41793. + String srcString = "5555";
  41794. + long newPosition = 8;
  41795. + String expectedString = "1234567855";
  41796. + ByteBuffer srcBuffer = ByteBuffer.wrap(srcString.getBytes(UTF_8));
  41797. +
  41798. + ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
  41799. + byteBufferChannel.position(newPosition);
  41800. + byteBufferChannel.write(srcBuffer);
  41801. + assertThat(byteBufferChannel.position()).isEqualTo(validByteBufferLength);
  41802. +
  41803. + ByteBuffer dstBuffer = ByteBuffer.allocate(validByteBufferLength);
  41804. + byteBufferChannel.position(0);
  41805. + byteBufferChannel.read(dstBuffer);
  41806. + ByteBuffer expectedBuffer = ByteBuffer.wrap(expectedString.getBytes(UTF_8));
  41807. + dstBuffer.rewind();
  41808. + expectedBuffer.rewind();
  41809. + assertThat(dstBuffer).isEqualTo(expectedBuffer);
  41810. + }
  41811. +
  41812. + private static String convertByteByfferToString(ByteBuffer buffer, int arrLength) {
  41813. + byte[] bytes = new byte[arrLength];
  41814. + buffer.rewind();
  41815. + buffer.get(bytes);
  41816. + return new String(bytes, UTF_8);
  41817. + }
  41818. }
  41819. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/MetadataExtractorTest.java b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/MetadataExtractorTest.java
  41820. index 67fc50d9f57b1..9f1173a1ea19b 100644
  41821. --- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/MetadataExtractorTest.java
  41822. +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/MetadataExtractorTest.java
  41823. @@ -16,24 +16,20 @@ limitations under the License.
  41824. package org.tensorflow.lite.support.metadata;
  41825. import static com.google.common.truth.Truth.assertThat;
  41826. +
  41827. import static org.junit.Assert.assertArrayEquals;
  41828. import static org.junit.Assert.assertThrows;
  41829. import android.content.Context;
  41830. import android.content.res.AssetFileDescriptor;
  41831. +
  41832. import androidx.test.core.app.ApplicationProvider;
  41833. +
  41834. import com.google.flatbuffers.FlatBufferBuilder;
  41835. -import java.io.FileInputStream;
  41836. -import java.io.InputStream;
  41837. -import java.nio.ByteBuffer;
  41838. -import java.nio.channels.FileChannel;
  41839. -import java.util.Arrays;
  41840. -import java.util.Collection;
  41841. -import java.util.HashSet;
  41842. -import java.util.Random;
  41843. -import java.util.Set;
  41844. +
  41845. import org.apache.commons.io.IOUtils;
  41846. import org.checkerframework.checker.nullness.qual.Nullable;
  41847. +import org.junit.Ignore;
  41848. import org.junit.Test;
  41849. import org.junit.runner.RunWith;
  41850. import org.junit.runners.Suite;
  41851. @@ -56,931 +52,903 @@ import org.tensorflow.lite.support.metadata.schema.ModelMetadata;
  41852. import org.tensorflow.lite.support.metadata.schema.SubGraphMetadata;
  41853. import org.tensorflow.lite.support.metadata.schema.TensorMetadata;
  41854. -import org.junit.Ignore;
  41855. +import java.io.FileInputStream;
  41856. +import java.io.InputStream;
  41857. +import java.nio.ByteBuffer;
  41858. +import java.nio.channels.FileChannel;
  41859. +import java.util.Arrays;
  41860. +import java.util.Collection;
  41861. +import java.util.HashSet;
  41862. +import java.util.Random;
  41863. +import java.util.Set;
  41864. /** Tests of {@link MetadataExtractor}. */
  41865. @RunWith(Suite.class)
  41866. @SuiteClasses({MetadataExtractorTest.General.class, MetadataExtractorTest.InputTensorType.class})
  41867. public class MetadataExtractorTest {
  41868. - private static final int[] validShape = new int[] {4, 10, 10, 3};
  41869. - private static final byte DATA_TYPE = TensorType.UINT8;
  41870. - private static final byte CONTENT_PROPERTIES_TYPE = ContentProperties.ImageProperties;
  41871. - private static final float VALID_SCALE = 3.3f;
  41872. - private static final long VALID_ZERO_POINT = 2;
  41873. - private static final float DEFAULT_SCALE = 0.0f;
  41874. - private static final long DEFAULT_ZERO_POINT = 0;
  41875. - private static final String MODEL_NAME = "model.tflite";
  41876. - // Scale and zero point should both be a single value, not an array.
  41877. - private static final float[] invalidScale = new float[] {0.0f, 1.2f};
  41878. - private static final long[] invalidZeroPoint = new long[] {1, 2};
  41879. - private static final String MODEL_PATH = "mobilenet_v1_1.0_224_quant.tflite";
  41880. - // labels.txt is packed in mobilenet_v1_1.0_224_quant.tflite as an associated file.
  41881. - private static final String VALID_LABEL_FILE_NAME = "labels.txt";
  41882. - // invalid.txt is not packed in mobilenet_v1_1.0_224_quant.tflite.
  41883. - private static final String INVALID_LABEL_FILE_NAME = "invalid.txt";
  41884. - private static final int EMPTY_FLATBUFFER_VECTOR = -1;
  41885. - private static final String TFLITE_MODEL_IDENTIFIER = "TFL3";
  41886. - private static final String TFLITE_METADATA_IDENTIFIER = "M001";
  41887. -
  41888. - /** General tests of MetadataExtractor. */
  41889. - @RunWith(RobolectricTestRunner.class)
  41890. - public static final class General extends MetadataExtractorTest {
  41891. -
  41892. - @Test
  41893. - public void hasMetadata_modelWithMetadata() throws Exception {
  41894. - // Creates a model flatbuffer with metadata.
  41895. - ByteBuffer modelWithMetadata = createModelByteBuffer();
  41896. -
  41897. - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
  41898. - assertThat(metadataExtractor.hasMetadata()).isTrue();
  41899. - }
  41900. -
  41901. - @Test
  41902. - public void hasMetadata_modelWithoutMetadata() throws Exception {
  41903. - // Creates a model flatbuffer without metadata.
  41904. - ByteBuffer modelWithoutMetadata = createModelByteBuffer(/*metadataBuffer=*/ null, DATA_TYPE);
  41905. -
  41906. - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithoutMetadata);
  41907. - assertThat(metadataExtractor.hasMetadata()).isFalse();
  41908. - }
  41909. -
  41910. - @Ignore
  41911. - @Test
  41912. - public void getAssociatedFile_validAssociateFile() throws Exception {
  41913. - ByteBuffer mobileNetBuffer = loadMobileNetBuffer();
  41914. - MetadataExtractor mobileNetMetadataExtractor = new MetadataExtractor(mobileNetBuffer);
  41915. - InputStream associateFileStream =
  41916. - mobileNetMetadataExtractor.getAssociatedFile(VALID_LABEL_FILE_NAME);
  41917. -
  41918. - // Reads the golden file from context.
  41919. - Context context = ApplicationProvider.getApplicationContext();
  41920. - InputStream goldenAssociateFileStream = context.getAssets().open(VALID_LABEL_FILE_NAME);
  41921. - assertThat(IOUtils.contentEquals(goldenAssociateFileStream, associateFileStream)).isTrue();
  41922. - }
  41923. -
  41924. - @Ignore
  41925. - @Test
  41926. - public void getAssociatedFile_invalidAssociateFile() throws Exception {
  41927. - ByteBuffer mobileNetBuffer = loadMobileNetBuffer();
  41928. - MetadataExtractor mobileNetMetadataExtractor = new MetadataExtractor(mobileNetBuffer);
  41929. - IllegalArgumentException exception =
  41930. - assertThrows(
  41931. - IllegalArgumentException.class,
  41932. - () -> mobileNetMetadataExtractor.getAssociatedFile(INVALID_LABEL_FILE_NAME));
  41933. - assertThat(exception)
  41934. - .hasMessageThat()
  41935. - .isEqualTo(
  41936. - String.format(
  41937. - "The file, %s, does not exist in the zip file.", INVALID_LABEL_FILE_NAME));
  41938. - }
  41939. -
  41940. - @Ignore
  41941. - @Test
  41942. - public void getAssociatedFile_nullFileName() throws Exception {
  41943. - ByteBuffer mobileNetBuffer = loadMobileNetBuffer();
  41944. - MetadataExtractor mobileNetMetadataExtractor = new MetadataExtractor(mobileNetBuffer);
  41945. - IllegalArgumentException exception =
  41946. - assertThrows(
  41947. - IllegalArgumentException.class,
  41948. - () -> mobileNetMetadataExtractor.getAssociatedFile(/*fileName=*/ null));
  41949. - assertThat(exception)
  41950. - .hasMessageThat()
  41951. - .contains("The file, null, does not exist in the zip file.");
  41952. - }
  41953. -
  41954. - @Test
  41955. - public void getAssociatedFile_nonZipModel_throwsException() throws Exception {
  41956. - // Creates a model flatbuffer with metadata.
  41957. - ByteBuffer modelWithMetadata = createModelByteBuffer();
  41958. -
  41959. - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
  41960. - IllegalStateException exception =
  41961. - assertThrows(
  41962. - IllegalStateException.class,
  41963. - () -> metadataExtractor.getAssociatedFile(VALID_LABEL_FILE_NAME));
  41964. - assertThat(exception)
  41965. - .hasMessageThat()
  41966. - .contains("This model does not contain associated files, and is not a Zip file.");
  41967. - }
  41968. -
  41969. - @Test
  41970. - public void getAssociatedFileNames_nonZipModel_throwsException() throws Exception {
  41971. - // Creates a model flatbuffer with metadata.
  41972. - ByteBuffer modelWithMetadata = createModelByteBuffer();
  41973. -
  41974. - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
  41975. - IllegalStateException exception =
  41976. - assertThrows(IllegalStateException.class, metadataExtractor::getAssociatedFileNames);
  41977. - assertThat(exception)
  41978. - .hasMessageThat()
  41979. - .contains("This model does not contain associated files, and is not a Zip file.");
  41980. - }
  41981. -
  41982. - @Ignore
  41983. - @Test
  41984. - public void getAssociatedFileNames_validFileNames() throws Exception {
  41985. - ByteBuffer mobileNetBuffer = loadMobileNetBuffer();
  41986. - MetadataExtractor mobileNetMetadataExtractor = new MetadataExtractor(mobileNetBuffer);
  41987. - Set<String> expectedSet = new HashSet<>();
  41988. - expectedSet.add(VALID_LABEL_FILE_NAME);
  41989. - assertThat(mobileNetMetadataExtractor.getAssociatedFileNames()).isEqualTo(expectedSet);
  41990. - }
  41991. -
  41992. - @Test
  41993. - public void metadataExtractor_loadNullBuffer_throwsException() {
  41994. - ByteBuffer nullBuffer = null;
  41995. - NullPointerException exception =
  41996. - assertThrows(NullPointerException.class, () -> new MetadataExtractor(nullBuffer));
  41997. - assertThat(exception).hasMessageThat().contains("Model flatbuffer cannot be null.");
  41998. - }
  41999. -
  42000. - @Test
  42001. - public void metadataExtractor_loadRandomBuffer_throwsException() {
  42002. - ByteBuffer randomBuffer = createRandomByteBuffer();
  42003. - IllegalArgumentException exception =
  42004. - assertThrows(IllegalArgumentException.class, () -> new MetadataExtractor(randomBuffer));
  42005. - assertThat(exception)
  42006. - .hasMessageThat()
  42007. - .contains(
  42008. - "The identifier of the model is invalid. The buffer may not be a valid TFLite model"
  42009. - + " flatbuffer.");
  42010. - }
  42011. -
  42012. - @Test
  42013. - public void metadataExtractor_loadModelWithInvalidIdentifier_throwsException() {
  42014. - // Creates a model with an invalid identifier.
  42015. - String invalidIdentifier = "INVI";
  42016. - FlatBufferBuilder builder = new FlatBufferBuilder();
  42017. - Model.startModel(builder);
  42018. - int model = Model.endModel(builder);
  42019. - builder.finish(model, invalidIdentifier);
  42020. - ByteBuffer modelBuffer = builder.dataBuffer();
  42021. -
  42022. - IllegalArgumentException exception =
  42023. - assertThrows(IllegalArgumentException.class, () -> new MetadataExtractor(modelBuffer));
  42024. - assertThat(exception)
  42025. - .hasMessageThat()
  42026. - .contains(
  42027. - "The identifier of the model is invalid. The buffer may not be a valid TFLite model"
  42028. - + " flatbuffer.");
  42029. - }
  42030. -
  42031. - @Test
  42032. - public void metadataExtractor_loadMetadataWithInvalidIdentifier_throwsException() {
  42033. - // Creates a model with metadata which contains an invalid identifier.
  42034. - String invalidIdentifier = "INVI";
  42035. - ByteBuffer metadata = createMetadataByteBuffer(invalidIdentifier, null);
  42036. - ByteBuffer modelBuffer = createModelByteBuffer(metadata, DATA_TYPE);
  42037. -
  42038. - IllegalArgumentException exception =
  42039. - assertThrows(IllegalArgumentException.class, () -> new MetadataExtractor(modelBuffer));
  42040. - assertThat(exception)
  42041. - .hasMessageThat()
  42042. - .contains(
  42043. - "The identifier of the metadata is invalid. The buffer may not be a valid TFLite"
  42044. - + " metadata flatbuffer.");
  42045. - }
  42046. -
  42047. - @Test
  42048. - public void getInputTensorCount_validModelFile() throws Exception {
  42049. - // Creates a model flatbuffer with metadata.
  42050. - ByteBuffer modelWithMetadata = createModelByteBuffer();
  42051. -
  42052. - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
  42053. - int count = metadataExtractor.getInputTensorCount();
  42054. - assertThat(count).isEqualTo(3);
  42055. - }
  42056. -
  42057. - @Test
  42058. - public void getOutputTensorCount_validModelFile() throws Exception {
  42059. - // Creates a model flatbuffer with metadata.
  42060. - ByteBuffer modelWithMetadata = createModelByteBuffer();
  42061. -
  42062. - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
  42063. - int count = metadataExtractor.getOutputTensorCount();
  42064. - assertThat(count).isEqualTo(3);
  42065. - }
  42066. -
  42067. - @Test
  42068. - public void getInputTensorShape_validTensorShape() throws Exception {
  42069. - // Creates a model flatbuffer with metadata.
  42070. - ByteBuffer modelWithMetadata = createModelByteBuffer();
  42071. -
  42072. - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
  42073. - int[] shape = metadataExtractor.getInputTensorShape(0);
  42074. - assertArrayEquals(validShape, shape);
  42075. - }
  42076. -
  42077. - @Test
  42078. - public void getInputTensorShape_emptyTensor() throws Exception {
  42079. - // Creates a model flatbuffer with metadata.
  42080. - ByteBuffer modelWithMetadata = createModelByteBuffer();
  42081. -
  42082. - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
  42083. - int[] shape = metadataExtractor.getInputTensorShape(1);
  42084. - assertThat(shape).isEmpty();
  42085. - }
  42086. -
  42087. - @Test
  42088. - public void getInputTensorType_emptyTensor() throws Exception {
  42089. - // Creates a model flatbuffer with metadata.
  42090. - ByteBuffer modelWithMetadata = createModelByteBuffer();
  42091. -
  42092. - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
  42093. - byte type = metadataExtractor.getInputTensorType(1);
  42094. - assertThat(type).isEqualTo(TensorType.FLOAT32);
  42095. - }
  42096. -
  42097. - @Test
  42098. - public void getOutputTensorShape_validTensor() throws Exception {
  42099. - // Creates a model flatbuffer with metadata.
  42100. - ByteBuffer modelWithMetadata = createModelByteBuffer();
  42101. -
  42102. - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
  42103. - int[] shape = metadataExtractor.getOutputTensorShape(0);
  42104. - assertArrayEquals(validShape, shape);
  42105. - }
  42106. -
  42107. - @Test
  42108. - public void getOutputTensorShape_emptyTensor() throws Exception {
  42109. - // Creates a model flatbuffer with metadata.
  42110. - ByteBuffer modelWithMetadata = createModelByteBuffer();
  42111. -
  42112. - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
  42113. - int[] shape = metadataExtractor.getOutputTensorShape(1);
  42114. - assertThat(shape).isEmpty();
  42115. - }
  42116. -
  42117. - @Test
  42118. - public void getOutputTensorType_emptyTensor() throws Exception {
  42119. - // Creates a model flatbuffer with metadata.
  42120. - ByteBuffer modelWithMetadata = createModelByteBuffer();
  42121. -
  42122. - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
  42123. - byte type = metadataExtractor.getOutputTensorType(1);
  42124. - assertThat(type).isEqualTo(TensorType.FLOAT32);
  42125. - }
  42126. -
  42127. - @Test
  42128. - public void getInputTensorShape_indexGreaterThanTensorNumber_throwsException()
  42129. - throws Exception {
  42130. - // Creates a model flatbuffer with metadata.
  42131. - ByteBuffer modelWithMetadata = createModelByteBuffer();
  42132. -
  42133. - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
  42134. - IllegalArgumentException exception =
  42135. - assertThrows(
  42136. - IllegalArgumentException.class, () -> metadataExtractor.getInputTensorShape(3));
  42137. - assertThat(exception).hasMessageThat().contains("The inputIndex specified is invalid.");
  42138. - }
  42139. -
  42140. - @Test
  42141. - public void getInputTensorShape_negtiveIndex_throwsException() throws Exception {
  42142. - // Creates a model flatbuffer with metadata.
  42143. - ByteBuffer modelWithMetadata = createModelByteBuffer();
  42144. -
  42145. - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
  42146. - IllegalArgumentException exception =
  42147. - assertThrows(
  42148. - IllegalArgumentException.class, () -> metadataExtractor.getInputTensorShape(-1));
  42149. - assertThat(exception).hasMessageThat().contains("The inputIndex specified is invalid.");
  42150. - }
  42151. -
  42152. - @Test
  42153. - public void getOutputTensorShape_indexGreaterThanTensorNumber_throwsException()
  42154. - throws Exception {
  42155. - // Creates a model flatbuffer with metadata.
  42156. - ByteBuffer modelWithMetadata = createModelByteBuffer();
  42157. -
  42158. - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
  42159. - IllegalArgumentException exception =
  42160. - assertThrows(
  42161. - IllegalArgumentException.class, () -> metadataExtractor.getOutputTensorShape(3));
  42162. - assertThat(exception).hasMessageThat().contains("The outputIndex specified is invalid.");
  42163. - }
  42164. -
  42165. - @Test
  42166. - public void getOutputTensorShape_negtiveIndex_throwsException() throws Exception {
  42167. - // Creates a model flatbuffer with metadata.
  42168. - ByteBuffer modelWithMetadata = createModelByteBuffer();
  42169. -
  42170. - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
  42171. - IllegalArgumentException exception =
  42172. - assertThrows(
  42173. - IllegalArgumentException.class, () -> metadataExtractor.getOutputTensorShape(-1));
  42174. - assertThat(exception).hasMessageThat().contains("The outputIndex specified is invalid.");
  42175. - }
  42176. -
  42177. - @Test
  42178. - public void getModelMetadata_modelWithMetadata() throws Exception {
  42179. - // Creates a model flatbuffer with metadata.
  42180. - ByteBuffer modelWithMetadata = createModelByteBuffer();
  42181. -
  42182. - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
  42183. - ModelMetadata modelMetadata = metadataExtractor.getModelMetadata();
  42184. - assertThat(modelMetadata.name()).isEqualTo(MODEL_NAME);
  42185. - }
  42186. -
  42187. - @Test
  42188. - public void getModelMetadata_modelWithoutMetadata_throwsException() throws Exception {
  42189. - // Creates a model flatbuffer without metadata.
  42190. - ByteBuffer modelWithoutMetadata = createModelByteBuffer(/*metadataBuffer=*/ null, DATA_TYPE);
  42191. -
  42192. - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithoutMetadata);
  42193. -
  42194. - IllegalStateException exception =
  42195. - assertThrows(IllegalStateException.class, () -> metadataExtractor.getModelMetadata());
  42196. - assertThat(exception)
  42197. - .hasMessageThat()
  42198. - .contains("This model does not contain model metadata.");
  42199. - }
  42200. -
  42201. - @Test
  42202. - public void metadataExtractor_modelWithEmptySubgraphMetadata_throwsException() {
  42203. - // Creates a metadata FlatBuffer without empty subgraph metadata.
  42204. - FlatBufferBuilder builder = new FlatBufferBuilder();
  42205. - SubGraphMetadata.startSubGraphMetadata(builder);
  42206. - int subgraph1Metadata = SubGraphMetadata.endSubGraphMetadata(builder);
  42207. - int subgraphsMetadata =
  42208. - ModelMetadata.createSubgraphMetadataVector(builder, new int[] {subgraph1Metadata});
  42209. -
  42210. - ModelMetadata.startModelMetadata(builder);
  42211. - ModelMetadata.addSubgraphMetadata(builder, subgraphsMetadata);
  42212. - int modelMetadata = ModelMetadata.endModelMetadata(builder);
  42213. - builder.finish(modelMetadata, TFLITE_METADATA_IDENTIFIER);
  42214. - ByteBuffer emptyMetadata = builder.dataBuffer();
  42215. - ByteBuffer modelWithEmptyMetadata = createModelByteBuffer(emptyMetadata, DATA_TYPE);
  42216. -
  42217. - IllegalArgumentException exception =
  42218. - assertThrows(
  42219. - IllegalArgumentException.class, () -> new MetadataExtractor(modelWithEmptyMetadata));
  42220. - assertThat(exception)
  42221. - .hasMessageThat()
  42222. - .isEqualTo(
  42223. - "The number of input tensors in the model is 3. The number of input tensors that"
  42224. - + " recorded in the metadata is 0. These two values does not match.");
  42225. - }
  42226. -
  42227. - @Test
  42228. - public void metadataExtractor_modelWithEmptyMetadata_throwsException() {
  42229. - // Creates a empty metadata FlatBuffer.
  42230. - FlatBufferBuilder builder = new FlatBufferBuilder();
  42231. - ModelMetadata.startModelMetadata(builder);
  42232. - int modelMetadata = ModelMetadata.endModelMetadata(builder);
  42233. - builder.finish(modelMetadata, TFLITE_METADATA_IDENTIFIER);
  42234. -
  42235. - ByteBuffer emptyMetadata = builder.dataBuffer();
  42236. - ByteBuffer modelWithEmptyMetadata = createModelByteBuffer(emptyMetadata, DATA_TYPE);
  42237. -
  42238. - IllegalArgumentException exception =
  42239. - assertThrows(
  42240. - IllegalArgumentException.class, () -> new MetadataExtractor(modelWithEmptyMetadata));
  42241. - assertThat(exception)
  42242. - .hasMessageThat()
  42243. - .contains("The metadata flatbuffer does not contain any subgraph metadata.");
  42244. - }
  42245. -
  42246. - @Test
  42247. - public void metadataExtractor_modelWithNoMetadata_throwsException() throws Exception {
  42248. - // Creates a model flatbuffer without metadata.
  42249. - ByteBuffer modelWithoutMetadata = createModelByteBuffer(/*metadataBuffer=*/ null, DATA_TYPE);
  42250. -
  42251. - // It is allowed to create a model without metadata, but invoking methods that reads metadata
  42252. - // is not allowed.
  42253. - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithoutMetadata);
  42254. -
  42255. - IllegalStateException exception =
  42256. - assertThrows(
  42257. - IllegalStateException.class, () -> metadataExtractor.getInputTensorMetadata(0));
  42258. - assertThat(exception)
  42259. - .hasMessageThat()
  42260. - .contains("This model does not contain model metadata.");
  42261. - }
  42262. -
  42263. - @Test
  42264. - public void metadataExtractor_modelWithIrrelevantMetadata_throwsException() throws Exception {
  42265. - // Creates a model with irrelevant metadata.
  42266. - FlatBufferBuilder builder = new FlatBufferBuilder();
  42267. - SubGraph.startSubGraph(builder);
  42268. - int subgraph = SubGraph.endSubGraph(builder);
  42269. -
  42270. - int metadataName = builder.createString("Irrelevant metadata");
  42271. - Metadata.startMetadata(builder);
  42272. - Metadata.addName(builder, metadataName);
  42273. - int metadata = Metadata.endMetadata(builder);
  42274. - int metadataArray = Model.createMetadataVector(builder, new int[] {metadata});
  42275. -
  42276. - // Creates Model.
  42277. - int[] subgraphs = new int[1];
  42278. - subgraphs[0] = subgraph;
  42279. - int modelSubgraphs = Model.createSubgraphsVector(builder, subgraphs);
  42280. - Model.startModel(builder);
  42281. - Model.addSubgraphs(builder, modelSubgraphs);
  42282. - Model.addMetadata(builder, metadataArray);
  42283. - int model = Model.endModel(builder);
  42284. - builder.finish(model, TFLITE_MODEL_IDENTIFIER);
  42285. - ByteBuffer modelBuffer = builder.dataBuffer();
  42286. -
  42287. - // It is allowed to create a model without metadata, but invoking methods that reads metadata
  42288. - // is not allowed.
  42289. - MetadataExtractor metadataExtractor = new MetadataExtractor(modelBuffer);
  42290. -
  42291. - IllegalStateException exception =
  42292. - assertThrows(
  42293. - IllegalStateException.class, () -> metadataExtractor.getInputTensorMetadata(0));
  42294. - assertThat(exception)
  42295. - .hasMessageThat()
  42296. - .contains("This model does not contain model metadata.");
  42297. - }
  42298. -
  42299. - @Test
  42300. - public void getInputTensorMetadata_validTensor() throws Exception {
  42301. - // Creates a model flatbuffer with metadata.
  42302. - ByteBuffer modelWithMetadata = createModelByteBuffer();
  42303. -
  42304. - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
  42305. - TensorMetadata inputMetadata = metadataExtractor.getInputTensorMetadata(0);
  42306. - assertThat(inputMetadata.content().contentPropertiesType())
  42307. - .isEqualTo(CONTENT_PROPERTIES_TYPE);
  42308. - }
  42309. -
  42310. - @Test
  42311. - public void getInputTensorMetadata_emptyTensor() throws Exception {
  42312. - // Creates a model flatbuffer with metadata.
  42313. - ByteBuffer modelWithMetadata = createModelByteBuffer();
  42314. -
  42315. - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
  42316. - TensorMetadata inputMetadata = metadataExtractor.getInputTensorMetadata(1);
  42317. - assertThat(inputMetadata.content()).isNull();
  42318. - }
  42319. -
  42320. - @Test
  42321. - public void getInputTensorMetadata_invalidTensor() throws Exception {
  42322. - // Creates a model flatbuffer with metadata.
  42323. - ByteBuffer modelWithMetadata = createModelByteBuffer();
  42324. -
  42325. - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
  42326. - TensorMetadata inputMetadata = metadataExtractor.getInputTensorMetadata(2);
  42327. - assertThat(inputMetadata.content().contentPropertiesType())
  42328. - .isEqualTo(CONTENT_PROPERTIES_TYPE);
  42329. - }
  42330. -
  42331. - @Test
  42332. - public void getOutputTensorMetadata_validTensor() throws Exception {
  42333. - // Creates a model flatbuffer with metadata.
  42334. - ByteBuffer modelWithMetadata = createModelByteBuffer();
  42335. -
  42336. - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
  42337. - TensorMetadata outputMetadata = metadataExtractor.getOutputTensorMetadata(0);
  42338. - assertThat(outputMetadata.content().contentPropertiesType())
  42339. - .isEqualTo(CONTENT_PROPERTIES_TYPE);
  42340. - }
  42341. -
  42342. - @Test
  42343. - public void getOutputTensorMetadata_emptyTensor() throws Exception {
  42344. - // Creates a model flatbuffer with metadata.
  42345. - ByteBuffer modelWithMetadata = createModelByteBuffer();
  42346. -
  42347. - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
  42348. - TensorMetadata outputMetadata = metadataExtractor.getOutputTensorMetadata(1);
  42349. - assertThat(outputMetadata.content()).isNull();
  42350. - }
  42351. -
  42352. - @Test
  42353. - public void getOutputTensorMetadata_invalidTensor() throws Exception {
  42354. - // Creates a model flatbuffer with metadata.
  42355. - ByteBuffer modelWithMetadata = createModelByteBuffer();
  42356. -
  42357. - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
  42358. - TensorMetadata outputMetadata = metadataExtractor.getOutputTensorMetadata(2);
  42359. - assertThat(outputMetadata.content().contentPropertiesType())
  42360. - .isEqualTo(CONTENT_PROPERTIES_TYPE);
  42361. - }
  42362. -
  42363. - @Test
  42364. - public void getInputTensorMetadata_indexGreaterThanTensorNumber_throwsException()
  42365. - throws Exception {
  42366. - // Creates a model flatbuffer with metadata.
  42367. - ByteBuffer modelWithMetadata = createModelByteBuffer();
  42368. -
  42369. - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
  42370. - IllegalArgumentException exception =
  42371. - assertThrows(
  42372. - IllegalArgumentException.class, () -> metadataExtractor.getInputTensorMetadata(3));
  42373. - assertThat(exception).hasMessageThat().contains("The inputIndex specified is invalid.");
  42374. - }
  42375. -
  42376. - @Test
  42377. - public void getInputTensorMetadata_negtiveIndex_throwsException() throws Exception {
  42378. - // Creates a model flatbuffer with metadata.
  42379. - ByteBuffer modelWithMetadata = createModelByteBuffer();
  42380. -
  42381. - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
  42382. - IllegalArgumentException exception =
  42383. - assertThrows(
  42384. - IllegalArgumentException.class, () -> metadataExtractor.getInputTensorMetadata(-1));
  42385. - assertThat(exception).hasMessageThat().contains("The inputIndex specified is invalid.");
  42386. - }
  42387. -
  42388. - @Test
  42389. - public void getOutputTensorMetadata_indexGreaterThanTensorNumber_throwsException()
  42390. - throws Exception {
  42391. - // Creates a model flatbuffer with metadata.
  42392. - ByteBuffer modelWithMetadata = createModelByteBuffer();
  42393. -
  42394. - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
  42395. - IllegalArgumentException exception =
  42396. - assertThrows(
  42397. - IllegalArgumentException.class, () -> metadataExtractor.getOutputTensorMetadata(3));
  42398. - assertThat(exception).hasMessageThat().contains("The outputIndex specified is invalid.");
  42399. - }
  42400. -
  42401. - @Test
  42402. - public void getOutputTensorMetadata_negtiveIndex_throwsException() throws Exception {
  42403. - // Creates a model flatbuffer with metadata.
  42404. - ByteBuffer modelWithMetadata = createModelByteBuffer();
  42405. -
  42406. - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
  42407. - IllegalArgumentException exception =
  42408. - assertThrows(
  42409. - IllegalArgumentException.class, () -> metadataExtractor.getOutputTensorMetadata(-1));
  42410. - assertThat(exception).hasMessageThat().contains("The outputIndex specified is invalid.");
  42411. - }
  42412. -
  42413. - @Test
  42414. - public void getInputTensorQuantizationParams_validScaleAndZeroPoint() throws Exception {
  42415. - // Creates a model flatbuffer with metadata.
  42416. - ByteBuffer modelWithMetadata = createModelByteBuffer();
  42417. -
  42418. - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
  42419. - QuantizationParams quantizationParams = metadataExtractor.getInputTensorQuantizationParams(0);
  42420. - assertThat(quantizationParams.getScale()).isEqualTo(VALID_SCALE);
  42421. - assertThat(quantizationParams.getZeroPoint()).isEqualTo(VALID_ZERO_POINT);
  42422. - }
  42423. -
  42424. - @Test
  42425. - public void getInputTensorQuantizationParams_emptyTensor() throws Exception {
  42426. - // Creates a model flatbuffer with metadata.
  42427. - ByteBuffer modelWithMetadata = createModelByteBuffer();
  42428. -
  42429. - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
  42430. - QuantizationParams quantizationParams = metadataExtractor.getInputTensorQuantizationParams(1);
  42431. - // Scale and zero point are expected to be 1.0f and 0, respectively as default.
  42432. - assertThat(quantizationParams.getScale()).isEqualTo(DEFAULT_SCALE);
  42433. - assertThat(quantizationParams.getZeroPoint()).isEqualTo(DEFAULT_ZERO_POINT);
  42434. - }
  42435. -
  42436. - @Test
  42437. - public void getInputTensorQuantizationParams_invalidScale() throws Exception {
  42438. - // Creates a model flatbuffer with metadata.
  42439. - ByteBuffer modelWithMetadata = createModelByteBuffer();
  42440. -
  42441. - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
  42442. - IllegalArgumentException exception =
  42443. - assertThrows(
  42444. - IllegalArgumentException.class,
  42445. - () -> metadataExtractor.getInputTensorQuantizationParams(2));
  42446. - assertThat(exception)
  42447. - .hasMessageThat()
  42448. - .contains("Input and output tensors do not support per-channel quantization.");
  42449. - }
  42450. -
  42451. - @Test
  42452. - public void getOutputTensorQuantizationParams_validScaleAndZeroPoint() throws Exception {
  42453. - // Creates a model flatbuffer with metadata.
  42454. - ByteBuffer modelWithMetadata = createModelByteBuffer();
  42455. -
  42456. - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
  42457. - QuantizationParams quantizationParams =
  42458. - metadataExtractor.getOutputTensorQuantizationParams(0);
  42459. - assertThat(quantizationParams.getScale()).isEqualTo(VALID_SCALE);
  42460. - assertThat(quantizationParams.getZeroPoint()).isEqualTo(VALID_ZERO_POINT);
  42461. - }
  42462. -
  42463. - @Test
  42464. - public void getOutputTensorQuantizationParams_emptyTensor() throws Exception {
  42465. - // Creates a model flatbuffer with metadata.
  42466. - ByteBuffer modelWithMetadata = createModelByteBuffer();
  42467. -
  42468. - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
  42469. - QuantizationParams quantizationParams =
  42470. - metadataExtractor.getOutputTensorQuantizationParams(1);
  42471. - // Scale and zero point are expected to be 1.0f and 0, respectively as default.
  42472. - assertThat(quantizationParams.getScale()).isEqualTo(DEFAULT_SCALE);
  42473. - assertThat(quantizationParams.getZeroPoint()).isEqualTo(DEFAULT_ZERO_POINT);
  42474. - }
  42475. -
  42476. - @Test
  42477. - public void getOutputTensorQuantizationParams_invalidScale() throws Exception {
  42478. - // Creates a model flatbuffer with metadata.
  42479. - ByteBuffer modelWithMetadata = createModelByteBuffer();
  42480. -
  42481. - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
  42482. - IllegalArgumentException exception =
  42483. - assertThrows(
  42484. - IllegalArgumentException.class,
  42485. - () -> metadataExtractor.getOutputTensorQuantizationParams(2));
  42486. - assertThat(exception)
  42487. - .hasMessageThat()
  42488. - .contains("Input and output tensors do not support per-channel quantization.");
  42489. - }
  42490. -
  42491. - @Test
  42492. - public void isMinimumParserVersionSatisfied_olderVersion() throws Exception {
  42493. - // A version older than the current one. The version starts from 1.0.0, thus 0.10.0 will
  42494. - // precede any furture versions.
  42495. - String minVersion = "0.10";
  42496. - // Creates a metadata using the above version.
  42497. - ByteBuffer metadata = createMetadataByteBuffer(TFLITE_METADATA_IDENTIFIER, minVersion);
  42498. - ByteBuffer modelWithMetadata = createModelByteBuffer(metadata, DATA_TYPE);
  42499. -
  42500. - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
  42501. -
  42502. - assertThat(metadataExtractor.isMinimumParserVersionSatisfied()).isTrue();
  42503. - }
  42504. -
  42505. - @Test
  42506. - public void isMinimumParserVersionSatisfied_sameVersionSamelength() throws Exception {
  42507. - // A version the same as the current one.
  42508. - String minVersion = MetadataParser.VERSION;
  42509. - // Creates a metadata using the above version.
  42510. - ByteBuffer metadata = createMetadataByteBuffer(TFLITE_METADATA_IDENTIFIER, minVersion);
  42511. - ByteBuffer modelWithMetadata = createModelByteBuffer(metadata, DATA_TYPE);
  42512. -
  42513. - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
  42514. -
  42515. - assertThat(metadataExtractor.isMinimumParserVersionSatisfied()).isTrue();
  42516. - }
  42517. -
  42518. - @Test
  42519. - public void isMinimumParserVersionSatisfied_sameVersionLongerlength() throws Exception {
  42520. - // A version the same as the current one, but with longer length.
  42521. - String minVersion = MetadataParser.VERSION + ".0";
  42522. - // Creates a metadata using the above version.
  42523. - ByteBuffer metadata = createMetadataByteBuffer(TFLITE_METADATA_IDENTIFIER, minVersion);
  42524. - ByteBuffer modelWithMetadata = createModelByteBuffer(metadata, DATA_TYPE);
  42525. -
  42526. - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
  42527. -
  42528. - assertThat(metadataExtractor.isMinimumParserVersionSatisfied()).isTrue();
  42529. - }
  42530. -
  42531. - @Test
  42532. - public void isMinimumParserVersionSatisfied_emptyVersion() throws Exception {
  42533. - // An empty version, which can be generated before the first versioned release.
  42534. - String minVersion = null;
  42535. - // Creates a metadata using the above version.
  42536. - ByteBuffer metadata = createMetadataByteBuffer(TFLITE_METADATA_IDENTIFIER, minVersion);
  42537. - ByteBuffer modelWithMetadata = createModelByteBuffer(metadata, DATA_TYPE);
  42538. -
  42539. - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
  42540. -
  42541. - assertThat(metadataExtractor.isMinimumParserVersionSatisfied()).isTrue();
  42542. - }
  42543. -
  42544. - @Test
  42545. - public void isMinimumParserVersionSatisfied_newerVersion() throws Exception {
  42546. - // Creates a version newer than the current one by appending "1" to the end of the current
  42547. - // version for testing purposes. For example, 1.0.0 becomes 1.0.01.
  42548. - String minVersion = MetadataParser.VERSION + "1";
  42549. - // Creates a metadata using the above version.
  42550. - ByteBuffer metadata = createMetadataByteBuffer(TFLITE_METADATA_IDENTIFIER, minVersion);
  42551. - ByteBuffer modelWithMetadata = createModelByteBuffer(metadata, DATA_TYPE);
  42552. -
  42553. - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
  42554. -
  42555. - assertThat(metadataExtractor.isMinimumParserVersionSatisfied()).isFalse();
  42556. - }
  42557. -
  42558. - @Test
  42559. - public void isMinimumParserVersionSatisfied_newerVersionLongerLength() throws Exception {
  42560. - // Creates a version newer than the current one by appending ".1" to the end of the current
  42561. - // version for testing purposes. For example, 1.0.0 becomes 1.0.0.1.
  42562. - String minVersion = MetadataParser.VERSION + ".1";
  42563. - // Creates a metadata using the above version.
  42564. - ByteBuffer metadata = createMetadataByteBuffer(TFLITE_METADATA_IDENTIFIER, minVersion);
  42565. - ByteBuffer modelWithMetadata = createModelByteBuffer(metadata, DATA_TYPE);
  42566. -
  42567. - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
  42568. -
  42569. - assertThat(metadataExtractor.isMinimumParserVersionSatisfied()).isFalse();
  42570. - }
  42571. - }
  42572. -
  42573. - /** Parameterized tests for the input tensor data type. */
  42574. - @RunWith(ParameterizedRobolectricTestRunner.class)
  42575. - public static final class InputTensorType extends MetadataExtractorTest {
  42576. - /** The tensor type that used to create the model buffer. */
  42577. - @Parameter(0)
  42578. - public byte tensorType;
  42579. -
  42580. - /** A list of TensorType that is used in the test. */
  42581. - @Parameters
  42582. - public static Collection<Object[]> data() {
  42583. - return Arrays.asList(
  42584. - new Object[][] {
  42585. - {TensorType.FLOAT32}, {TensorType.INT32},
  42586. - {TensorType.UINT8}, {TensorType.INT64},
  42587. - {TensorType.STRING}
  42588. - });
  42589. - }
  42590. -
  42591. - @Test
  42592. - public void getInputTensorType_validTensor() throws Exception {
  42593. - ByteBuffer metadata = createMetadataByteBuffer(TFLITE_METADATA_IDENTIFIER, null);
  42594. - ByteBuffer modelWithMetadata = createModelByteBuffer(metadata, tensorType);
  42595. - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
  42596. - byte type = metadataExtractor.getInputTensorType(0);
  42597. - assertThat(type).isEqualTo(tensorType);
  42598. - }
  42599. -
  42600. - @Test
  42601. - public void getOutputTensorType_validTensor() throws Exception {
  42602. - ByteBuffer metadata = createMetadataByteBuffer(TFLITE_METADATA_IDENTIFIER, null);
  42603. - ByteBuffer modelWithMetadata = createModelByteBuffer(metadata, tensorType);
  42604. - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
  42605. - byte type = metadataExtractor.getOutputTensorType(0);
  42606. - assertThat(type).isEqualTo(tensorType);
  42607. - }
  42608. - }
  42609. -
  42610. - /**
  42611. - * Creates an example metadata flatbuffer, which contains one subgraph with three inputs and three
  42612. - * outputs.
  42613. - */
  42614. - private static ByteBuffer createMetadataByteBuffer(
  42615. - String identifier, @Nullable String minVersionStr) {
  42616. - FlatBufferBuilder builder = new FlatBufferBuilder();
  42617. -
  42618. - Content.startContent(builder);
  42619. - Content.addContentPropertiesType(builder, CONTENT_PROPERTIES_TYPE);
  42620. - int content = Content.endContent(builder);
  42621. -
  42622. - TensorMetadata.startTensorMetadata(builder);
  42623. - TensorMetadata.addContent(builder, content);
  42624. - int metadataForValidTensor = TensorMetadata.endTensorMetadata(builder);
  42625. -
  42626. - TensorMetadata.startTensorMetadata(builder);
  42627. - int metadataForEmptyTensor = TensorMetadata.endTensorMetadata(builder);
  42628. -
  42629. - TensorMetadata.startTensorMetadata(builder);
  42630. - TensorMetadata.addContent(builder, content);
  42631. - int metadataForInvalidTensor = TensorMetadata.endTensorMetadata(builder);
  42632. -
  42633. - int[] tensorMetadataArray =
  42634. - new int[] {metadataForValidTensor, metadataForEmptyTensor, metadataForInvalidTensor};
  42635. - int inputTensorMetadata =
  42636. - SubGraphMetadata.createInputTensorMetadataVector(builder, tensorMetadataArray);
  42637. - int outputTensorMetadata =
  42638. - SubGraphMetadata.createOutputTensorMetadataVector(builder, tensorMetadataArray);
  42639. -
  42640. - SubGraphMetadata.startSubGraphMetadata(builder);
  42641. - SubGraphMetadata.addInputTensorMetadata(builder, inputTensorMetadata);
  42642. - SubGraphMetadata.addOutputTensorMetadata(builder, outputTensorMetadata);
  42643. - int subgraph1Metadata = SubGraphMetadata.endSubGraphMetadata(builder);
  42644. -
  42645. - int[] subgraphMetadataArray = new int[] {subgraph1Metadata};
  42646. - int subgraphsMetadata =
  42647. - ModelMetadata.createSubgraphMetadataVector(builder, subgraphMetadataArray);
  42648. -
  42649. - int modelName = builder.createString(MODEL_NAME);
  42650. - if (minVersionStr != null) {
  42651. - int minVersion = builder.createString(minVersionStr);
  42652. - ModelMetadata.startModelMetadata(builder);
  42653. - ModelMetadata.addMinParserVersion(builder, minVersion);
  42654. - } else {
  42655. - // If minVersionStr is null, skip generating the field in the metadata.
  42656. - ModelMetadata.startModelMetadata(builder);
  42657. - }
  42658. - ModelMetadata.addName(builder, modelName);
  42659. - ModelMetadata.addSubgraphMetadata(builder, subgraphsMetadata);
  42660. - int modelMetadata = ModelMetadata.endModelMetadata(builder);
  42661. -
  42662. - builder.finish(modelMetadata, identifier);
  42663. - return builder.dataBuffer();
  42664. - }
  42665. -
  42666. - private static int createQuantizationParameters(
  42667. - FlatBufferBuilder builder, float[] scale, long[] zeroPoint) {
  42668. - int inputScale = QuantizationParameters.createScaleVector(builder, scale);
  42669. - int inputZeroPoint = QuantizationParameters.createZeroPointVector(builder, zeroPoint);
  42670. - QuantizationParameters.startQuantizationParameters(builder);
  42671. - QuantizationParameters.addScale(builder, inputScale);
  42672. - QuantizationParameters.addZeroPoint(builder, inputZeroPoint);
  42673. - return QuantizationParameters.endQuantizationParameters(builder);
  42674. - }
  42675. -
  42676. - private static int createTensor(
  42677. - FlatBufferBuilder builder, int[] inputShape, byte inputType, int inputQuantization) {
  42678. - int inputShapeVector1 = Tensor.createShapeVector(builder, inputShape);
  42679. - Tensor.startTensor(builder);
  42680. - Tensor.addShape(builder, inputShapeVector1);
  42681. - Tensor.addType(builder, inputType);
  42682. - Tensor.addQuantization(builder, inputQuantization);
  42683. - return Tensor.endTensor(builder);
  42684. - }
  42685. -
  42686. - /**
  42687. - * Creates an example model flatbuffer, which contains one subgraph with three inputs and three
  42688. - * output.
  42689. - */
  42690. - private static ByteBuffer createModelByteBuffer(ByteBuffer metadataBuffer, byte dataType) {
  42691. - FlatBufferBuilder builder = new FlatBufferBuilder();
  42692. -
  42693. - // Creates a valid set of quantization parameters.
  42694. - int validQuantization =
  42695. - createQuantizationParameters(
  42696. - builder, new float[] {VALID_SCALE}, new long[] {VALID_ZERO_POINT});
  42697. -
  42698. - // Creates an invalid set of quantization parameters.
  42699. - int inValidQuantization = createQuantizationParameters(builder, invalidScale, invalidZeroPoint);
  42700. -
  42701. - // Creates an input Tensor with valid quantization parameters.
  42702. - int validTensor = createTensor(builder, validShape, dataType, validQuantization);
  42703. -
  42704. - // Creates an empty input Tensor.
  42705. - Tensor.startTensor(builder);
  42706. - int emptyTensor = Tensor.endTensor(builder);
  42707. -
  42708. - // Creates an input Tensor with invalid quantization parameters.
  42709. - int invalidTensor = createTensor(builder, validShape, dataType, inValidQuantization);
  42710. -
  42711. - // Creates the SubGraph.
  42712. - int[] tensors = new int[6];
  42713. - tensors[0] = validTensor;
  42714. - tensors[1] = emptyTensor;
  42715. - tensors[2] = invalidTensor;
  42716. - tensors[3] = validTensor;
  42717. - tensors[4] = emptyTensor;
  42718. - tensors[5] = invalidTensor;
  42719. - int subgraphTensors = SubGraph.createTensorsVector(builder, tensors);
  42720. -
  42721. - int subgraphInputs = SubGraph.createInputsVector(builder, new int[] {0, 1, 2});
  42722. - int subgraphOutputs = SubGraph.createOutputsVector(builder, new int[] {3, 4, 5});
  42723. -
  42724. - SubGraph.startSubGraph(builder);
  42725. - SubGraph.addTensors(builder, subgraphTensors);
  42726. - SubGraph.addInputs(builder, subgraphInputs);
  42727. - SubGraph.addOutputs(builder, subgraphOutputs);
  42728. - int subgraph = SubGraph.endSubGraph(builder);
  42729. -
  42730. - // Creates the Model.
  42731. - int[] subgraphs = new int[1];
  42732. - subgraphs[0] = subgraph;
  42733. - int modelSubgraphs = Model.createSubgraphsVector(builder, subgraphs);
  42734. -
  42735. - // Inserts metadataBuffer into the model if it's not null.
  42736. - int modelBuffers = EMPTY_FLATBUFFER_VECTOR;
  42737. - int metadataArray = EMPTY_FLATBUFFER_VECTOR;
  42738. - if (metadataBuffer != null) {
  42739. - int data = Buffer.createDataVector(builder, metadataBuffer);
  42740. - Buffer.startBuffer(builder);
  42741. - Buffer.addData(builder, data);
  42742. - int buffer = Buffer.endBuffer(builder);
  42743. - modelBuffers = Model.createBuffersVector(builder, new int[] {buffer});
  42744. -
  42745. - int metadataName = builder.createString(ModelInfo.METADATA_FIELD_NAME);
  42746. - Metadata.startMetadata(builder);
  42747. - Metadata.addName(builder, metadataName);
  42748. - Metadata.addBuffer(builder, 0);
  42749. - int metadata = Metadata.endMetadata(builder);
  42750. - metadataArray = Model.createMetadataVector(builder, new int[] {metadata});
  42751. - }
  42752. -
  42753. - Model.startModel(builder);
  42754. - Model.addSubgraphs(builder, modelSubgraphs);
  42755. - if (modelBuffers != EMPTY_FLATBUFFER_VECTOR && metadataArray != EMPTY_FLATBUFFER_VECTOR) {
  42756. - Model.addBuffers(builder, modelBuffers);
  42757. - Model.addMetadata(builder, metadataArray);
  42758. + private static final int[] validShape = new int[] {4, 10, 10, 3};
  42759. + private static final byte DATA_TYPE = TensorType.UINT8;
  42760. + private static final byte CONTENT_PROPERTIES_TYPE = ContentProperties.ImageProperties;
  42761. + private static final float VALID_SCALE = 3.3f;
  42762. + private static final long VALID_ZERO_POINT = 2;
  42763. + private static final float DEFAULT_SCALE = 0.0f;
  42764. + private static final long DEFAULT_ZERO_POINT = 0;
  42765. + private static final String MODEL_NAME = "model.tflite";
  42766. + // Scale and zero point should both be a single value, not an array.
  42767. + private static final float[] invalidScale = new float[] {0.0f, 1.2f};
  42768. + private static final long[] invalidZeroPoint = new long[] {1, 2};
  42769. + private static final String MODEL_PATH = "mobilenet_v1_1.0_224_quant.tflite";
  42770. + // labels.txt is packed in mobilenet_v1_1.0_224_quant.tflite as an associated file.
  42771. + private static final String VALID_LABEL_FILE_NAME = "labels.txt";
  42772. + // invalid.txt is not packed in mobilenet_v1_1.0_224_quant.tflite.
  42773. + private static final String INVALID_LABEL_FILE_NAME = "invalid.txt";
  42774. + private static final int EMPTY_FLATBUFFER_VECTOR = -1;
  42775. + private static final String TFLITE_MODEL_IDENTIFIER = "TFL3";
  42776. + private static final String TFLITE_METADATA_IDENTIFIER = "M001";
  42777. +
  42778. + /** General tests of MetadataExtractor. */
  42779. + @RunWith(RobolectricTestRunner.class)
  42780. + public static final class General extends MetadataExtractorTest {
  42781. + @Test
  42782. + public void hasMetadata_modelWithMetadata() throws Exception {
  42783. + // Creates a model flatbuffer with metadata.
  42784. + ByteBuffer modelWithMetadata = createModelByteBuffer();
  42785. +
  42786. + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
  42787. + assertThat(metadataExtractor.hasMetadata()).isTrue();
  42788. + }
  42789. +
  42790. + @Test
  42791. + public void hasMetadata_modelWithoutMetadata() throws Exception {
  42792. + // Creates a model flatbuffer without metadata.
  42793. + ByteBuffer modelWithoutMetadata =
  42794. + createModelByteBuffer(/*metadataBuffer=*/null, DATA_TYPE);
  42795. +
  42796. + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithoutMetadata);
  42797. + assertThat(metadataExtractor.hasMetadata()).isFalse();
  42798. + }
  42799. +
  42800. + @Ignore
  42801. + @Test
  42802. + public void getAssociatedFile_validAssociateFile() throws Exception {
  42803. + ByteBuffer mobileNetBuffer = loadMobileNetBuffer();
  42804. + MetadataExtractor mobileNetMetadataExtractor = new MetadataExtractor(mobileNetBuffer);
  42805. + InputStream associateFileStream =
  42806. + mobileNetMetadataExtractor.getAssociatedFile(VALID_LABEL_FILE_NAME);
  42807. +
  42808. + // Reads the golden file from context.
  42809. + Context context = ApplicationProvider.getApplicationContext();
  42810. + InputStream goldenAssociateFileStream = context.getAssets().open(VALID_LABEL_FILE_NAME);
  42811. + assertThat(IOUtils.contentEquals(goldenAssociateFileStream, associateFileStream))
  42812. + .isTrue();
  42813. + }
  42814. +
  42815. + @Ignore
  42816. + @Test
  42817. + public void getAssociatedFile_invalidAssociateFile() throws Exception {
  42818. + ByteBuffer mobileNetBuffer = loadMobileNetBuffer();
  42819. + MetadataExtractor mobileNetMetadataExtractor = new MetadataExtractor(mobileNetBuffer);
  42820. + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
  42821. + () -> mobileNetMetadataExtractor.getAssociatedFile(INVALID_LABEL_FILE_NAME));
  42822. + assertThat(exception).hasMessageThat().isEqualTo(String.format(
  42823. + "The file, %s, does not exist in the zip file.", INVALID_LABEL_FILE_NAME));
  42824. + }
  42825. +
  42826. + @Ignore
  42827. + @Test
  42828. + public void getAssociatedFile_nullFileName() throws Exception {
  42829. + ByteBuffer mobileNetBuffer = loadMobileNetBuffer();
  42830. + MetadataExtractor mobileNetMetadataExtractor = new MetadataExtractor(mobileNetBuffer);
  42831. + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
  42832. + () -> mobileNetMetadataExtractor.getAssociatedFile(/*fileName=*/null));
  42833. + assertThat(exception).hasMessageThat().contains(
  42834. + "The file, null, does not exist in the zip file.");
  42835. + }
  42836. +
  42837. + @Test
  42838. + public void getAssociatedFile_nonZipModel_throwsException() throws Exception {
  42839. + // Creates a model flatbuffer with metadata.
  42840. + ByteBuffer modelWithMetadata = createModelByteBuffer();
  42841. +
  42842. + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
  42843. + IllegalStateException exception = assertThrows(IllegalStateException.class,
  42844. + () -> metadataExtractor.getAssociatedFile(VALID_LABEL_FILE_NAME));
  42845. + assertThat(exception).hasMessageThat().contains(
  42846. + "This model does not contain associated files, and is not a Zip file.");
  42847. + }
  42848. +
  42849. + @Test
  42850. + public void getAssociatedFileNames_nonZipModel_throwsException() throws Exception {
  42851. + // Creates a model flatbuffer with metadata.
  42852. + ByteBuffer modelWithMetadata = createModelByteBuffer();
  42853. +
  42854. + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
  42855. + IllegalStateException exception = assertThrows(
  42856. + IllegalStateException.class, metadataExtractor::getAssociatedFileNames);
  42857. + assertThat(exception).hasMessageThat().contains(
  42858. + "This model does not contain associated files, and is not a Zip file.");
  42859. + }
  42860. +
  42861. + @Ignore
  42862. + @Test
  42863. + public void getAssociatedFileNames_validFileNames() throws Exception {
  42864. + ByteBuffer mobileNetBuffer = loadMobileNetBuffer();
  42865. + MetadataExtractor mobileNetMetadataExtractor = new MetadataExtractor(mobileNetBuffer);
  42866. + Set<String> expectedSet = new HashSet<>();
  42867. + expectedSet.add(VALID_LABEL_FILE_NAME);
  42868. + assertThat(mobileNetMetadataExtractor.getAssociatedFileNames()).isEqualTo(expectedSet);
  42869. + }
  42870. +
  42871. + @Test
  42872. + public void metadataExtractor_loadNullBuffer_throwsException() {
  42873. + ByteBuffer nullBuffer = null;
  42874. + NullPointerException exception = assertThrows(
  42875. + NullPointerException.class, () -> new MetadataExtractor(nullBuffer));
  42876. + assertThat(exception).hasMessageThat().contains("Model flatbuffer cannot be null.");
  42877. + }
  42878. +
  42879. + @Test
  42880. + public void metadataExtractor_loadRandomBuffer_throwsException() {
  42881. + ByteBuffer randomBuffer = createRandomByteBuffer();
  42882. + IllegalArgumentException exception = assertThrows(
  42883. + IllegalArgumentException.class, () -> new MetadataExtractor(randomBuffer));
  42884. + assertThat(exception).hasMessageThat().contains(
  42885. + "The identifier of the model is invalid. The buffer may not be a valid TFLite model"
  42886. + + " flatbuffer.");
  42887. + }
  42888. +
  42889. + @Test
  42890. + public void metadataExtractor_loadModelWithInvalidIdentifier_throwsException() {
  42891. + // Creates a model with an invalid identifier.
  42892. + String invalidIdentifier = "INVI";
  42893. + FlatBufferBuilder builder = new FlatBufferBuilder();
  42894. + Model.startModel(builder);
  42895. + int model = Model.endModel(builder);
  42896. + builder.finish(model, invalidIdentifier);
  42897. + ByteBuffer modelBuffer = builder.dataBuffer();
  42898. +
  42899. + IllegalArgumentException exception = assertThrows(
  42900. + IllegalArgumentException.class, () -> new MetadataExtractor(modelBuffer));
  42901. + assertThat(exception).hasMessageThat().contains(
  42902. + "The identifier of the model is invalid. The buffer may not be a valid TFLite model"
  42903. + + " flatbuffer.");
  42904. + }
  42905. +
  42906. + @Test
  42907. + public void metadataExtractor_loadMetadataWithInvalidIdentifier_throwsException() {
  42908. + // Creates a model with metadata which contains an invalid identifier.
  42909. + String invalidIdentifier = "INVI";
  42910. + ByteBuffer metadata = createMetadataByteBuffer(invalidIdentifier, null);
  42911. + ByteBuffer modelBuffer = createModelByteBuffer(metadata, DATA_TYPE);
  42912. +
  42913. + IllegalArgumentException exception = assertThrows(
  42914. + IllegalArgumentException.class, () -> new MetadataExtractor(modelBuffer));
  42915. + assertThat(exception).hasMessageThat().contains(
  42916. + "The identifier of the metadata is invalid. The buffer may not be a valid TFLite"
  42917. + + " metadata flatbuffer.");
  42918. + }
  42919. +
  42920. + @Test
  42921. + public void getInputTensorCount_validModelFile() throws Exception {
  42922. + // Creates a model flatbuffer with metadata.
  42923. + ByteBuffer modelWithMetadata = createModelByteBuffer();
  42924. +
  42925. + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
  42926. + int count = metadataExtractor.getInputTensorCount();
  42927. + assertThat(count).isEqualTo(3);
  42928. + }
  42929. +
  42930. + @Test
  42931. + public void getOutputTensorCount_validModelFile() throws Exception {
  42932. + // Creates a model flatbuffer with metadata.
  42933. + ByteBuffer modelWithMetadata = createModelByteBuffer();
  42934. +
  42935. + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
  42936. + int count = metadataExtractor.getOutputTensorCount();
  42937. + assertThat(count).isEqualTo(3);
  42938. + }
  42939. +
  42940. + @Test
  42941. + public void getInputTensorShape_validTensorShape() throws Exception {
  42942. + // Creates a model flatbuffer with metadata.
  42943. + ByteBuffer modelWithMetadata = createModelByteBuffer();
  42944. +
  42945. + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
  42946. + int[] shape = metadataExtractor.getInputTensorShape(0);
  42947. + assertArrayEquals(validShape, shape);
  42948. + }
  42949. +
  42950. + @Test
  42951. + public void getInputTensorShape_emptyTensor() throws Exception {
  42952. + // Creates a model flatbuffer with metadata.
  42953. + ByteBuffer modelWithMetadata = createModelByteBuffer();
  42954. +
  42955. + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
  42956. + int[] shape = metadataExtractor.getInputTensorShape(1);
  42957. + assertThat(shape).isEmpty();
  42958. + }
  42959. +
  42960. + @Test
  42961. + public void getInputTensorType_emptyTensor() throws Exception {
  42962. + // Creates a model flatbuffer with metadata.
  42963. + ByteBuffer modelWithMetadata = createModelByteBuffer();
  42964. +
  42965. + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
  42966. + byte type = metadataExtractor.getInputTensorType(1);
  42967. + assertThat(type).isEqualTo(TensorType.FLOAT32);
  42968. + }
  42969. +
  42970. + @Test
  42971. + public void getOutputTensorShape_validTensor() throws Exception {
  42972. + // Creates a model flatbuffer with metadata.
  42973. + ByteBuffer modelWithMetadata = createModelByteBuffer();
  42974. +
  42975. + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
  42976. + int[] shape = metadataExtractor.getOutputTensorShape(0);
  42977. + assertArrayEquals(validShape, shape);
  42978. + }
  42979. +
  42980. + @Test
  42981. + public void getOutputTensorShape_emptyTensor() throws Exception {
  42982. + // Creates a model flatbuffer with metadata.
  42983. + ByteBuffer modelWithMetadata = createModelByteBuffer();
  42984. +
  42985. + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
  42986. + int[] shape = metadataExtractor.getOutputTensorShape(1);
  42987. + assertThat(shape).isEmpty();
  42988. + }
  42989. +
  42990. + @Test
  42991. + public void getOutputTensorType_emptyTensor() throws Exception {
  42992. + // Creates a model flatbuffer with metadata.
  42993. + ByteBuffer modelWithMetadata = createModelByteBuffer();
  42994. +
  42995. + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
  42996. + byte type = metadataExtractor.getOutputTensorType(1);
  42997. + assertThat(type).isEqualTo(TensorType.FLOAT32);
  42998. + }
  42999. +
  43000. + @Test
  43001. + public void getInputTensorShape_indexGreaterThanTensorNumber_throwsException()
  43002. + throws Exception {
  43003. + // Creates a model flatbuffer with metadata.
  43004. + ByteBuffer modelWithMetadata = createModelByteBuffer();
  43005. +
  43006. + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
  43007. + IllegalArgumentException exception = assertThrows(
  43008. + IllegalArgumentException.class, () -> metadataExtractor.getInputTensorShape(3));
  43009. + assertThat(exception).hasMessageThat().contains("The inputIndex specified is invalid.");
  43010. + }
  43011. +
  43012. + @Test
  43013. + public void getInputTensorShape_negtiveIndex_throwsException() throws Exception {
  43014. + // Creates a model flatbuffer with metadata.
  43015. + ByteBuffer modelWithMetadata = createModelByteBuffer();
  43016. +
  43017. + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
  43018. + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
  43019. + () -> metadataExtractor.getInputTensorShape(-1));
  43020. + assertThat(exception).hasMessageThat().contains("The inputIndex specified is invalid.");
  43021. + }
  43022. +
  43023. + @Test
  43024. + public void getOutputTensorShape_indexGreaterThanTensorNumber_throwsException()
  43025. + throws Exception {
  43026. + // Creates a model flatbuffer with metadata.
  43027. + ByteBuffer modelWithMetadata = createModelByteBuffer();
  43028. +
  43029. + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
  43030. + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
  43031. + () -> metadataExtractor.getOutputTensorShape(3));
  43032. + assertThat(exception).hasMessageThat().contains(
  43033. + "The outputIndex specified is invalid.");
  43034. + }
  43035. +
  43036. + @Test
  43037. + public void getOutputTensorShape_negtiveIndex_throwsException() throws Exception {
  43038. + // Creates a model flatbuffer with metadata.
  43039. + ByteBuffer modelWithMetadata = createModelByteBuffer();
  43040. +
  43041. + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
  43042. + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
  43043. + () -> metadataExtractor.getOutputTensorShape(-1));
  43044. + assertThat(exception).hasMessageThat().contains(
  43045. + "The outputIndex specified is invalid.");
  43046. + }
  43047. +
  43048. + @Test
  43049. + public void getModelMetadata_modelWithMetadata() throws Exception {
  43050. + // Creates a model flatbuffer with metadata.
  43051. + ByteBuffer modelWithMetadata = createModelByteBuffer();
  43052. +
  43053. + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
  43054. + ModelMetadata modelMetadata = metadataExtractor.getModelMetadata();
  43055. + assertThat(modelMetadata.name()).isEqualTo(MODEL_NAME);
  43056. + }
  43057. +
  43058. + @Test
  43059. + public void getModelMetadata_modelWithoutMetadata_throwsException() throws Exception {
  43060. + // Creates a model flatbuffer without metadata.
  43061. + ByteBuffer modelWithoutMetadata =
  43062. + createModelByteBuffer(/*metadataBuffer=*/null, DATA_TYPE);
  43063. +
  43064. + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithoutMetadata);
  43065. +
  43066. + IllegalStateException exception = assertThrows(
  43067. + IllegalStateException.class, () -> metadataExtractor.getModelMetadata());
  43068. + assertThat(exception).hasMessageThat().contains(
  43069. + "This model does not contain model metadata.");
  43070. + }
  43071. +
  43072. + @Test
  43073. + public void metadataExtractor_modelWithEmptySubgraphMetadata_throwsException() {
  43074. + // Creates a metadata FlatBuffer without empty subgraph metadata.
  43075. + FlatBufferBuilder builder = new FlatBufferBuilder();
  43076. + SubGraphMetadata.startSubGraphMetadata(builder);
  43077. + int subgraph1Metadata = SubGraphMetadata.endSubGraphMetadata(builder);
  43078. + int subgraphsMetadata = ModelMetadata.createSubgraphMetadataVector(
  43079. + builder, new int[] {subgraph1Metadata});
  43080. +
  43081. + ModelMetadata.startModelMetadata(builder);
  43082. + ModelMetadata.addSubgraphMetadata(builder, subgraphsMetadata);
  43083. + int modelMetadata = ModelMetadata.endModelMetadata(builder);
  43084. + builder.finish(modelMetadata, TFLITE_METADATA_IDENTIFIER);
  43085. + ByteBuffer emptyMetadata = builder.dataBuffer();
  43086. + ByteBuffer modelWithEmptyMetadata = createModelByteBuffer(emptyMetadata, DATA_TYPE);
  43087. +
  43088. + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
  43089. + () -> new MetadataExtractor(modelWithEmptyMetadata));
  43090. + assertThat(exception).hasMessageThat().isEqualTo(
  43091. + "The number of input tensors in the model is 3. The number of input tensors that"
  43092. + + " recorded in the metadata is 0. These two values does not match.");
  43093. + }
  43094. +
  43095. + @Test
  43096. + public void metadataExtractor_modelWithEmptyMetadata_throwsException() {
  43097. + // Creates a empty metadata FlatBuffer.
  43098. + FlatBufferBuilder builder = new FlatBufferBuilder();
  43099. + ModelMetadata.startModelMetadata(builder);
  43100. + int modelMetadata = ModelMetadata.endModelMetadata(builder);
  43101. + builder.finish(modelMetadata, TFLITE_METADATA_IDENTIFIER);
  43102. +
  43103. + ByteBuffer emptyMetadata = builder.dataBuffer();
  43104. + ByteBuffer modelWithEmptyMetadata = createModelByteBuffer(emptyMetadata, DATA_TYPE);
  43105. +
  43106. + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
  43107. + () -> new MetadataExtractor(modelWithEmptyMetadata));
  43108. + assertThat(exception).hasMessageThat().contains(
  43109. + "The metadata flatbuffer does not contain any subgraph metadata.");
  43110. + }
  43111. +
  43112. + @Test
  43113. + public void metadataExtractor_modelWithNoMetadata_throwsException() throws Exception {
  43114. + // Creates a model flatbuffer without metadata.
  43115. + ByteBuffer modelWithoutMetadata =
  43116. + createModelByteBuffer(/*metadataBuffer=*/null, DATA_TYPE);
  43117. +
  43118. + // It is allowed to create a model without metadata, but invoking methods that reads
  43119. + // metadata is not allowed.
  43120. + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithoutMetadata);
  43121. +
  43122. + IllegalStateException exception = assertThrows(
  43123. + IllegalStateException.class, () -> metadataExtractor.getInputTensorMetadata(0));
  43124. + assertThat(exception).hasMessageThat().contains(
  43125. + "This model does not contain model metadata.");
  43126. + }
  43127. +
  43128. + @Test
  43129. + public void metadataExtractor_modelWithIrrelevantMetadata_throwsException()
  43130. + throws Exception {
  43131. + // Creates a model with irrelevant metadata.
  43132. + FlatBufferBuilder builder = new FlatBufferBuilder();
  43133. + SubGraph.startSubGraph(builder);
  43134. + int subgraph = SubGraph.endSubGraph(builder);
  43135. +
  43136. + int metadataName = builder.createString("Irrelevant metadata");
  43137. + Metadata.startMetadata(builder);
  43138. + Metadata.addName(builder, metadataName);
  43139. + int metadata = Metadata.endMetadata(builder);
  43140. + int metadataArray = Model.createMetadataVector(builder, new int[] {metadata});
  43141. +
  43142. + // Creates Model.
  43143. + int[] subgraphs = new int[1];
  43144. + subgraphs[0] = subgraph;
  43145. + int modelSubgraphs = Model.createSubgraphsVector(builder, subgraphs);
  43146. + Model.startModel(builder);
  43147. + Model.addSubgraphs(builder, modelSubgraphs);
  43148. + Model.addMetadata(builder, metadataArray);
  43149. + int model = Model.endModel(builder);
  43150. + builder.finish(model, TFLITE_MODEL_IDENTIFIER);
  43151. + ByteBuffer modelBuffer = builder.dataBuffer();
  43152. +
  43153. + // It is allowed to create a model without metadata, but invoking methods that reads
  43154. + // metadata is not allowed.
  43155. + MetadataExtractor metadataExtractor = new MetadataExtractor(modelBuffer);
  43156. +
  43157. + IllegalStateException exception = assertThrows(
  43158. + IllegalStateException.class, () -> metadataExtractor.getInputTensorMetadata(0));
  43159. + assertThat(exception).hasMessageThat().contains(
  43160. + "This model does not contain model metadata.");
  43161. + }
  43162. +
  43163. + @Test
  43164. + public void getInputTensorMetadata_validTensor() throws Exception {
  43165. + // Creates a model flatbuffer with metadata.
  43166. + ByteBuffer modelWithMetadata = createModelByteBuffer();
  43167. +
  43168. + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
  43169. + TensorMetadata inputMetadata = metadataExtractor.getInputTensorMetadata(0);
  43170. + assertThat(inputMetadata.content().contentPropertiesType())
  43171. + .isEqualTo(CONTENT_PROPERTIES_TYPE);
  43172. + }
  43173. +
  43174. + @Test
  43175. + public void getInputTensorMetadata_emptyTensor() throws Exception {
  43176. + // Creates a model flatbuffer with metadata.
  43177. + ByteBuffer modelWithMetadata = createModelByteBuffer();
  43178. +
  43179. + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
  43180. + TensorMetadata inputMetadata = metadataExtractor.getInputTensorMetadata(1);
  43181. + assertThat(inputMetadata.content()).isNull();
  43182. + }
  43183. +
  43184. + @Test
  43185. + public void getInputTensorMetadata_invalidTensor() throws Exception {
  43186. + // Creates a model flatbuffer with metadata.
  43187. + ByteBuffer modelWithMetadata = createModelByteBuffer();
  43188. +
  43189. + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
  43190. + TensorMetadata inputMetadata = metadataExtractor.getInputTensorMetadata(2);
  43191. + assertThat(inputMetadata.content().contentPropertiesType())
  43192. + .isEqualTo(CONTENT_PROPERTIES_TYPE);
  43193. + }
  43194. +
  43195. + @Test
  43196. + public void getOutputTensorMetadata_validTensor() throws Exception {
  43197. + // Creates a model flatbuffer with metadata.
  43198. + ByteBuffer modelWithMetadata = createModelByteBuffer();
  43199. +
  43200. + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
  43201. + TensorMetadata outputMetadata = metadataExtractor.getOutputTensorMetadata(0);
  43202. + assertThat(outputMetadata.content().contentPropertiesType())
  43203. + .isEqualTo(CONTENT_PROPERTIES_TYPE);
  43204. + }
  43205. +
  43206. + @Test
  43207. + public void getOutputTensorMetadata_emptyTensor() throws Exception {
  43208. + // Creates a model flatbuffer with metadata.
  43209. + ByteBuffer modelWithMetadata = createModelByteBuffer();
  43210. +
  43211. + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
  43212. + TensorMetadata outputMetadata = metadataExtractor.getOutputTensorMetadata(1);
  43213. + assertThat(outputMetadata.content()).isNull();
  43214. + }
  43215. +
  43216. + @Test
  43217. + public void getOutputTensorMetadata_invalidTensor() throws Exception {
  43218. + // Creates a model flatbuffer with metadata.
  43219. + ByteBuffer modelWithMetadata = createModelByteBuffer();
  43220. +
  43221. + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
  43222. + TensorMetadata outputMetadata = metadataExtractor.getOutputTensorMetadata(2);
  43223. + assertThat(outputMetadata.content().contentPropertiesType())
  43224. + .isEqualTo(CONTENT_PROPERTIES_TYPE);
  43225. + }
  43226. +
  43227. + @Test
  43228. + public void getInputTensorMetadata_indexGreaterThanTensorNumber_throwsException()
  43229. + throws Exception {
  43230. + // Creates a model flatbuffer with metadata.
  43231. + ByteBuffer modelWithMetadata = createModelByteBuffer();
  43232. +
  43233. + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
  43234. + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
  43235. + () -> metadataExtractor.getInputTensorMetadata(3));
  43236. + assertThat(exception).hasMessageThat().contains("The inputIndex specified is invalid.");
  43237. + }
  43238. +
  43239. + @Test
  43240. + public void getInputTensorMetadata_negtiveIndex_throwsException() throws Exception {
  43241. + // Creates a model flatbuffer with metadata.
  43242. + ByteBuffer modelWithMetadata = createModelByteBuffer();
  43243. +
  43244. + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
  43245. + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
  43246. + () -> metadataExtractor.getInputTensorMetadata(-1));
  43247. + assertThat(exception).hasMessageThat().contains("The inputIndex specified is invalid.");
  43248. + }
  43249. +
  43250. + @Test
  43251. + public void getOutputTensorMetadata_indexGreaterThanTensorNumber_throwsException()
  43252. + throws Exception {
  43253. + // Creates a model flatbuffer with metadata.
  43254. + ByteBuffer modelWithMetadata = createModelByteBuffer();
  43255. +
  43256. + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
  43257. + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
  43258. + () -> metadataExtractor.getOutputTensorMetadata(3));
  43259. + assertThat(exception).hasMessageThat().contains(
  43260. + "The outputIndex specified is invalid.");
  43261. + }
  43262. +
  43263. + @Test
  43264. + public void getOutputTensorMetadata_negtiveIndex_throwsException() throws Exception {
  43265. + // Creates a model flatbuffer with metadata.
  43266. + ByteBuffer modelWithMetadata = createModelByteBuffer();
  43267. +
  43268. + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
  43269. + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
  43270. + () -> metadataExtractor.getOutputTensorMetadata(-1));
  43271. + assertThat(exception).hasMessageThat().contains(
  43272. + "The outputIndex specified is invalid.");
  43273. + }
  43274. +
  43275. + @Test
  43276. + public void getInputTensorQuantizationParams_validScaleAndZeroPoint() throws Exception {
  43277. + // Creates a model flatbuffer with metadata.
  43278. + ByteBuffer modelWithMetadata = createModelByteBuffer();
  43279. +
  43280. + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
  43281. + QuantizationParams quantizationParams =
  43282. + metadataExtractor.getInputTensorQuantizationParams(0);
  43283. + assertThat(quantizationParams.getScale()).isEqualTo(VALID_SCALE);
  43284. + assertThat(quantizationParams.getZeroPoint()).isEqualTo(VALID_ZERO_POINT);
  43285. + }
  43286. +
  43287. + @Test
  43288. + public void getInputTensorQuantizationParams_emptyTensor() throws Exception {
  43289. + // Creates a model flatbuffer with metadata.
  43290. + ByteBuffer modelWithMetadata = createModelByteBuffer();
  43291. +
  43292. + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
  43293. + QuantizationParams quantizationParams =
  43294. + metadataExtractor.getInputTensorQuantizationParams(1);
  43295. + // Scale and zero point are expected to be 1.0f and 0, respectively as default.
  43296. + assertThat(quantizationParams.getScale()).isEqualTo(DEFAULT_SCALE);
  43297. + assertThat(quantizationParams.getZeroPoint()).isEqualTo(DEFAULT_ZERO_POINT);
  43298. + }
  43299. +
  43300. + @Test
  43301. + public void getInputTensorQuantizationParams_invalidScale() throws Exception {
  43302. + // Creates a model flatbuffer with metadata.
  43303. + ByteBuffer modelWithMetadata = createModelByteBuffer();
  43304. +
  43305. + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
  43306. + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
  43307. + () -> metadataExtractor.getInputTensorQuantizationParams(2));
  43308. + assertThat(exception).hasMessageThat().contains(
  43309. + "Input and output tensors do not support per-channel quantization.");
  43310. + }
  43311. +
  43312. + @Test
  43313. + public void getOutputTensorQuantizationParams_validScaleAndZeroPoint() throws Exception {
  43314. + // Creates a model flatbuffer with metadata.
  43315. + ByteBuffer modelWithMetadata = createModelByteBuffer();
  43316. +
  43317. + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
  43318. + QuantizationParams quantizationParams =
  43319. + metadataExtractor.getOutputTensorQuantizationParams(0);
  43320. + assertThat(quantizationParams.getScale()).isEqualTo(VALID_SCALE);
  43321. + assertThat(quantizationParams.getZeroPoint()).isEqualTo(VALID_ZERO_POINT);
  43322. + }
  43323. +
  43324. + @Test
  43325. + public void getOutputTensorQuantizationParams_emptyTensor() throws Exception {
  43326. + // Creates a model flatbuffer with metadata.
  43327. + ByteBuffer modelWithMetadata = createModelByteBuffer();
  43328. +
  43329. + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
  43330. + QuantizationParams quantizationParams =
  43331. + metadataExtractor.getOutputTensorQuantizationParams(1);
  43332. + // Scale and zero point are expected to be 1.0f and 0, respectively as default.
  43333. + assertThat(quantizationParams.getScale()).isEqualTo(DEFAULT_SCALE);
  43334. + assertThat(quantizationParams.getZeroPoint()).isEqualTo(DEFAULT_ZERO_POINT);
  43335. + }
  43336. +
  43337. + @Test
  43338. + public void getOutputTensorQuantizationParams_invalidScale() throws Exception {
  43339. + // Creates a model flatbuffer with metadata.
  43340. + ByteBuffer modelWithMetadata = createModelByteBuffer();
  43341. +
  43342. + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
  43343. + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
  43344. + () -> metadataExtractor.getOutputTensorQuantizationParams(2));
  43345. + assertThat(exception).hasMessageThat().contains(
  43346. + "Input and output tensors do not support per-channel quantization.");
  43347. + }
  43348. +
  43349. + @Test
  43350. + public void isMinimumParserVersionSatisfied_olderVersion() throws Exception {
  43351. + // A version older than the current one. The version starts from 1.0.0, thus 0.10.0 will
  43352. + // precede any furture versions.
  43353. + String minVersion = "0.10";
  43354. + // Creates a metadata using the above version.
  43355. + ByteBuffer metadata = createMetadataByteBuffer(TFLITE_METADATA_IDENTIFIER, minVersion);
  43356. + ByteBuffer modelWithMetadata = createModelByteBuffer(metadata, DATA_TYPE);
  43357. +
  43358. + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
  43359. +
  43360. + assertThat(metadataExtractor.isMinimumParserVersionSatisfied()).isTrue();
  43361. + }
  43362. +
  43363. + @Test
  43364. + public void isMinimumParserVersionSatisfied_sameVersionSamelength() throws Exception {
  43365. + // A version the same as the current one.
  43366. + String minVersion = MetadataParser.VERSION;
  43367. + // Creates a metadata using the above version.
  43368. + ByteBuffer metadata = createMetadataByteBuffer(TFLITE_METADATA_IDENTIFIER, minVersion);
  43369. + ByteBuffer modelWithMetadata = createModelByteBuffer(metadata, DATA_TYPE);
  43370. +
  43371. + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
  43372. +
  43373. + assertThat(metadataExtractor.isMinimumParserVersionSatisfied()).isTrue();
  43374. + }
  43375. +
  43376. + @Test
  43377. + public void isMinimumParserVersionSatisfied_sameVersionLongerlength() throws Exception {
  43378. + // A version the same as the current one, but with longer length.
  43379. + String minVersion = MetadataParser.VERSION + ".0";
  43380. + // Creates a metadata using the above version.
  43381. + ByteBuffer metadata = createMetadataByteBuffer(TFLITE_METADATA_IDENTIFIER, minVersion);
  43382. + ByteBuffer modelWithMetadata = createModelByteBuffer(metadata, DATA_TYPE);
  43383. +
  43384. + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
  43385. +
  43386. + assertThat(metadataExtractor.isMinimumParserVersionSatisfied()).isTrue();
  43387. + }
  43388. +
  43389. + @Test
  43390. + public void isMinimumParserVersionSatisfied_emptyVersion() throws Exception {
  43391. + // An empty version, which can be generated before the first versioned release.
  43392. + String minVersion = null;
  43393. + // Creates a metadata using the above version.
  43394. + ByteBuffer metadata = createMetadataByteBuffer(TFLITE_METADATA_IDENTIFIER, minVersion);
  43395. + ByteBuffer modelWithMetadata = createModelByteBuffer(metadata, DATA_TYPE);
  43396. +
  43397. + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
  43398. +
  43399. + assertThat(metadataExtractor.isMinimumParserVersionSatisfied()).isTrue();
  43400. + }
  43401. +
  43402. + @Test
  43403. + public void isMinimumParserVersionSatisfied_newerVersion() throws Exception {
  43404. + // Creates a version newer than the current one by appending "1" to the end of the
  43405. + // current version for testing purposes. For example, 1.0.0 becomes 1.0.01.
  43406. + String minVersion = MetadataParser.VERSION + "1";
  43407. + // Creates a metadata using the above version.
  43408. + ByteBuffer metadata = createMetadataByteBuffer(TFLITE_METADATA_IDENTIFIER, minVersion);
  43409. + ByteBuffer modelWithMetadata = createModelByteBuffer(metadata, DATA_TYPE);
  43410. +
  43411. + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
  43412. +
  43413. + assertThat(metadataExtractor.isMinimumParserVersionSatisfied()).isFalse();
  43414. + }
  43415. +
  43416. + @Test
  43417. + public void isMinimumParserVersionSatisfied_newerVersionLongerLength() throws Exception {
  43418. + // Creates a version newer than the current one by appending ".1" to the end of the
  43419. + // current version for testing purposes. For example, 1.0.0 becomes 1.0.0.1.
  43420. + String minVersion = MetadataParser.VERSION + ".1";
  43421. + // Creates a metadata using the above version.
  43422. + ByteBuffer metadata = createMetadataByteBuffer(TFLITE_METADATA_IDENTIFIER, minVersion);
  43423. + ByteBuffer modelWithMetadata = createModelByteBuffer(metadata, DATA_TYPE);
  43424. +
  43425. + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
  43426. +
  43427. + assertThat(metadataExtractor.isMinimumParserVersionSatisfied()).isFalse();
  43428. + }
  43429. + }
  43430. +
  43431. + /** Parameterized tests for the input tensor data type. */
  43432. + @RunWith(ParameterizedRobolectricTestRunner.class)
  43433. + public static final class InputTensorType extends MetadataExtractorTest {
  43434. + /** The tensor type that used to create the model buffer. */
  43435. + @Parameter(0)
  43436. + public byte tensorType;
  43437. +
  43438. + /** A list of TensorType that is used in the test. */
  43439. + @Parameters
  43440. + public static Collection<Object[]> data() {
  43441. + return Arrays.asList(new Object[][] {{TensorType.FLOAT32}, {TensorType.INT32},
  43442. + {TensorType.UINT8}, {TensorType.INT64}, {TensorType.STRING}});
  43443. + }
  43444. +
  43445. + @Test
  43446. + public void getInputTensorType_validTensor() throws Exception {
  43447. + ByteBuffer metadata = createMetadataByteBuffer(TFLITE_METADATA_IDENTIFIER, null);
  43448. + ByteBuffer modelWithMetadata = createModelByteBuffer(metadata, tensorType);
  43449. + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
  43450. + byte type = metadataExtractor.getInputTensorType(0);
  43451. + assertThat(type).isEqualTo(tensorType);
  43452. + }
  43453. +
  43454. + @Test
  43455. + public void getOutputTensorType_validTensor() throws Exception {
  43456. + ByteBuffer metadata = createMetadataByteBuffer(TFLITE_METADATA_IDENTIFIER, null);
  43457. + ByteBuffer modelWithMetadata = createModelByteBuffer(metadata, tensorType);
  43458. + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
  43459. + byte type = metadataExtractor.getOutputTensorType(0);
  43460. + assertThat(type).isEqualTo(tensorType);
  43461. + }
  43462. + }
  43463. +
  43464. + /**
  43465. + * Creates an example metadata flatbuffer, which contains one subgraph with three inputs and
  43466. + * three outputs.
  43467. + */
  43468. + private static ByteBuffer createMetadataByteBuffer(
  43469. + String identifier, @Nullable String minVersionStr) {
  43470. + FlatBufferBuilder builder = new FlatBufferBuilder();
  43471. +
  43472. + Content.startContent(builder);
  43473. + Content.addContentPropertiesType(builder, CONTENT_PROPERTIES_TYPE);
  43474. + int content = Content.endContent(builder);
  43475. +
  43476. + TensorMetadata.startTensorMetadata(builder);
  43477. + TensorMetadata.addContent(builder, content);
  43478. + int metadataForValidTensor = TensorMetadata.endTensorMetadata(builder);
  43479. +
  43480. + TensorMetadata.startTensorMetadata(builder);
  43481. + int metadataForEmptyTensor = TensorMetadata.endTensorMetadata(builder);
  43482. +
  43483. + TensorMetadata.startTensorMetadata(builder);
  43484. + TensorMetadata.addContent(builder, content);
  43485. + int metadataForInvalidTensor = TensorMetadata.endTensorMetadata(builder);
  43486. +
  43487. + int[] tensorMetadataArray = new int[] {
  43488. + metadataForValidTensor, metadataForEmptyTensor, metadataForInvalidTensor};
  43489. + int inputTensorMetadata =
  43490. + SubGraphMetadata.createInputTensorMetadataVector(builder, tensorMetadataArray);
  43491. + int outputTensorMetadata =
  43492. + SubGraphMetadata.createOutputTensorMetadataVector(builder, tensorMetadataArray);
  43493. +
  43494. + SubGraphMetadata.startSubGraphMetadata(builder);
  43495. + SubGraphMetadata.addInputTensorMetadata(builder, inputTensorMetadata);
  43496. + SubGraphMetadata.addOutputTensorMetadata(builder, outputTensorMetadata);
  43497. + int subgraph1Metadata = SubGraphMetadata.endSubGraphMetadata(builder);
  43498. +
  43499. + int[] subgraphMetadataArray = new int[] {subgraph1Metadata};
  43500. + int subgraphsMetadata =
  43501. + ModelMetadata.createSubgraphMetadataVector(builder, subgraphMetadataArray);
  43502. +
  43503. + int modelName = builder.createString(MODEL_NAME);
  43504. + if (minVersionStr != null) {
  43505. + int minVersion = builder.createString(minVersionStr);
  43506. + ModelMetadata.startModelMetadata(builder);
  43507. + ModelMetadata.addMinParserVersion(builder, minVersion);
  43508. + } else {
  43509. + // If minVersionStr is null, skip generating the field in the metadata.
  43510. + ModelMetadata.startModelMetadata(builder);
  43511. + }
  43512. + ModelMetadata.addName(builder, modelName);
  43513. + ModelMetadata.addSubgraphMetadata(builder, subgraphsMetadata);
  43514. + int modelMetadata = ModelMetadata.endModelMetadata(builder);
  43515. +
  43516. + builder.finish(modelMetadata, identifier);
  43517. + return builder.dataBuffer();
  43518. + }
  43519. +
  43520. + private static int createQuantizationParameters(
  43521. + FlatBufferBuilder builder, float[] scale, long[] zeroPoint) {
  43522. + int inputScale = QuantizationParameters.createScaleVector(builder, scale);
  43523. + int inputZeroPoint = QuantizationParameters.createZeroPointVector(builder, zeroPoint);
  43524. + QuantizationParameters.startQuantizationParameters(builder);
  43525. + QuantizationParameters.addScale(builder, inputScale);
  43526. + QuantizationParameters.addZeroPoint(builder, inputZeroPoint);
  43527. + return QuantizationParameters.endQuantizationParameters(builder);
  43528. + }
  43529. +
  43530. + private static int createTensor(
  43531. + FlatBufferBuilder builder, int[] inputShape, byte inputType, int inputQuantization) {
  43532. + int inputShapeVector1 = Tensor.createShapeVector(builder, inputShape);
  43533. + Tensor.startTensor(builder);
  43534. + Tensor.addShape(builder, inputShapeVector1);
  43535. + Tensor.addType(builder, inputType);
  43536. + Tensor.addQuantization(builder, inputQuantization);
  43537. + return Tensor.endTensor(builder);
  43538. + }
  43539. +
  43540. + /**
  43541. + * Creates an example model flatbuffer, which contains one subgraph with three inputs and three
  43542. + * output.
  43543. + */
  43544. + private static ByteBuffer createModelByteBuffer(ByteBuffer metadataBuffer, byte dataType) {
  43545. + FlatBufferBuilder builder = new FlatBufferBuilder();
  43546. +
  43547. + // Creates a valid set of quantization parameters.
  43548. + int validQuantization = createQuantizationParameters(
  43549. + builder, new float[] {VALID_SCALE}, new long[] {VALID_ZERO_POINT});
  43550. +
  43551. + // Creates an invalid set of quantization parameters.
  43552. + int inValidQuantization =
  43553. + createQuantizationParameters(builder, invalidScale, invalidZeroPoint);
  43554. +
  43555. + // Creates an input Tensor with valid quantization parameters.
  43556. + int validTensor = createTensor(builder, validShape, dataType, validQuantization);
  43557. +
  43558. + // Creates an empty input Tensor.
  43559. + Tensor.startTensor(builder);
  43560. + int emptyTensor = Tensor.endTensor(builder);
  43561. +
  43562. + // Creates an input Tensor with invalid quantization parameters.
  43563. + int invalidTensor = createTensor(builder, validShape, dataType, inValidQuantization);
  43564. +
  43565. + // Creates the SubGraph.
  43566. + int[] tensors = new int[6];
  43567. + tensors[0] = validTensor;
  43568. + tensors[1] = emptyTensor;
  43569. + tensors[2] = invalidTensor;
  43570. + tensors[3] = validTensor;
  43571. + tensors[4] = emptyTensor;
  43572. + tensors[5] = invalidTensor;
  43573. + int subgraphTensors = SubGraph.createTensorsVector(builder, tensors);
  43574. +
  43575. + int subgraphInputs = SubGraph.createInputsVector(builder, new int[] {0, 1, 2});
  43576. + int subgraphOutputs = SubGraph.createOutputsVector(builder, new int[] {3, 4, 5});
  43577. +
  43578. + SubGraph.startSubGraph(builder);
  43579. + SubGraph.addTensors(builder, subgraphTensors);
  43580. + SubGraph.addInputs(builder, subgraphInputs);
  43581. + SubGraph.addOutputs(builder, subgraphOutputs);
  43582. + int subgraph = SubGraph.endSubGraph(builder);
  43583. +
  43584. + // Creates the Model.
  43585. + int[] subgraphs = new int[1];
  43586. + subgraphs[0] = subgraph;
  43587. + int modelSubgraphs = Model.createSubgraphsVector(builder, subgraphs);
  43588. +
  43589. + // Inserts metadataBuffer into the model if it's not null.
  43590. + int modelBuffers = EMPTY_FLATBUFFER_VECTOR;
  43591. + int metadataArray = EMPTY_FLATBUFFER_VECTOR;
  43592. + if (metadataBuffer != null) {
  43593. + int data = Buffer.createDataVector(builder, metadataBuffer);
  43594. + Buffer.startBuffer(builder);
  43595. + Buffer.addData(builder, data);
  43596. + int buffer = Buffer.endBuffer(builder);
  43597. + modelBuffers = Model.createBuffersVector(builder, new int[] {buffer});
  43598. +
  43599. + int metadataName = builder.createString(ModelInfo.METADATA_FIELD_NAME);
  43600. + Metadata.startMetadata(builder);
  43601. + Metadata.addName(builder, metadataName);
  43602. + Metadata.addBuffer(builder, 0);
  43603. + int metadata = Metadata.endMetadata(builder);
  43604. + metadataArray = Model.createMetadataVector(builder, new int[] {metadata});
  43605. + }
  43606. +
  43607. + Model.startModel(builder);
  43608. + Model.addSubgraphs(builder, modelSubgraphs);
  43609. + if (modelBuffers != EMPTY_FLATBUFFER_VECTOR && metadataArray != EMPTY_FLATBUFFER_VECTOR) {
  43610. + Model.addBuffers(builder, modelBuffers);
  43611. + Model.addMetadata(builder, metadataArray);
  43612. + }
  43613. + int model = Model.endModel(builder);
  43614. + builder.finish(model, TFLITE_MODEL_IDENTIFIER);
  43615. +
  43616. + return builder.dataBuffer();
  43617. + }
  43618. +
  43619. + /** Creates an example model flatbuffer with the default metadata and data type. */
  43620. + private static ByteBuffer createModelByteBuffer() {
  43621. + ByteBuffer metadata =
  43622. + createMetadataByteBuffer(TFLITE_METADATA_IDENTIFIER, /*minVersionStr=*/null);
  43623. + return createModelByteBuffer(metadata, DATA_TYPE);
  43624. + }
  43625. +
  43626. + private static ByteBuffer loadMobileNetBuffer() throws Exception {
  43627. + Context context = ApplicationProvider.getApplicationContext();
  43628. + // Loads a MobileNet model flatbuffer with metadata. The MobileNet model is a zip file that
  43629. + // contains a label file as the associated file.
  43630. + AssetFileDescriptor fileDescriptor = context.getAssets().openFd(MODEL_PATH);
  43631. + FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
  43632. + FileChannel fileChannel = inputStream.getChannel();
  43633. + long startOffset = fileDescriptor.getStartOffset();
  43634. + long declaredLength = fileDescriptor.getDeclaredLength();
  43635. + return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
  43636. + }
  43637. +
  43638. + private static ByteBuffer createRandomByteBuffer() {
  43639. + byte[] buffer = new byte[20];
  43640. + new Random().nextBytes(buffer);
  43641. + return ByteBuffer.wrap(buffer);
  43642. }
  43643. - int model = Model.endModel(builder);
  43644. - builder.finish(model, TFLITE_MODEL_IDENTIFIER);
  43645. -
  43646. - return builder.dataBuffer();
  43647. - }
  43648. -
  43649. - /** Creates an example model flatbuffer with the default metadata and data type. */
  43650. - private static ByteBuffer createModelByteBuffer() {
  43651. - ByteBuffer metadata =
  43652. - createMetadataByteBuffer(TFLITE_METADATA_IDENTIFIER, /*minVersionStr=*/ null);
  43653. - return createModelByteBuffer(metadata, DATA_TYPE);
  43654. - }
  43655. -
  43656. - private static ByteBuffer loadMobileNetBuffer() throws Exception {
  43657. - Context context = ApplicationProvider.getApplicationContext();
  43658. - // Loads a MobileNet model flatbuffer with metadata. The MobileNet model is a zip file that
  43659. - // contains a label file as the associated file.
  43660. - AssetFileDescriptor fileDescriptor = context.getAssets().openFd(MODEL_PATH);
  43661. - FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
  43662. - FileChannel fileChannel = inputStream.getChannel();
  43663. - long startOffset = fileDescriptor.getStartOffset();
  43664. - long declaredLength = fileDescriptor.getDeclaredLength();
  43665. - return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
  43666. - }
  43667. -
  43668. - private static ByteBuffer createRandomByteBuffer() {
  43669. - byte[] buffer = new byte[20];
  43670. - new Random().nextBytes(buffer);
  43671. - return ByteBuffer.wrap(buffer);
  43672. - }
  43673. }
  43674. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/MetadataParserTest.java b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/MetadataParserTest.java
  43675. index a47566fec06e9..eede6750ea479 100644
  43676. --- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/MetadataParserTest.java
  43677. +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/MetadataParserTest.java
  43678. @@ -17,20 +17,20 @@ package org.tensorflow.lite.support.metadata;
  43679. import static com.google.common.truth.Truth.assertThat;
  43680. -import java.util.regex.Pattern;
  43681. import org.junit.Test;
  43682. import org.junit.runner.RunWith;
  43683. import org.junit.runners.JUnit4;
  43684. +import java.util.regex.Pattern;
  43685. +
  43686. /** Tests of {@link MetadataParser}. */
  43687. @RunWith(JUnit4.class)
  43688. public final class MetadataParserTest {
  43689. -
  43690. - @Test
  43691. - public void version_wellFormedAsSemanticVersion() throws Exception {
  43692. - // Validates that the version is well-formed (x.y.z).
  43693. - String pattern = "[0-9]+\\.[0-9]+\\.[0-9]+";
  43694. - Pattern r = Pattern.compile(pattern);
  43695. - assertThat(MetadataParser.VERSION).matches(r);
  43696. - }
  43697. + @Test
  43698. + public void version_wellFormedAsSemanticVersion() throws Exception {
  43699. + // Validates that the version is well-formed (x.y.z).
  43700. + String pattern = "[0-9]+\\.[0-9]+\\.[0-9]+";
  43701. + Pattern r = Pattern.compile(pattern);
  43702. + assertThat(MetadataParser.VERSION).matches(r);
  43703. + }
  43704. }
  43705. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/ZipFileTest.java b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/ZipFileTest.java
  43706. index 61231e902e03e..80d2ddc6fd34e 100644
  43707. --- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/ZipFileTest.java
  43708. +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/ZipFileTest.java
  43709. @@ -16,11 +16,20 @@ limitations under the License.
  43710. package org.tensorflow.lite.support.metadata;
  43711. import static com.google.common.truth.Truth.assertThat;
  43712. +
  43713. import static org.junit.Assert.assertThrows;
  43714. import android.content.Context;
  43715. import android.content.res.AssetFileDescriptor;
  43716. +
  43717. import androidx.test.core.app.ApplicationProvider;
  43718. +
  43719. +import org.apache.commons.io.IOUtils;
  43720. +import org.junit.Ignore;
  43721. +import org.junit.Test;
  43722. +import org.junit.runner.RunWith;
  43723. +import org.robolectric.RobolectricTestRunner;
  43724. +
  43725. import java.io.FileInputStream;
  43726. import java.io.InputStream;
  43727. import java.nio.ByteBuffer;
  43728. @@ -28,113 +37,102 @@ import java.nio.channels.FileChannel;
  43729. import java.util.HashSet;
  43730. import java.util.Set;
  43731. import java.util.zip.ZipException;
  43732. -import org.apache.commons.io.IOUtils;
  43733. -import org.junit.Test;
  43734. -import org.junit.runner.RunWith;
  43735. -import org.robolectric.RobolectricTestRunner;
  43736. -
  43737. -import org.junit.Ignore;
  43738. /** Tests of {@link ZipFile}. */
  43739. @RunWith(RobolectricTestRunner.class)
  43740. public final class ZipFileTest {
  43741. -
  43742. - // The TFLite model file is a zip file.
  43743. - private static final String MODEL_PATH = "mobilenet_v1_1.0_224_quant.tflite";
  43744. - // labels.txt is packed in mobilenet_v1_1.0_224_quant.tflite as an associated file.
  43745. - private static final String VALID_LABEL_FILE_NAME = "labels.txt";
  43746. - // invalid.txt is not packed in mobilenet_v1_1.0_224_quant.tflite.
  43747. - private static final String INVALID_LABEL_FILE_NAME = "invalid.txt";
  43748. - private final Context context = ApplicationProvider.getApplicationContext();
  43749. -
  43750. - @Test
  43751. - public void zipFile_nullChannel_throwsException() throws Exception {
  43752. - NullPointerException exception =
  43753. - assertThrows(NullPointerException.class, () -> ZipFile.createFrom(null));
  43754. - assertThat(exception).hasMessageThat().isEqualTo("The object reference is null.");
  43755. - }
  43756. -
  43757. - @Test
  43758. - public void zipFile_invalidFileWithExtremeSmallSize_throwsException() throws Exception {
  43759. - // The size limit for a zip file is the End head size, ZipConstant.ENDHDR, which is 22.
  43760. - ByteBuffer modelBuffer = ByteBuffer.allocate(21);
  43761. - ByteBufferChannel modelChannel = new ByteBufferChannel(modelBuffer);
  43762. -
  43763. - ZipException exception =
  43764. - assertThrows(ZipException.class, () -> ZipFile.createFrom(modelChannel));
  43765. - assertThat(exception).hasMessageThat().isEqualTo("The archive is not a ZIP archive.");
  43766. - }
  43767. -
  43768. - @Test
  43769. - public void zipFile_invalidFileWithNoSignature_throwsException() throws Exception {
  43770. - // An invalid zip file that meets the size requirement but does not contain the zip signature.
  43771. - ByteBuffer modelBuffer = ByteBuffer.allocate(22);
  43772. - ByteBufferChannel modelChannel = new ByteBufferChannel(modelBuffer);
  43773. -
  43774. - ZipException exception =
  43775. - assertThrows(ZipException.class, () -> ZipFile.createFrom(modelChannel));
  43776. - assertThat(exception).hasMessageThat().isEqualTo("The archive is not a ZIP archive.");
  43777. - }
  43778. -
  43779. - @Ignore
  43780. - @Test
  43781. - public void getFileNames_correctFileName() throws Exception {
  43782. - ByteBufferChannel modelChannel = loadModel(MODEL_PATH);
  43783. - ZipFile zipFile = ZipFile.createFrom(modelChannel);
  43784. - Set<String> expectedSet = new HashSet<>();
  43785. - expectedSet.add(VALID_LABEL_FILE_NAME);
  43786. - assertThat(zipFile.getFileNames()).isEqualTo(expectedSet);
  43787. - }
  43788. -
  43789. - @Ignore
  43790. - @Test
  43791. - public void getRawInputStream_existentFile() throws Exception {
  43792. - ByteBufferChannel modelChannel = loadModel(MODEL_PATH);
  43793. - ZipFile zipFile = ZipFile.createFrom(modelChannel);
  43794. - InputStream fileStream = zipFile.getRawInputStream(VALID_LABEL_FILE_NAME);
  43795. -
  43796. - // Reads the golden file from context.
  43797. - InputStream goldenFileStream = context.getAssets().open(VALID_LABEL_FILE_NAME);
  43798. - assertThat(IOUtils.contentEquals(goldenFileStream, fileStream)).isTrue();
  43799. - }
  43800. -
  43801. - @Ignore
  43802. - @Test
  43803. - public void getRawInputStream_nonExistentFile() throws Exception {
  43804. - ByteBufferChannel modelChannel = loadModel(MODEL_PATH);
  43805. - ZipFile zipFile = ZipFile.createFrom(modelChannel);
  43806. -
  43807. - IllegalArgumentException exception =
  43808. - assertThrows(
  43809. - IllegalArgumentException.class,
  43810. - () -> zipFile.getRawInputStream(INVALID_LABEL_FILE_NAME));
  43811. - assertThat(exception)
  43812. - .hasMessageThat()
  43813. - .isEqualTo(
  43814. - String.format(
  43815. + // The TFLite model file is a zip file.
  43816. + private static final String MODEL_PATH = "mobilenet_v1_1.0_224_quant.tflite";
  43817. + // labels.txt is packed in mobilenet_v1_1.0_224_quant.tflite as an associated file.
  43818. + private static final String VALID_LABEL_FILE_NAME = "labels.txt";
  43819. + // invalid.txt is not packed in mobilenet_v1_1.0_224_quant.tflite.
  43820. + private static final String INVALID_LABEL_FILE_NAME = "invalid.txt";
  43821. + private final Context context = ApplicationProvider.getApplicationContext();
  43822. +
  43823. + @Test
  43824. + public void zipFile_nullChannel_throwsException() throws Exception {
  43825. + NullPointerException exception =
  43826. + assertThrows(NullPointerException.class, () -> ZipFile.createFrom(null));
  43827. + assertThat(exception).hasMessageThat().isEqualTo("The object reference is null.");
  43828. + }
  43829. +
  43830. + @Test
  43831. + public void zipFile_invalidFileWithExtremeSmallSize_throwsException() throws Exception {
  43832. + // The size limit for a zip file is the End head size, ZipConstant.ENDHDR, which is 22.
  43833. + ByteBuffer modelBuffer = ByteBuffer.allocate(21);
  43834. + ByteBufferChannel modelChannel = new ByteBufferChannel(modelBuffer);
  43835. +
  43836. + ZipException exception =
  43837. + assertThrows(ZipException.class, () -> ZipFile.createFrom(modelChannel));
  43838. + assertThat(exception).hasMessageThat().isEqualTo("The archive is not a ZIP archive.");
  43839. + }
  43840. +
  43841. + @Test
  43842. + public void zipFile_invalidFileWithNoSignature_throwsException() throws Exception {
  43843. + // An invalid zip file that meets the size requirement but does not contain the zip
  43844. + // signature.
  43845. + ByteBuffer modelBuffer = ByteBuffer.allocate(22);
  43846. + ByteBufferChannel modelChannel = new ByteBufferChannel(modelBuffer);
  43847. +
  43848. + ZipException exception =
  43849. + assertThrows(ZipException.class, () -> ZipFile.createFrom(modelChannel));
  43850. + assertThat(exception).hasMessageThat().isEqualTo("The archive is not a ZIP archive.");
  43851. + }
  43852. +
  43853. + @Ignore
  43854. + @Test
  43855. + public void getFileNames_correctFileName() throws Exception {
  43856. + ByteBufferChannel modelChannel = loadModel(MODEL_PATH);
  43857. + ZipFile zipFile = ZipFile.createFrom(modelChannel);
  43858. + Set<String> expectedSet = new HashSet<>();
  43859. + expectedSet.add(VALID_LABEL_FILE_NAME);
  43860. + assertThat(zipFile.getFileNames()).isEqualTo(expectedSet);
  43861. + }
  43862. +
  43863. + @Ignore
  43864. + @Test
  43865. + public void getRawInputStream_existentFile() throws Exception {
  43866. + ByteBufferChannel modelChannel = loadModel(MODEL_PATH);
  43867. + ZipFile zipFile = ZipFile.createFrom(modelChannel);
  43868. + InputStream fileStream = zipFile.getRawInputStream(VALID_LABEL_FILE_NAME);
  43869. +
  43870. + // Reads the golden file from context.
  43871. + InputStream goldenFileStream = context.getAssets().open(VALID_LABEL_FILE_NAME);
  43872. + assertThat(IOUtils.contentEquals(goldenFileStream, fileStream)).isTrue();
  43873. + }
  43874. +
  43875. + @Ignore
  43876. + @Test
  43877. + public void getRawInputStream_nonExistentFile() throws Exception {
  43878. + ByteBufferChannel modelChannel = loadModel(MODEL_PATH);
  43879. + ZipFile zipFile = ZipFile.createFrom(modelChannel);
  43880. +
  43881. + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
  43882. + () -> zipFile.getRawInputStream(INVALID_LABEL_FILE_NAME));
  43883. + assertThat(exception).hasMessageThat().isEqualTo(String.format(
  43884. "The file, %s, does not exist in the zip file.", INVALID_LABEL_FILE_NAME));
  43885. - }
  43886. -
  43887. - @Ignore
  43888. - @Test
  43889. - public void close_validStatus() throws Exception {
  43890. - ByteBufferChannel modelChannel = loadModel(MODEL_PATH);
  43891. - ZipFile zipFile = ZipFile.createFrom(modelChannel);
  43892. - // Should do nothing (including not throwing an exception).
  43893. - zipFile.close();
  43894. - }
  43895. -
  43896. - private static ByteBufferChannel loadModel(String modelPath) throws Exception {
  43897. - // Creates a ZipFile with a TFLite model flatbuffer with metadata. The MobileNet
  43898. - // model is a zip file that contains a label file as the associated file.
  43899. - Context context = ApplicationProvider.getApplicationContext();
  43900. - AssetFileDescriptor fileDescriptor = context.getAssets().openFd(modelPath);
  43901. - FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
  43902. - FileChannel fileChannel = inputStream.getChannel();
  43903. - long startOffset = fileDescriptor.getStartOffset();
  43904. - long declaredLength = fileDescriptor.getDeclaredLength();
  43905. - ByteBuffer modelBuffer =
  43906. - fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
  43907. - return new ByteBufferChannel(modelBuffer);
  43908. - }
  43909. + }
  43910. +
  43911. + @Ignore
  43912. + @Test
  43913. + public void close_validStatus() throws Exception {
  43914. + ByteBufferChannel modelChannel = loadModel(MODEL_PATH);
  43915. + ZipFile zipFile = ZipFile.createFrom(modelChannel);
  43916. + // Should do nothing (including not throwing an exception).
  43917. + zipFile.close();
  43918. + }
  43919. +
  43920. + private static ByteBufferChannel loadModel(String modelPath) throws Exception {
  43921. + // Creates a ZipFile with a TFLite model flatbuffer with metadata. The MobileNet
  43922. + // model is a zip file that contains a label file as the associated file.
  43923. + Context context = ApplicationProvider.getApplicationContext();
  43924. + AssetFileDescriptor fileDescriptor = context.getAssets().openFd(modelPath);
  43925. + FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
  43926. + FileChannel fileChannel = inputStream.getChannel();
  43927. + long startOffset = fileDescriptor.getStartOffset();
  43928. + long declaredLength = fileDescriptor.getDeclaredLength();
  43929. + ByteBuffer modelBuffer =
  43930. + fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
  43931. + return new ByteBufferChannel(modelBuffer);
  43932. + }
  43933. }
  43934. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/ios/image/apis/GMLImage.h b/third_party/tflite_support/src/tensorflow_lite_support/odml/ios/image/apis/GMLImage.h
  43935. index 110186bb63a1b..18797d8135eb8 100644
  43936. --- a/third_party/tflite_support/src/tensorflow_lite_support/odml/ios/image/apis/GMLImage.h
  43937. +++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/ios/image/apis/GMLImage.h
  43938. @@ -19,7 +19,8 @@
  43939. NS_ASSUME_NONNULL_BEGIN
  43940. /** Types of image sources. */
  43941. -typedef NSInteger GMLImageSourceType NS_TYPED_ENUM NS_SWIFT_NAME(MLImageSourceType);
  43942. +typedef NSInteger GMLImageSourceType
  43943. + NS_TYPED_ENUM NS_SWIFT_NAME(MLImageSourceType);
  43944. /** Image source is a `UIImage`. */
  43945. static const GMLImageSourceType GMLImageSourceTypeImage = 0;
  43946. /** Image source is a `CVPixelBuffer`. */
  43947. @@ -38,8 +39,9 @@ NS_SWIFT_NAME(MLImage)
  43948. @property(nonatomic, readonly) CGFloat height;
  43949. /**
  43950. - * The display orientation of the image. If `imageSourceType` is `.image`, the default value is
  43951. - * `image.imageOrientation`; otherwise the default value is `.up`.
  43952. + * The display orientation of the image. If `imageSourceType` is `.image`, the
  43953. + * default value is `image.imageOrientation`; otherwise the default value is
  43954. + * `.up`.
  43955. */
  43956. @property(nonatomic) UIImageOrientation orientation;
  43957. @@ -47,30 +49,34 @@ NS_SWIFT_NAME(MLImage)
  43958. @property(nonatomic, readonly) GMLImageSourceType imageSourceType;
  43959. /** The source image. `nil` if `imageSourceType` is not `.image`. */
  43960. -@property(nonatomic, readonly, nullable) UIImage *image;
  43961. +@property(nonatomic, readonly, nullable) UIImage* image;
  43962. -/** The source pixel buffer. `nil` if `imageSourceType` is not `.pixelBuffer`. */
  43963. +/** The source pixel buffer. `nil` if `imageSourceType` is not `.pixelBuffer`.
  43964. + */
  43965. @property(nonatomic, readonly, nullable) CVPixelBufferRef pixelBuffer;
  43966. -/** The source sample buffer. `nil` if `imageSourceType` is not `.sampleBuffer`. */
  43967. +/** The source sample buffer. `nil` if `imageSourceType` is not `.sampleBuffer`.
  43968. + */
  43969. @property(nonatomic, readonly, nullable) CMSampleBufferRef sampleBuffer;
  43970. /**
  43971. * Initializes an `MLImage` object with the given image.
  43972. *
  43973. - * @param image The image to use as the source. Its `CGImage` property must not be `NULL`.
  43974. - * @return A new `MLImage` instance with the given image as the source. `nil` if the given `image`
  43975. - * is `nil` or invalid.
  43976. + * @param image The image to use as the source. Its `CGImage` property must not
  43977. + * be `NULL`.
  43978. + * @return A new `MLImage` instance with the given image as the source. `nil` if
  43979. + * the given `image` is `nil` or invalid.
  43980. */
  43981. -- (nullable instancetype)initWithImage:(UIImage *)image NS_DESIGNATED_INITIALIZER;
  43982. +- (nullable instancetype)initWithImage:(UIImage*)image
  43983. + NS_DESIGNATED_INITIALIZER;
  43984. /**
  43985. * Initializes an `MLImage` object with the given pixel buffer.
  43986. *
  43987. - * @param pixelBuffer The pixel buffer to use as the source. It will be retained by the new
  43988. - * `MLImage` instance for the duration of its lifecycle.
  43989. - * @return A new `MLImage` instance with the given pixel buffer as the source. `nil` if the given
  43990. - * pixel buffer is `nil` or invalid.
  43991. + * @param pixelBuffer The pixel buffer to use as the source. It will be retained
  43992. + * by the new `MLImage` instance for the duration of its lifecycle.
  43993. + * @return A new `MLImage` instance with the given pixel buffer as the source.
  43994. + * `nil` if the given pixel buffer is `nil` or invalid.
  43995. */
  43996. - (nullable instancetype)initWithPixelBuffer:(CVPixelBufferRef)pixelBuffer
  43997. NS_DESIGNATED_INITIALIZER;
  43998. @@ -78,12 +84,13 @@ NS_SWIFT_NAME(MLImage)
  43999. /**
  44000. * Initializes an `MLImage` object with the given sample buffer.
  44001. *
  44002. - * @param sampleBuffer The sample buffer to use as the source. It will be retained by the new
  44003. - * `MLImage` instance for the duration of its lifecycle. The sample buffer must be based on a
  44004. - * pixel buffer (not compressed data). In practice, it should be the video output of the camera
  44005. - * on an iOS device, not other arbitrary types of `CMSampleBuffer`s.
  44006. - * @return A new `MLImage` instance with the given sample buffer as the source. `nil` if the given
  44007. - * sample buffer is `nil` or invalid.
  44008. + * @param sampleBuffer The sample buffer to use as the source. It will be
  44009. + * retained by the new `MLImage` instance for the duration of its lifecycle. The
  44010. + * sample buffer must be based on a pixel buffer (not compressed data). In
  44011. + * practice, it should be the video output of the camera on an iOS device, not
  44012. + * other arbitrary types of `CMSampleBuffer`s.
  44013. + * @return A new `MLImage` instance with the given sample buffer as the source.
  44014. + * `nil` if the given sample buffer is `nil` or invalid.
  44015. */
  44016. - (nullable instancetype)initWithSampleBuffer:(CMSampleBufferRef)sampleBuffer
  44017. NS_DESIGNATED_INITIALIZER;
  44018. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/BitmapExtractor.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/BitmapExtractor.java
  44019. index a32fc24749e0c..59116a72a0533 100644
  44020. --- a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/BitmapExtractor.java
  44021. +++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/BitmapExtractor.java
  44022. @@ -24,28 +24,27 @@ import android.graphics.Bitmap;
  44023. * {@link IllegalArgumentException} will be thrown.
  44024. */
  44025. public final class BitmapExtractor {
  44026. -
  44027. - /**
  44028. - * Extracts a {@link android.graphics.Bitmap} from an {@link MlImage}.
  44029. - *
  44030. - * <p>Notice: Properties of the {@code image} like rotation will not take effects.
  44031. - *
  44032. - * @param image the image to extract {@link android.graphics.Bitmap} from.
  44033. - * @return the {@link android.graphics.Bitmap} stored in {@link MlImage}
  44034. - * @throws IllegalArgumentException when the extraction requires unsupported format or data type
  44035. - * conversions.
  44036. - */
  44037. - public static Bitmap extract(MlImage image) {
  44038. - ImageContainer imageContainer = image.getContainer(MlImage.STORAGE_TYPE_BITMAP);
  44039. - if (imageContainer != null) {
  44040. - return ((BitmapImageContainer) imageContainer).getBitmap();
  44041. - } else {
  44042. - // TODO(b/180504869): Support ByteBuffer -> Bitmap conversion.
  44043. - throw new IllegalArgumentException(
  44044. - "Extracting Bitmap from an MlImage created by objects other than Bitmap is not"
  44045. - + " supported");
  44046. + /**
  44047. + * Extracts a {@link android.graphics.Bitmap} from an {@link MlImage}.
  44048. + *
  44049. + * <p>Notice: Properties of the {@code image} like rotation will not take effects.
  44050. + *
  44051. + * @param image the image to extract {@link android.graphics.Bitmap} from.
  44052. + * @return the {@link android.graphics.Bitmap} stored in {@link MlImage}
  44053. + * @throws IllegalArgumentException when the extraction requires unsupported format or data type
  44054. + * conversions.
  44055. + */
  44056. + public static Bitmap extract(MlImage image) {
  44057. + ImageContainer imageContainer = image.getContainer(MlImage.STORAGE_TYPE_BITMAP);
  44058. + if (imageContainer != null) {
  44059. + return ((BitmapImageContainer) imageContainer).getBitmap();
  44060. + } else {
  44061. + // TODO(b/180504869): Support ByteBuffer -> Bitmap conversion.
  44062. + throw new IllegalArgumentException(
  44063. + "Extracting Bitmap from an MlImage created by objects other than Bitmap is not"
  44064. + + " supported");
  44065. + }
  44066. }
  44067. - }
  44068. - private BitmapExtractor() {}
  44069. + private BitmapExtractor() {}
  44070. }
  44071. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/BitmapImageContainer.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/BitmapImageContainer.java
  44072. index 77e63f0351449..b1b02f8e369ec 100644
  44073. --- a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/BitmapImageContainer.java
  44074. +++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/BitmapImageContainer.java
  44075. @@ -16,44 +16,44 @@ limitations under the License.
  44076. package com.google.android.odml.image;
  44077. import android.graphics.Bitmap;
  44078. +
  44079. import com.google.android.odml.image.MlImage.ImageFormat;
  44080. class BitmapImageContainer implements ImageContainer {
  44081. + private final Bitmap bitmap;
  44082. + private final ImageProperties properties;
  44083. +
  44084. + public BitmapImageContainer(Bitmap bitmap) {
  44085. + this.bitmap = bitmap;
  44086. + this.properties = ImageProperties.builder()
  44087. + .setImageFormat(convertFormatCode(bitmap.getConfig()))
  44088. + .setStorageType(MlImage.STORAGE_TYPE_BITMAP)
  44089. + .build();
  44090. + }
  44091. +
  44092. + public Bitmap getBitmap() {
  44093. + return bitmap;
  44094. + }
  44095. +
  44096. + @Override
  44097. + public ImageProperties getImageProperties() {
  44098. + return properties;
  44099. + }
  44100. +
  44101. + @Override
  44102. + public void close() {
  44103. + bitmap.recycle();
  44104. + }
  44105. - private final Bitmap bitmap;
  44106. - private final ImageProperties properties;
  44107. -
  44108. - public BitmapImageContainer(Bitmap bitmap) {
  44109. - this.bitmap = bitmap;
  44110. - this.properties = ImageProperties.builder()
  44111. - .setImageFormat(convertFormatCode(bitmap.getConfig()))
  44112. - .setStorageType(MlImage.STORAGE_TYPE_BITMAP)
  44113. - .build();
  44114. - }
  44115. -
  44116. - public Bitmap getBitmap() {
  44117. - return bitmap;
  44118. - }
  44119. -
  44120. - @Override
  44121. - public ImageProperties getImageProperties() {
  44122. - return properties;
  44123. - }
  44124. -
  44125. - @Override
  44126. - public void close() {
  44127. - bitmap.recycle();
  44128. - }
  44129. -
  44130. - @ImageFormat
  44131. - static int convertFormatCode(Bitmap.Config config) {
  44132. - switch (config) {
  44133. - case ALPHA_8:
  44134. - return MlImage.IMAGE_FORMAT_ALPHA;
  44135. - case ARGB_8888:
  44136. - return MlImage.IMAGE_FORMAT_RGBA;
  44137. - default:
  44138. - return MlImage.IMAGE_FORMAT_UNKNOWN;
  44139. + @ImageFormat
  44140. + static int convertFormatCode(Bitmap.Config config) {
  44141. + switch (config) {
  44142. + case ALPHA_8:
  44143. + return MlImage.IMAGE_FORMAT_ALPHA;
  44144. + case ARGB_8888:
  44145. + return MlImage.IMAGE_FORMAT_RGBA;
  44146. + default:
  44147. + return MlImage.IMAGE_FORMAT_UNKNOWN;
  44148. + }
  44149. }
  44150. - }
  44151. }
  44152. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/BitmapMlImageBuilder.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/BitmapMlImageBuilder.java
  44153. index fe9c35a8a6ede..6c4552bfdac3a 100644
  44154. --- a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/BitmapMlImageBuilder.java
  44155. +++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/BitmapMlImageBuilder.java
  44156. @@ -20,6 +20,7 @@ import android.graphics.Bitmap;
  44157. import android.graphics.Rect;
  44158. import android.net.Uri;
  44159. import android.provider.MediaStore;
  44160. +
  44161. import java.io.IOException;
  44162. /**
  44163. @@ -32,82 +33,76 @@ import java.io.IOException;
  44164. * <p>Use {@link BitmapExtractor} to get {@link android.graphics.Bitmap} you passed in.
  44165. */
  44166. public class BitmapMlImageBuilder {
  44167. + // Mandatory fields.
  44168. + private final Bitmap bitmap;
  44169. - // Mandatory fields.
  44170. - private final Bitmap bitmap;
  44171. -
  44172. - // Optional fields.
  44173. - private int rotation;
  44174. - private Rect roi;
  44175. - private long timestamp;
  44176. + // Optional fields.
  44177. + private int rotation;
  44178. + private Rect roi;
  44179. + private long timestamp;
  44180. - /**
  44181. - * Creates the builder with a mandatory {@link android.graphics.Bitmap}.
  44182. - *
  44183. - * <p>Also calls {@link #setRotation(int)} to set the optional properties. If not set, the values
  44184. - * will be set with default:
  44185. - *
  44186. - * <ul>
  44187. - * <li>rotation: 0
  44188. - * </ul>
  44189. - *
  44190. - * @param bitmap image data object.
  44191. - */
  44192. - public BitmapMlImageBuilder(Bitmap bitmap) {
  44193. - this.bitmap = bitmap;
  44194. - rotation = 0;
  44195. - roi = new Rect(0, 0, bitmap.getWidth(), bitmap.getHeight());
  44196. - timestamp = 0;
  44197. - }
  44198. + /**
  44199. + * Creates the builder with a mandatory {@link android.graphics.Bitmap}.
  44200. + *
  44201. + * <p>Also calls {@link #setRotation(int)} to set the optional properties. If not set, the
  44202. + * values will be set with default:
  44203. + *
  44204. + * <ul>
  44205. + * <li>rotation: 0
  44206. + * </ul>
  44207. + *
  44208. + * @param bitmap image data object.
  44209. + */
  44210. + public BitmapMlImageBuilder(Bitmap bitmap) {
  44211. + this.bitmap = bitmap;
  44212. + rotation = 0;
  44213. + roi = new Rect(0, 0, bitmap.getWidth(), bitmap.getHeight());
  44214. + timestamp = 0;
  44215. + }
  44216. - /**
  44217. - * Creates the builder to build {@link MlImage} from a file.
  44218. - *
  44219. - * <p>Also calls {@link #setRotation(int)} to set the optional properties. If not set, the values
  44220. - * will be set with default:
  44221. - *
  44222. - * <ul>
  44223. - * <li>rotation: 0
  44224. - * </ul>
  44225. - *
  44226. - * @param context the application context.
  44227. - * @param uri the path to the resource file.
  44228. - */
  44229. - public BitmapMlImageBuilder(Context context, Uri uri) throws IOException {
  44230. - this(MediaStore.Images.Media.getBitmap(context.getContentResolver(), uri));
  44231. - }
  44232. + /**
  44233. + * Creates the builder to build {@link MlImage} from a file.
  44234. + *
  44235. + * <p>Also calls {@link #setRotation(int)} to set the optional properties. If not set, the
  44236. + * values will be set with default:
  44237. + *
  44238. + * <ul>
  44239. + * <li>rotation: 0
  44240. + * </ul>
  44241. + *
  44242. + * @param context the application context.
  44243. + * @param uri the path to the resource file.
  44244. + */
  44245. + public BitmapMlImageBuilder(Context context, Uri uri) throws IOException {
  44246. + this(MediaStore.Images.Media.getBitmap(context.getContentResolver(), uri));
  44247. + }
  44248. - /**
  44249. - * Sets value for {@link MlImage#getRotation()}.
  44250. - *
  44251. - * @throws IllegalArgumentException if the rotation value is not 0, 90, 180 or 270.
  44252. - */
  44253. - public BitmapMlImageBuilder setRotation(int rotation) {
  44254. - MlImage.validateRotation(rotation);
  44255. - this.rotation = rotation;
  44256. - return this;
  44257. - }
  44258. + /**
  44259. + * Sets value for {@link MlImage#getRotation()}.
  44260. + *
  44261. + * @throws IllegalArgumentException if the rotation value is not 0, 90, 180 or 270.
  44262. + */
  44263. + public BitmapMlImageBuilder setRotation(int rotation) {
  44264. + MlImage.validateRotation(rotation);
  44265. + this.rotation = rotation;
  44266. + return this;
  44267. + }
  44268. - /** Sets value for {@link MlImage#getRoi()}. */
  44269. - BitmapMlImageBuilder setRoi(Rect roi) {
  44270. - this.roi = roi;
  44271. - return this;
  44272. - }
  44273. + /** Sets value for {@link MlImage#getRoi()}. */
  44274. + BitmapMlImageBuilder setRoi(Rect roi) {
  44275. + this.roi = roi;
  44276. + return this;
  44277. + }
  44278. - /** Sets value for {@link MlImage#getTimestamp()}. */
  44279. - BitmapMlImageBuilder setTimestamp(long timestamp) {
  44280. - this.timestamp = timestamp;
  44281. - return this;
  44282. - }
  44283. + /** Sets value for {@link MlImage#getTimestamp()}. */
  44284. + BitmapMlImageBuilder setTimestamp(long timestamp) {
  44285. + this.timestamp = timestamp;
  44286. + return this;
  44287. + }
  44288. - /** Builds an {@link MlImage} instance. */
  44289. - public MlImage build() {
  44290. - return new MlImage(
  44291. - new BitmapImageContainer(bitmap),
  44292. - rotation,
  44293. - roi,
  44294. - timestamp,
  44295. - bitmap.getWidth(),
  44296. - bitmap.getHeight());
  44297. - }
  44298. + /** Builds an {@link MlImage} instance. */
  44299. + public MlImage build() {
  44300. + return new MlImage(new BitmapImageContainer(bitmap), rotation, roi, timestamp,
  44301. + bitmap.getWidth(), bitmap.getHeight());
  44302. + }
  44303. }
  44304. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ByteBufferExtractor.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ByteBufferExtractor.java
  44305. index 7b86be6d1b533..d5861c8ca94ac 100644
  44306. --- a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ByteBufferExtractor.java
  44307. +++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ByteBufferExtractor.java
  44308. @@ -19,8 +19,10 @@ import android.graphics.Bitmap;
  44309. import android.graphics.Bitmap.Config;
  44310. import android.os.Build.VERSION;
  44311. import android.os.Build.VERSION_CODES;
  44312. +
  44313. import com.google.android.odml.image.MlImage.ImageFormat;
  44314. import com.google.auto.value.AutoValue;
  44315. +
  44316. import java.nio.ByteBuffer;
  44317. import java.nio.ByteOrder;
  44318. import java.util.Locale;
  44319. @@ -32,229 +34,234 @@ import java.util.Locale;
  44320. * otherwise {@link IllegalArgumentException} will be thrown.
  44321. */
  44322. public class ByteBufferExtractor {
  44323. -
  44324. - /**
  44325. - * Extracts a {@link ByteBuffer} from an {@link MlImage}.
  44326. - *
  44327. - * <p>The returned {@link ByteBuffer} is a read-only view, with the first available {@link
  44328. - * ImageProperties} whose storage type is {@code MlImage.STORAGE_TYPE_BYTEBUFFER}.
  44329. - *
  44330. - * @see MlImage#getContainedImageProperties()
  44331. - * @return A read-only {@link ByteBuffer}.
  44332. - * @throws IllegalArgumentException when the image doesn't contain a {@link ByteBuffer} storage.
  44333. - */
  44334. - public static ByteBuffer extract(MlImage image) {
  44335. - ImageContainer container = image.getContainer();
  44336. - switch (container.getImageProperties().getStorageType()) {
  44337. - case MlImage.STORAGE_TYPE_BYTEBUFFER:
  44338. - ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container;
  44339. - return byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer();
  44340. - default:
  44341. - throw new IllegalArgumentException(
  44342. - "Extract ByteBuffer from an MlImage created by objects other than Bytebuffer is not"
  44343. - + " supported");
  44344. - }
  44345. - }
  44346. -
  44347. - /**
  44348. - * Extracts a readonly {@link ByteBuffer} in given {@code targetFormat} from an {@link MlImage}.
  44349. - *
  44350. - * <p>Notice: Properties of the {@code image} like rotation will not take effects.
  44351. - *
  44352. - * <p>Format conversion spec:
  44353. - *
  44354. - * <ul>
  44355. - * <li>When extracting RGB images to RGBA format, A channel will always set to 255.
  44356. - * <li>When extracting RGBA images to RGB format, A channel will be dropped.
  44357. - * </ul>
  44358. - *
  44359. - * @param image the image to extract buffer from.
  44360. - * @param targetFormat the image format of the result bytebuffer.
  44361. - * @return the readonly {@link ByteBuffer} stored in {@link MlImage}
  44362. - * @throws IllegalArgumentException when the extraction requires unsupported format or data type
  44363. - * conversions.
  44364. - */
  44365. - static ByteBuffer extract(MlImage image, @ImageFormat int targetFormat) {
  44366. - ImageContainer container;
  44367. - ImageProperties byteBufferProperties =
  44368. - ImageProperties.builder()
  44369. - .setStorageType(MlImage.STORAGE_TYPE_BYTEBUFFER)
  44370. - .setImageFormat(targetFormat)
  44371. - .build();
  44372. - if ((container = image.getContainer(byteBufferProperties)) != null) {
  44373. - ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container;
  44374. - return byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer();
  44375. - } else if ((container = image.getContainer(MlImage.STORAGE_TYPE_BYTEBUFFER)) != null) {
  44376. - ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container;
  44377. - @ImageFormat int sourceFormat = byteBufferImageContainer.getImageFormat();
  44378. - return convertByteBuffer(byteBufferImageContainer.getByteBuffer(), sourceFormat, targetFormat)
  44379. - .asReadOnlyBuffer();
  44380. - } else if ((container = image.getContainer(MlImage.STORAGE_TYPE_BITMAP)) != null) {
  44381. - BitmapImageContainer bitmapImageContainer = (BitmapImageContainer) container;
  44382. - ByteBuffer byteBuffer =
  44383. - extractByteBufferFromBitmap(bitmapImageContainer.getBitmap(), targetFormat)
  44384. - .asReadOnlyBuffer();
  44385. - image.addContainer(new ByteBufferImageContainer(byteBuffer, targetFormat));
  44386. - return byteBuffer;
  44387. - } else {
  44388. - throw new IllegalArgumentException(
  44389. - "Extracting ByteBuffer from an MlImage created by objects other than Bitmap or"
  44390. - + " Bytebuffer is not supported");
  44391. - }
  44392. - }
  44393. -
  44394. - /** A wrapper for a {@link ByteBuffer} and its {@link ImageFormat}. */
  44395. - @AutoValue
  44396. - abstract static class Result {
  44397. /**
  44398. - * Gets the {@link ByteBuffer} in the result of {@link ByteBufferExtractor#extract(MlImage)}.
  44399. + * Extracts a {@link ByteBuffer} from an {@link MlImage}.
  44400. + *
  44401. + * <p>The returned {@link ByteBuffer} is a read-only view, with the first available {@link
  44402. + * ImageProperties} whose storage type is {@code MlImage.STORAGE_TYPE_BYTEBUFFER}.
  44403. + *
  44404. + * @see MlImage#getContainedImageProperties()
  44405. + * @return A read-only {@link ByteBuffer}.
  44406. + * @throws IllegalArgumentException when the image doesn't contain a {@link ByteBuffer} storage.
  44407. */
  44408. - public abstract ByteBuffer buffer();
  44409. + public static ByteBuffer extract(MlImage image) {
  44410. + ImageContainer container = image.getContainer();
  44411. + switch (container.getImageProperties().getStorageType()) {
  44412. + case MlImage.STORAGE_TYPE_BYTEBUFFER:
  44413. + ByteBufferImageContainer byteBufferImageContainer =
  44414. + (ByteBufferImageContainer) container;
  44415. + return byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer();
  44416. + default:
  44417. + throw new IllegalArgumentException(
  44418. + "Extract ByteBuffer from an MlImage created by objects other than Bytebuffer is not"
  44419. + + " supported");
  44420. + }
  44421. + }
  44422. /**
  44423. - * Gets the {@link ImageFormat} in the result of {@link ByteBufferExtractor#extract(MlImage)}.
  44424. + * Extracts a readonly {@link ByteBuffer} in given {@code targetFormat} from an {@link MlImage}.
  44425. + *
  44426. + * <p>Notice: Properties of the {@code image} like rotation will not take effects.
  44427. + *
  44428. + * <p>Format conversion spec:
  44429. + *
  44430. + * <ul>
  44431. + * <li>When extracting RGB images to RGBA format, A channel will always set to 255.
  44432. + * <li>When extracting RGBA images to RGB format, A channel will be dropped.
  44433. + * </ul>
  44434. + *
  44435. + * @param image the image to extract buffer from.
  44436. + * @param targetFormat the image format of the result bytebuffer.
  44437. + * @return the readonly {@link ByteBuffer} stored in {@link MlImage}
  44438. + * @throws IllegalArgumentException when the extraction requires unsupported format or data type
  44439. + * conversions.
  44440. */
  44441. - @ImageFormat
  44442. - public abstract int format();
  44443. -
  44444. - static Result create(ByteBuffer buffer, @ImageFormat int imageFormat) {
  44445. - return new AutoValue_ByteBufferExtractor_Result(buffer, imageFormat);
  44446. + static ByteBuffer extract(MlImage image, @ImageFormat int targetFormat) {
  44447. + ImageContainer container;
  44448. + ImageProperties byteBufferProperties =
  44449. + ImageProperties.builder()
  44450. + .setStorageType(MlImage.STORAGE_TYPE_BYTEBUFFER)
  44451. + .setImageFormat(targetFormat)
  44452. + .build();
  44453. + if ((container = image.getContainer(byteBufferProperties)) != null) {
  44454. + ByteBufferImageContainer byteBufferImageContainer =
  44455. + (ByteBufferImageContainer) container;
  44456. + return byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer();
  44457. + } else if ((container = image.getContainer(MlImage.STORAGE_TYPE_BYTEBUFFER)) != null) {
  44458. + ByteBufferImageContainer byteBufferImageContainer =
  44459. + (ByteBufferImageContainer) container;
  44460. + @ImageFormat
  44461. + int sourceFormat = byteBufferImageContainer.getImageFormat();
  44462. + return convertByteBuffer(
  44463. + byteBufferImageContainer.getByteBuffer(), sourceFormat, targetFormat)
  44464. + .asReadOnlyBuffer();
  44465. + } else if ((container = image.getContainer(MlImage.STORAGE_TYPE_BITMAP)) != null) {
  44466. + BitmapImageContainer bitmapImageContainer = (BitmapImageContainer) container;
  44467. + ByteBuffer byteBuffer =
  44468. + extractByteBufferFromBitmap(bitmapImageContainer.getBitmap(), targetFormat)
  44469. + .asReadOnlyBuffer();
  44470. + image.addContainer(new ByteBufferImageContainer(byteBuffer, targetFormat));
  44471. + return byteBuffer;
  44472. + } else {
  44473. + throw new IllegalArgumentException(
  44474. + "Extracting ByteBuffer from an MlImage created by objects other than Bitmap or"
  44475. + + " Bytebuffer is not supported");
  44476. + }
  44477. }
  44478. - }
  44479. - /**
  44480. - * Extracts a {@link ByteBuffer} in any available {@code imageFormat} from an {@link MlImage}.
  44481. - *
  44482. - * <p>It will make the best effort to return an already existed {@link ByteBuffer} to avoid copy.
  44483. - *
  44484. - * <p>Notice: Properties of the {@code image} like rotation will not take effects.
  44485. - *
  44486. - * @return the readonly {@link ByteBuffer} stored in {@link MlImage}
  44487. - * @throws IllegalArgumentException when {@code image} doesn't contain {@link ByteBuffer} with
  44488. - * given {@code imageFormat}
  44489. - */
  44490. - static Result extractInRecommendedFormat(MlImage image) {
  44491. - ImageContainer container;
  44492. - if ((container = image.getContainer(MlImage.STORAGE_TYPE_BITMAP)) != null) {
  44493. - Bitmap bitmap = ((BitmapImageContainer) container).getBitmap();
  44494. - @ImageFormat int format = adviseImageFormat(bitmap);
  44495. - Result result =
  44496. - Result.create(extractByteBufferFromBitmap(bitmap, format).asReadOnlyBuffer(), format);
  44497. + /** A wrapper for a {@link ByteBuffer} and its {@link ImageFormat}. */
  44498. + @AutoValue
  44499. + abstract static class Result {
  44500. + /**
  44501. + * Gets the {@link ByteBuffer} in the result of {@link
  44502. + * ByteBufferExtractor#extract(MlImage)}.
  44503. + */
  44504. + public abstract ByteBuffer buffer();
  44505. - image.addContainer(new ByteBufferImageContainer(result.buffer(), result.format()));
  44506. - return result;
  44507. - } else if ((container = image.getContainer(MlImage.STORAGE_TYPE_BYTEBUFFER)) != null) {
  44508. - ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container;
  44509. - return Result.create(
  44510. - byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer(),
  44511. - byteBufferImageContainer.getImageFormat());
  44512. - } else {
  44513. - throw new IllegalArgumentException(
  44514. - "Extract ByteBuffer from an MlImage created by objects other than Bitmap or Bytebuffer"
  44515. - + " is not supported");
  44516. + /**
  44517. + * Gets the {@link ImageFormat} in the result of {@link
  44518. + * ByteBufferExtractor#extract(MlImage)}.
  44519. + */
  44520. + @ImageFormat
  44521. + public abstract int format();
  44522. +
  44523. + static Result create(ByteBuffer buffer, @ImageFormat int imageFormat) {
  44524. + return new AutoValue_ByteBufferExtractor_Result(buffer, imageFormat);
  44525. + }
  44526. }
  44527. - }
  44528. - @ImageFormat
  44529. - private static int adviseImageFormat(Bitmap bitmap) {
  44530. - if (bitmap.getConfig() == Config.ARGB_8888) {
  44531. - return MlImage.IMAGE_FORMAT_RGBA;
  44532. - } else {
  44533. - throw new IllegalArgumentException(
  44534. - String.format(
  44535. - "Extracting ByteBuffer from an MlImage created by a Bitmap in config %s is not"
  44536. - + " supported",
  44537. - bitmap.getConfig()));
  44538. + /**
  44539. + * Extracts a {@link ByteBuffer} in any available {@code imageFormat} from an {@link MlImage}.
  44540. + *
  44541. + * <p>It will make the best effort to return an already existed {@link ByteBuffer} to avoid
  44542. + * copy.
  44543. + *
  44544. + * <p>Notice: Properties of the {@code image} like rotation will not take effects.
  44545. + *
  44546. + * @return the readonly {@link ByteBuffer} stored in {@link MlImage}
  44547. + * @throws IllegalArgumentException when {@code image} doesn't contain {@link ByteBuffer} with
  44548. + * given {@code imageFormat}
  44549. + */
  44550. + static Result extractInRecommendedFormat(MlImage image) {
  44551. + ImageContainer container;
  44552. + if ((container = image.getContainer(MlImage.STORAGE_TYPE_BITMAP)) != null) {
  44553. + Bitmap bitmap = ((BitmapImageContainer) container).getBitmap();
  44554. + @ImageFormat
  44555. + int format = adviseImageFormat(bitmap);
  44556. + Result result = Result.create(
  44557. + extractByteBufferFromBitmap(bitmap, format).asReadOnlyBuffer(), format);
  44558. +
  44559. + image.addContainer(new ByteBufferImageContainer(result.buffer(), result.format()));
  44560. + return result;
  44561. + } else if ((container = image.getContainer(MlImage.STORAGE_TYPE_BYTEBUFFER)) != null) {
  44562. + ByteBufferImageContainer byteBufferImageContainer =
  44563. + (ByteBufferImageContainer) container;
  44564. + return Result.create(byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer(),
  44565. + byteBufferImageContainer.getImageFormat());
  44566. + } else {
  44567. + throw new IllegalArgumentException(
  44568. + "Extract ByteBuffer from an MlImage created by objects other than Bitmap or Bytebuffer"
  44569. + + " is not supported");
  44570. + }
  44571. }
  44572. - }
  44573. - private static ByteBuffer extractByteBufferFromBitmap(
  44574. - Bitmap bitmap, @ImageFormat int imageFormat) {
  44575. - if (VERSION.SDK_INT >= VERSION_CODES.JELLY_BEAN_MR1 && bitmap.isPremultiplied()) {
  44576. - throw new IllegalArgumentException(
  44577. - "Extracting ByteBuffer from an MlImage created by a premultiplied Bitmap is not"
  44578. - + " supported");
  44579. + @ImageFormat
  44580. + private static int adviseImageFormat(Bitmap bitmap) {
  44581. + if (bitmap.getConfig() == Config.ARGB_8888) {
  44582. + return MlImage.IMAGE_FORMAT_RGBA;
  44583. + } else {
  44584. + throw new IllegalArgumentException(String.format(
  44585. + "Extracting ByteBuffer from an MlImage created by a Bitmap in config %s is not"
  44586. + + " supported",
  44587. + bitmap.getConfig()));
  44588. + }
  44589. }
  44590. - if (bitmap.getConfig() == Config.ARGB_8888) {
  44591. - if (imageFormat == MlImage.IMAGE_FORMAT_RGBA) {
  44592. - ByteBuffer buffer = ByteBuffer.allocateDirect(bitmap.getByteCount());
  44593. - bitmap.copyPixelsToBuffer(buffer);
  44594. - buffer.rewind();
  44595. - return buffer;
  44596. - } else if (imageFormat == MlImage.IMAGE_FORMAT_RGB) {
  44597. - // TODO(b/180504869): Try Use RGBA buffer to create RGB buffer which might be faster.
  44598. - int w = bitmap.getWidth();
  44599. - int h = bitmap.getHeight();
  44600. - int[] pixels = new int[w * h];
  44601. - bitmap.getPixels(pixels, 0, w, 0, 0, w, h);
  44602. - ByteBuffer buffer = ByteBuffer.allocateDirect(w * h * 3);
  44603. - buffer.order(ByteOrder.nativeOrder());
  44604. - for (int pixel : pixels) {
  44605. - // getPixels returns Color in ARGB rather than copyPixelsToBuffer which returns RGBA
  44606. - buffer.put((byte) ((pixel >> 16) & 0xff));
  44607. - buffer.put((byte) ((pixel >> 8) & 0xff));
  44608. - buffer.put((byte) (pixel & 0xff));
  44609. +
  44610. + private static ByteBuffer extractByteBufferFromBitmap(
  44611. + Bitmap bitmap, @ImageFormat int imageFormat) {
  44612. + if (VERSION.SDK_INT >= VERSION_CODES.JELLY_BEAN_MR1 && bitmap.isPremultiplied()) {
  44613. + throw new IllegalArgumentException(
  44614. + "Extracting ByteBuffer from an MlImage created by a premultiplied Bitmap is not"
  44615. + + " supported");
  44616. }
  44617. - buffer.rewind();
  44618. - return buffer;
  44619. - }
  44620. + if (bitmap.getConfig() == Config.ARGB_8888) {
  44621. + if (imageFormat == MlImage.IMAGE_FORMAT_RGBA) {
  44622. + ByteBuffer buffer = ByteBuffer.allocateDirect(bitmap.getByteCount());
  44623. + bitmap.copyPixelsToBuffer(buffer);
  44624. + buffer.rewind();
  44625. + return buffer;
  44626. + } else if (imageFormat == MlImage.IMAGE_FORMAT_RGB) {
  44627. + // TODO(b/180504869): Try Use RGBA buffer to create RGB buffer which might be
  44628. + // faster.
  44629. + int w = bitmap.getWidth();
  44630. + int h = bitmap.getHeight();
  44631. + int[] pixels = new int[w * h];
  44632. + bitmap.getPixels(pixels, 0, w, 0, 0, w, h);
  44633. + ByteBuffer buffer = ByteBuffer.allocateDirect(w * h * 3);
  44634. + buffer.order(ByteOrder.nativeOrder());
  44635. + for (int pixel : pixels) {
  44636. + // getPixels returns Color in ARGB rather than copyPixelsToBuffer which returns
  44637. + // RGBA
  44638. + buffer.put((byte) ((pixel >> 16) & 0xff));
  44639. + buffer.put((byte) ((pixel >> 8) & 0xff));
  44640. + buffer.put((byte) (pixel & 0xff));
  44641. + }
  44642. + buffer.rewind();
  44643. + return buffer;
  44644. + }
  44645. + }
  44646. + throw new IllegalArgumentException(String.format(
  44647. + "Extracting ByteBuffer from an MlImage created by Bitmap and convert from %s to format"
  44648. + + " %d is not supported",
  44649. + bitmap.getConfig(), imageFormat));
  44650. }
  44651. - throw new IllegalArgumentException(
  44652. - String.format(
  44653. - "Extracting ByteBuffer from an MlImage created by Bitmap and convert from %s to format"
  44654. - + " %d is not supported",
  44655. - bitmap.getConfig(), imageFormat));
  44656. - }
  44657. - private static ByteBuffer convertByteBuffer(
  44658. - ByteBuffer source, @ImageFormat int sourceFormat, @ImageFormat int targetFormat) {
  44659. - if (sourceFormat == MlImage.IMAGE_FORMAT_RGB && targetFormat == MlImage.IMAGE_FORMAT_RGBA) {
  44660. - ByteBuffer target = ByteBuffer.allocateDirect(source.capacity() / 3 * 4);
  44661. - // Extend the buffer when the target is longer than the source. Use two cursors and sweep the
  44662. - // array reversely to convert in-place.
  44663. - byte[] array = new byte[target.capacity()];
  44664. - source.get(array, 0, source.capacity());
  44665. - source.rewind();
  44666. - int rgbCursor = source.capacity();
  44667. - int rgbaCursor = target.capacity();
  44668. - while (rgbCursor != rgbaCursor) {
  44669. - array[--rgbaCursor] = (byte) 0xff; // A
  44670. - array[--rgbaCursor] = array[--rgbCursor]; // B
  44671. - array[--rgbaCursor] = array[--rgbCursor]; // G
  44672. - array[--rgbaCursor] = array[--rgbCursor]; // R
  44673. - }
  44674. - target.put(array, 0, target.capacity());
  44675. - target.rewind();
  44676. - return target;
  44677. - } else if (sourceFormat == MlImage.IMAGE_FORMAT_RGBA
  44678. - && targetFormat == MlImage.IMAGE_FORMAT_RGB) {
  44679. - ByteBuffer target = ByteBuffer.allocateDirect(source.capacity() / 4 * 3);
  44680. - // Shrink the buffer when the target is shorter than the source. Use two cursors and sweep the
  44681. - // array to convert in-place.
  44682. - byte[] array = new byte[source.capacity()];
  44683. - source.get(array, 0, source.capacity());
  44684. - source.rewind();
  44685. - int rgbaCursor = 0;
  44686. - int rgbCursor = 0;
  44687. - while (rgbaCursor < array.length) {
  44688. - array[rgbCursor++] = array[rgbaCursor++]; // R
  44689. - array[rgbCursor++] = array[rgbaCursor++]; // G
  44690. - array[rgbCursor++] = array[rgbaCursor++]; // B
  44691. - rgbaCursor++;
  44692. - }
  44693. - target.put(array, 0, target.capacity());
  44694. - target.rewind();
  44695. - return target;
  44696. - } else {
  44697. - throw new IllegalArgumentException(
  44698. - String.format(
  44699. - Locale.ENGLISH,
  44700. - "Convert bytebuffer image format from %d to %d is not supported",
  44701. - sourceFormat,
  44702. - targetFormat));
  44703. + private static ByteBuffer convertByteBuffer(
  44704. + ByteBuffer source, @ImageFormat int sourceFormat, @ImageFormat int targetFormat) {
  44705. + if (sourceFormat == MlImage.IMAGE_FORMAT_RGB && targetFormat == MlImage.IMAGE_FORMAT_RGBA) {
  44706. + ByteBuffer target = ByteBuffer.allocateDirect(source.capacity() / 3 * 4);
  44707. + // Extend the buffer when the target is longer than the source. Use two cursors and
  44708. + // sweep the array reversely to convert in-place.
  44709. + byte[] array = new byte[target.capacity()];
  44710. + source.get(array, 0, source.capacity());
  44711. + source.rewind();
  44712. + int rgbCursor = source.capacity();
  44713. + int rgbaCursor = target.capacity();
  44714. + while (rgbCursor != rgbaCursor) {
  44715. + array[--rgbaCursor] = (byte) 0xff; // A
  44716. + array[--rgbaCursor] = array[--rgbCursor]; // B
  44717. + array[--rgbaCursor] = array[--rgbCursor]; // G
  44718. + array[--rgbaCursor] = array[--rgbCursor]; // R
  44719. + }
  44720. + target.put(array, 0, target.capacity());
  44721. + target.rewind();
  44722. + return target;
  44723. + } else if (sourceFormat == MlImage.IMAGE_FORMAT_RGBA
  44724. + && targetFormat == MlImage.IMAGE_FORMAT_RGB) {
  44725. + ByteBuffer target = ByteBuffer.allocateDirect(source.capacity() / 4 * 3);
  44726. + // Shrink the buffer when the target is shorter than the source. Use two cursors and
  44727. + // sweep the array to convert in-place.
  44728. + byte[] array = new byte[source.capacity()];
  44729. + source.get(array, 0, source.capacity());
  44730. + source.rewind();
  44731. + int rgbaCursor = 0;
  44732. + int rgbCursor = 0;
  44733. + while (rgbaCursor < array.length) {
  44734. + array[rgbCursor++] = array[rgbaCursor++]; // R
  44735. + array[rgbCursor++] = array[rgbaCursor++]; // G
  44736. + array[rgbCursor++] = array[rgbaCursor++]; // B
  44737. + rgbaCursor++;
  44738. + }
  44739. + target.put(array, 0, target.capacity());
  44740. + target.rewind();
  44741. + return target;
  44742. + } else {
  44743. + throw new IllegalArgumentException(String.format(Locale.ENGLISH,
  44744. + "Convert bytebuffer image format from %d to %d is not supported", sourceFormat,
  44745. + targetFormat));
  44746. + }
  44747. }
  44748. - }
  44749. - // ByteBuffer is not able to be instantiated.
  44750. - private ByteBufferExtractor() {}
  44751. + // ByteBuffer is not able to be instantiated.
  44752. + private ByteBufferExtractor() {}
  44753. }
  44754. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ByteBufferImageContainer.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ByteBufferImageContainer.java
  44755. index 9fbc3cbb94994..f872db485a8a2 100644
  44756. --- a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ByteBufferImageContainer.java
  44757. +++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ByteBufferImageContainer.java
  44758. @@ -16,42 +16,40 @@ limitations under the License.
  44759. package com.google.android.odml.image;
  44760. import com.google.android.odml.image.MlImage.ImageFormat;
  44761. +
  44762. import java.nio.ByteBuffer;
  44763. class ByteBufferImageContainer implements ImageContainer {
  44764. -
  44765. - private final ByteBuffer buffer;
  44766. - private final ImageProperties properties;
  44767. -
  44768. - public ByteBufferImageContainer(
  44769. - ByteBuffer buffer,
  44770. - @ImageFormat int imageFormat) {
  44771. - this.buffer = buffer;
  44772. - this.properties = ImageProperties.builder()
  44773. - .setStorageType(MlImage.STORAGE_TYPE_BYTEBUFFER)
  44774. - .setImageFormat(imageFormat)
  44775. - .build();
  44776. - }
  44777. -
  44778. - public ByteBuffer getByteBuffer() {
  44779. - return buffer;
  44780. - }
  44781. -
  44782. - @Override
  44783. - public ImageProperties getImageProperties() {
  44784. - return properties;
  44785. - }
  44786. -
  44787. - /**
  44788. - * Returns the image format.
  44789. - */
  44790. - @ImageFormat
  44791. - public int getImageFormat() {
  44792. - return properties.getImageFormat();
  44793. - }
  44794. -
  44795. - @Override
  44796. - public void close() {
  44797. - // No op for ByteBuffer.
  44798. - }
  44799. + private final ByteBuffer buffer;
  44800. + private final ImageProperties properties;
  44801. +
  44802. + public ByteBufferImageContainer(ByteBuffer buffer, @ImageFormat int imageFormat) {
  44803. + this.buffer = buffer;
  44804. + this.properties = ImageProperties.builder()
  44805. + .setStorageType(MlImage.STORAGE_TYPE_BYTEBUFFER)
  44806. + .setImageFormat(imageFormat)
  44807. + .build();
  44808. + }
  44809. +
  44810. + public ByteBuffer getByteBuffer() {
  44811. + return buffer;
  44812. + }
  44813. +
  44814. + @Override
  44815. + public ImageProperties getImageProperties() {
  44816. + return properties;
  44817. + }
  44818. +
  44819. + /**
  44820. + * Returns the image format.
  44821. + */
  44822. + @ImageFormat
  44823. + public int getImageFormat() {
  44824. + return properties.getImageFormat();
  44825. + }
  44826. +
  44827. + @Override
  44828. + public void close() {
  44829. + // No op for ByteBuffer.
  44830. + }
  44831. }
  44832. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ByteBufferMlImageBuilder.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ByteBufferMlImageBuilder.java
  44833. index 421e2b8f0de31..f4b0b31dd5e3b 100644
  44834. --- a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ByteBufferMlImageBuilder.java
  44835. +++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ByteBufferMlImageBuilder.java
  44836. @@ -16,7 +16,9 @@ limitations under the License.
  44837. package com.google.android.odml.image;
  44838. import android.graphics.Rect;
  44839. +
  44840. import com.google.android.odml.image.MlImage.ImageFormat;
  44841. +
  44842. import java.nio.ByteBuffer;
  44843. /**
  44844. @@ -28,79 +30,74 @@ import java.nio.ByteBuffer;
  44845. * <p>Use {@link ByteBufferExtractor} to get {@link ByteBuffer} you passed in.
  44846. */
  44847. public class ByteBufferMlImageBuilder {
  44848. + // Mandatory fields.
  44849. + private final ByteBuffer buffer;
  44850. + private final int width;
  44851. + private final int height;
  44852. + @ImageFormat
  44853. + private final int imageFormat;
  44854. - // Mandatory fields.
  44855. - private final ByteBuffer buffer;
  44856. - private final int width;
  44857. - private final int height;
  44858. - @ImageFormat private final int imageFormat;
  44859. -
  44860. - // Optional fields.
  44861. - private int rotation;
  44862. - private Rect roi;
  44863. - private long timestamp;
  44864. + // Optional fields.
  44865. + private int rotation;
  44866. + private Rect roi;
  44867. + private long timestamp;
  44868. - /**
  44869. - * Creates the builder with mandatory {@link ByteBuffer} and the represented image.
  44870. - *
  44871. - * <p>We will validate the size of the {@code byteBuffer} with given {@code width}, {@code height}
  44872. - * and {@code imageFormat}.
  44873. - *
  44874. - * <p>Also calls {@link #setRotation(int)} to set the optional properties. If not set, the values
  44875. - * will be set with default:
  44876. - *
  44877. - * <ul>
  44878. - * <li>rotation: 0
  44879. - * </ul>
  44880. - *
  44881. - * @param byteBuffer image data object.
  44882. - * @param width the width of the represented image.
  44883. - * @param height the height of the represented image.
  44884. - * @param imageFormat how the data encode the image.
  44885. - */
  44886. - public ByteBufferMlImageBuilder(
  44887. - ByteBuffer byteBuffer, int width, int height, @ImageFormat int imageFormat) {
  44888. - this.buffer = byteBuffer;
  44889. - this.width = width;
  44890. - this.height = height;
  44891. - this.imageFormat = imageFormat;
  44892. - // TODO(b/180504869): Validate bytebuffer size with width, height and image format
  44893. - this.rotation = 0;
  44894. - this.roi = new Rect(0, 0, width, height);
  44895. - this.timestamp = 0;
  44896. - }
  44897. + /**
  44898. + * Creates the builder with mandatory {@link ByteBuffer} and the represented image.
  44899. + *
  44900. + * <p>We will validate the size of the {@code byteBuffer} with given {@code width}, {@code
  44901. + * height} and {@code imageFormat}.
  44902. + *
  44903. + * <p>Also calls {@link #setRotation(int)} to set the optional properties. If not set, the
  44904. + * values will be set with default:
  44905. + *
  44906. + * <ul>
  44907. + * <li>rotation: 0
  44908. + * </ul>
  44909. + *
  44910. + * @param byteBuffer image data object.
  44911. + * @param width the width of the represented image.
  44912. + * @param height the height of the represented image.
  44913. + * @param imageFormat how the data encode the image.
  44914. + */
  44915. + public ByteBufferMlImageBuilder(
  44916. + ByteBuffer byteBuffer, int width, int height, @ImageFormat int imageFormat) {
  44917. + this.buffer = byteBuffer;
  44918. + this.width = width;
  44919. + this.height = height;
  44920. + this.imageFormat = imageFormat;
  44921. + // TODO(b/180504869): Validate bytebuffer size with width, height and image format
  44922. + this.rotation = 0;
  44923. + this.roi = new Rect(0, 0, width, height);
  44924. + this.timestamp = 0;
  44925. + }
  44926. - /**
  44927. - * Sets value for {@link MlImage#getRotation()}.
  44928. - *
  44929. - * @throws IllegalArgumentException if the rotation value is not 0, 90, 180 or 270.
  44930. - */
  44931. - public ByteBufferMlImageBuilder setRotation(int rotation) {
  44932. - MlImage.validateRotation(rotation);
  44933. - this.rotation = rotation;
  44934. - return this;
  44935. - }
  44936. + /**
  44937. + * Sets value for {@link MlImage#getRotation()}.
  44938. + *
  44939. + * @throws IllegalArgumentException if the rotation value is not 0, 90, 180 or 270.
  44940. + */
  44941. + public ByteBufferMlImageBuilder setRotation(int rotation) {
  44942. + MlImage.validateRotation(rotation);
  44943. + this.rotation = rotation;
  44944. + return this;
  44945. + }
  44946. - /** Sets value for {@link MlImage#getRoi()}. */
  44947. - ByteBufferMlImageBuilder setRoi(Rect roi) {
  44948. - this.roi = roi;
  44949. - return this;
  44950. - }
  44951. + /** Sets value for {@link MlImage#getRoi()}. */
  44952. + ByteBufferMlImageBuilder setRoi(Rect roi) {
  44953. + this.roi = roi;
  44954. + return this;
  44955. + }
  44956. - /** Sets value for {@link MlImage#getTimestamp()}. */
  44957. - ByteBufferMlImageBuilder setTimestamp(long timestamp) {
  44958. - this.timestamp = timestamp;
  44959. - return this;
  44960. - }
  44961. + /** Sets value for {@link MlImage#getTimestamp()}. */
  44962. + ByteBufferMlImageBuilder setTimestamp(long timestamp) {
  44963. + this.timestamp = timestamp;
  44964. + return this;
  44965. + }
  44966. - /** Builds an {@link MlImage} instance. */
  44967. - public MlImage build() {
  44968. - return new MlImage(
  44969. - new ByteBufferImageContainer(buffer, imageFormat),
  44970. - rotation,
  44971. - roi,
  44972. - timestamp,
  44973. - width,
  44974. - height);
  44975. - }
  44976. + /** Builds an {@link MlImage} instance. */
  44977. + public MlImage build() {
  44978. + return new MlImage(new ByteBufferImageContainer(buffer, imageFormat), rotation, roi,
  44979. + timestamp, width, height);
  44980. + }
  44981. }
  44982. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ImageContainer.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ImageContainer.java
  44983. index 25ed2312ce580..bfa7c0a292f4f 100644
  44984. --- a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ImageContainer.java
  44985. +++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ImageContainer.java
  44986. @@ -20,11 +20,11 @@ import com.google.android.odml.image.annotation.KeepForSdk;
  44987. /** Manages internal image data storage. The interface is package-private. */
  44988. @KeepForSdk
  44989. interface ImageContainer {
  44990. - /** Returns the properties of the contained image. */
  44991. - @KeepForSdk
  44992. - ImageProperties getImageProperties();
  44993. + /** Returns the properties of the contained image. */
  44994. + @KeepForSdk
  44995. + ImageProperties getImageProperties();
  44996. - /** Close the image container and releases the image resource inside. */
  44997. - @KeepForSdk
  44998. - void close();
  44999. + /** Close the image container and releases the image resource inside. */
  45000. + @KeepForSdk
  45001. + void close();
  45002. }
  45003. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ImageProperties.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ImageProperties.java
  45004. index 717bc5f9935ed..a61e97b81b872 100644
  45005. --- a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ImageProperties.java
  45006. +++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ImageProperties.java
  45007. @@ -24,63 +24,61 @@ import com.google.auto.value.extension.memoized.Memoized;
  45008. /** Groups a set of properties to describe how an image is stored. */
  45009. @AutoValue
  45010. public abstract class ImageProperties {
  45011. -
  45012. - /**
  45013. - * Gets the pixel format of the image.
  45014. - *
  45015. - * @see MlImage.ImageFormat
  45016. - */
  45017. - @ImageFormat
  45018. - public abstract int getImageFormat();
  45019. -
  45020. - /**
  45021. - * Gets the storage type of the image.
  45022. - *
  45023. - * @see MlImage.StorageType
  45024. - */
  45025. - @StorageType
  45026. - public abstract int getStorageType();
  45027. -
  45028. - @Memoized
  45029. - @Override
  45030. - public abstract int hashCode();
  45031. -
  45032. - /**
  45033. - * Creates a builder of {@link ImageProperties}.
  45034. - *
  45035. - * @see ImageProperties.Builder
  45036. - */
  45037. - @KeepForSdk
  45038. - static Builder builder() {
  45039. - return new AutoValue_ImageProperties.Builder();
  45040. - }
  45041. -
  45042. - /** Builds a {@link ImageProperties}. */
  45043. - @AutoValue.Builder
  45044. - @KeepForSdk
  45045. - abstract static class Builder {
  45046. + /**
  45047. + * Gets the pixel format of the image.
  45048. + *
  45049. + * @see MlImage.ImageFormat
  45050. + */
  45051. + @ImageFormat
  45052. + public abstract int getImageFormat();
  45053. /**
  45054. - * Sets the {@link MlImage.ImageFormat}.
  45055. + * Gets the storage type of the image.
  45056. *
  45057. - * @see ImageProperties#getImageFormat
  45058. + * @see MlImage.StorageType
  45059. */
  45060. - @KeepForSdk
  45061. - abstract Builder setImageFormat(@ImageFormat int value);
  45062. + @StorageType
  45063. + public abstract int getStorageType();
  45064. +
  45065. + @Memoized
  45066. + @Override
  45067. + public abstract int hashCode();
  45068. /**
  45069. - * Sets the {@link MlImage.StorageType}.
  45070. + * Creates a builder of {@link ImageProperties}.
  45071. *
  45072. - * @see ImageProperties#getStorageType
  45073. + * @see ImageProperties.Builder
  45074. */
  45075. @KeepForSdk
  45076. - abstract Builder setStorageType(@StorageType int value);
  45077. + static Builder builder() {
  45078. + return new AutoValue_ImageProperties.Builder();
  45079. + }
  45080. - /** Builds the {@link ImageProperties}. */
  45081. + /** Builds a {@link ImageProperties}. */
  45082. + @AutoValue.Builder
  45083. @KeepForSdk
  45084. - abstract ImageProperties build();
  45085. - }
  45086. + abstract static class Builder {
  45087. + /**
  45088. + * Sets the {@link MlImage.ImageFormat}.
  45089. + *
  45090. + * @see ImageProperties#getImageFormat
  45091. + */
  45092. + @KeepForSdk
  45093. + abstract Builder setImageFormat(@ImageFormat int value);
  45094. +
  45095. + /**
  45096. + * Sets the {@link MlImage.StorageType}.
  45097. + *
  45098. + * @see ImageProperties#getStorageType
  45099. + */
  45100. + @KeepForSdk
  45101. + abstract Builder setStorageType(@StorageType int value);
  45102. +
  45103. + /** Builds the {@link ImageProperties}. */
  45104. + @KeepForSdk
  45105. + abstract ImageProperties build();
  45106. + }
  45107. - // Hide the constructor.
  45108. - ImageProperties() {}
  45109. + // Hide the constructor.
  45110. + ImageProperties() {}
  45111. }
  45112. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/MediaImageContainer.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/MediaImageContainer.java
  45113. index 9365d0b2a422e..9ed88ee30c62f 100644
  45114. --- a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/MediaImageContainer.java
  45115. +++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/MediaImageContainer.java
  45116. @@ -19,55 +19,56 @@ import android.media.Image;
  45117. import android.os.Build;
  45118. import android.os.Build.VERSION;
  45119. import android.os.Build.VERSION_CODES;
  45120. +
  45121. import androidx.annotation.RequiresApi;
  45122. +
  45123. import com.google.android.odml.image.MlImage.ImageFormat;
  45124. @RequiresApi(VERSION_CODES.KITKAT)
  45125. class MediaImageContainer implements ImageContainer {
  45126. + private final Image mediaImage;
  45127. + private final ImageProperties properties;
  45128. - private final Image mediaImage;
  45129. - private final ImageProperties properties;
  45130. -
  45131. - public MediaImageContainer(Image mediaImage) {
  45132. - this.mediaImage = mediaImage;
  45133. - this.properties = ImageProperties.builder()
  45134. - .setStorageType(MlImage.STORAGE_TYPE_MEDIA_IMAGE)
  45135. - .setImageFormat(convertFormatCode(mediaImage.getFormat()))
  45136. - .build();
  45137. - }
  45138. -
  45139. - public Image getImage() {
  45140. - return mediaImage;
  45141. - }
  45142. + public MediaImageContainer(Image mediaImage) {
  45143. + this.mediaImage = mediaImage;
  45144. + this.properties = ImageProperties.builder()
  45145. + .setStorageType(MlImage.STORAGE_TYPE_MEDIA_IMAGE)
  45146. + .setImageFormat(convertFormatCode(mediaImage.getFormat()))
  45147. + .build();
  45148. + }
  45149. - @Override
  45150. - public ImageProperties getImageProperties() {
  45151. - return properties;
  45152. - }
  45153. + public Image getImage() {
  45154. + return mediaImage;
  45155. + }
  45156. - @Override
  45157. - public void close() {
  45158. - mediaImage.close();
  45159. - }
  45160. + @Override
  45161. + public ImageProperties getImageProperties() {
  45162. + return properties;
  45163. + }
  45164. - @ImageFormat
  45165. - static int convertFormatCode(int graphicsFormat) {
  45166. - // We only cover the format mentioned in
  45167. - // https://developer.android.com/reference/android/media/Image#getFormat()
  45168. - if (VERSION.SDK_INT >= Build.VERSION_CODES.M) {
  45169. - if (graphicsFormat == android.graphics.ImageFormat.FLEX_RGBA_8888) {
  45170. - return MlImage.IMAGE_FORMAT_RGBA;
  45171. - } else if (graphicsFormat == android.graphics.ImageFormat.FLEX_RGB_888) {
  45172. - return MlImage.IMAGE_FORMAT_RGB;
  45173. - }
  45174. + @Override
  45175. + public void close() {
  45176. + mediaImage.close();
  45177. }
  45178. - switch (graphicsFormat) {
  45179. - case android.graphics.ImageFormat.JPEG:
  45180. - return MlImage.IMAGE_FORMAT_JPEG;
  45181. - case android.graphics.ImageFormat.YUV_420_888:
  45182. - return MlImage.IMAGE_FORMAT_YUV_420_888;
  45183. - default:
  45184. - return MlImage.IMAGE_FORMAT_UNKNOWN;
  45185. +
  45186. + @ImageFormat
  45187. + static int convertFormatCode(int graphicsFormat) {
  45188. + // We only cover the format mentioned in
  45189. + // https://developer.android.com/reference/android/media/Image#getFormat()
  45190. + if (VERSION.SDK_INT >= Build.VERSION_CODES.M) {
  45191. + if (graphicsFormat == android.graphics.ImageFormat.FLEX_RGBA_8888) {
  45192. + return MlImage.IMAGE_FORMAT_RGBA;
  45193. + } else if (graphicsFormat == android.graphics.ImageFormat.FLEX_RGB_888) {
  45194. + return MlImage.IMAGE_FORMAT_RGB;
  45195. + }
  45196. + }
  45197. + switch (graphicsFormat) {
  45198. + case android.graphics.ImageFormat.JPEG:
  45199. + return MlImage.IMAGE_FORMAT_JPEG;
  45200. + case android.graphics.ImageFormat.YUV_420_888:
  45201. + return MlImage.IMAGE_FORMAT_YUV_420_888;
  45202. + default:
  45203. + return MlImage.IMAGE_FORMAT_UNKNOWN;
  45204. + }
  45205. }
  45206. - }
  45207. }
  45208. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/MediaImageExtractor.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/MediaImageExtractor.java
  45209. index 73aadabb38789..59ed98b569fa2 100644
  45210. --- a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/MediaImageExtractor.java
  45211. +++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/MediaImageExtractor.java
  45212. @@ -17,6 +17,7 @@ package com.google.android.odml.image;
  45213. import android.media.Image;
  45214. import android.os.Build.VERSION_CODES;
  45215. +
  45216. import androidx.annotation.RequiresApi;
  45217. /**
  45218. @@ -27,26 +28,25 @@ import androidx.annotation.RequiresApi;
  45219. */
  45220. @RequiresApi(VERSION_CODES.KITKAT)
  45221. public class MediaImageExtractor {
  45222. -
  45223. - private MediaImageExtractor() {}
  45224. -
  45225. - /**
  45226. - * Extracts a {@link android.media.Image} from an {@link MlImage}. Currently it only works for
  45227. - * {@link MlImage} that built from {@link MediaMlImageBuilder}.
  45228. - *
  45229. - * <p>Notice: Properties of the {@code image} like rotation will not take effects.
  45230. - *
  45231. - * @param image the image to extract {@link android.media.Image} from.
  45232. - * @return {@link android.media.Image} that stored in {@link MlImage}.
  45233. - * @throws IllegalArgumentException if the extraction failed.
  45234. - */
  45235. - public static Image extract(MlImage image) {
  45236. - ImageContainer container;
  45237. - if ((container = image.getContainer(MlImage.STORAGE_TYPE_MEDIA_IMAGE)) != null) {
  45238. - return ((MediaImageContainer) container).getImage();
  45239. + private MediaImageExtractor() {}
  45240. +
  45241. + /**
  45242. + * Extracts a {@link android.media.Image} from an {@link MlImage}. Currently it only works for
  45243. + * {@link MlImage} that built from {@link MediaMlImageBuilder}.
  45244. + *
  45245. + * <p>Notice: Properties of the {@code image} like rotation will not take effects.
  45246. + *
  45247. + * @param image the image to extract {@link android.media.Image} from.
  45248. + * @return {@link android.media.Image} that stored in {@link MlImage}.
  45249. + * @throws IllegalArgumentException if the extraction failed.
  45250. + */
  45251. + public static Image extract(MlImage image) {
  45252. + ImageContainer container;
  45253. + if ((container = image.getContainer(MlImage.STORAGE_TYPE_MEDIA_IMAGE)) != null) {
  45254. + return ((MediaImageContainer) container).getImage();
  45255. + }
  45256. + throw new IllegalArgumentException(
  45257. + "Extract Media Image from an MlImage created by objects other than Media Image"
  45258. + + " is not supported");
  45259. }
  45260. - throw new IllegalArgumentException(
  45261. - "Extract Media Image from an MlImage created by objects other than Media Image"
  45262. - + " is not supported");
  45263. - }
  45264. }
  45265. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/MediaMlImageBuilder.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/MediaMlImageBuilder.java
  45266. index e96ab38317bac..80771bdb91890 100644
  45267. --- a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/MediaMlImageBuilder.java
  45268. +++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/MediaMlImageBuilder.java
  45269. @@ -18,6 +18,7 @@ package com.google.android.odml.image;
  45270. import android.graphics.Rect;
  45271. import android.media.Image;
  45272. import android.os.Build.VERSION_CODES;
  45273. +
  45274. import androidx.annotation.RequiresApi;
  45275. /**
  45276. @@ -30,65 +31,59 @@ import androidx.annotation.RequiresApi;
  45277. */
  45278. @RequiresApi(VERSION_CODES.KITKAT)
  45279. public class MediaMlImageBuilder {
  45280. + // Mandatory fields.
  45281. + private final Image mediaImage;
  45282. - // Mandatory fields.
  45283. - private final Image mediaImage;
  45284. -
  45285. - // Optional fields.
  45286. - private int rotation;
  45287. - private Rect roi;
  45288. - private long timestamp;
  45289. + // Optional fields.
  45290. + private int rotation;
  45291. + private Rect roi;
  45292. + private long timestamp;
  45293. - /**
  45294. - * Creates the builder with a mandatory {@link android.media.Image}.
  45295. - *
  45296. - * <p>Also calls {@link #setRotation(int)} to set the optional properties. If not set, the values
  45297. - * will be set with default:
  45298. - *
  45299. - * <ul>
  45300. - * <li>rotation: 0
  45301. - * </ul>
  45302. - *
  45303. - * @param mediaImage image data object.
  45304. - */
  45305. - public MediaMlImageBuilder(Image mediaImage) {
  45306. - this.mediaImage = mediaImage;
  45307. - this.rotation = 0;
  45308. - this.roi = new Rect(0, 0, mediaImage.getWidth(), mediaImage.getHeight());
  45309. - this.timestamp = 0;
  45310. - }
  45311. + /**
  45312. + * Creates the builder with a mandatory {@link android.media.Image}.
  45313. + *
  45314. + * <p>Also calls {@link #setRotation(int)} to set the optional properties. If not set, the
  45315. + * values will be set with default:
  45316. + *
  45317. + * <ul>
  45318. + * <li>rotation: 0
  45319. + * </ul>
  45320. + *
  45321. + * @param mediaImage image data object.
  45322. + */
  45323. + public MediaMlImageBuilder(Image mediaImage) {
  45324. + this.mediaImage = mediaImage;
  45325. + this.rotation = 0;
  45326. + this.roi = new Rect(0, 0, mediaImage.getWidth(), mediaImage.getHeight());
  45327. + this.timestamp = 0;
  45328. + }
  45329. - /**
  45330. - * Sets value for {@link MlImage#getRotation()}.
  45331. - *
  45332. - * @throws IllegalArgumentException if the rotation value is not 0, 90, 180 or 270.
  45333. - */
  45334. - public MediaMlImageBuilder setRotation(int rotation) {
  45335. - MlImage.validateRotation(rotation);
  45336. - this.rotation = rotation;
  45337. - return this;
  45338. - }
  45339. + /**
  45340. + * Sets value for {@link MlImage#getRotation()}.
  45341. + *
  45342. + * @throws IllegalArgumentException if the rotation value is not 0, 90, 180 or 270.
  45343. + */
  45344. + public MediaMlImageBuilder setRotation(int rotation) {
  45345. + MlImage.validateRotation(rotation);
  45346. + this.rotation = rotation;
  45347. + return this;
  45348. + }
  45349. - /** Sets value for {@link MlImage#getRoi()}. */
  45350. - MediaMlImageBuilder setRoi(Rect roi) {
  45351. - this.roi = roi;
  45352. - return this;
  45353. - }
  45354. + /** Sets value for {@link MlImage#getRoi()}. */
  45355. + MediaMlImageBuilder setRoi(Rect roi) {
  45356. + this.roi = roi;
  45357. + return this;
  45358. + }
  45359. - /** Sets value for {@link MlImage#getTimestamp()}. */
  45360. - MediaMlImageBuilder setTimestamp(long timestamp) {
  45361. - this.timestamp = timestamp;
  45362. - return this;
  45363. - }
  45364. + /** Sets value for {@link MlImage#getTimestamp()}. */
  45365. + MediaMlImageBuilder setTimestamp(long timestamp) {
  45366. + this.timestamp = timestamp;
  45367. + return this;
  45368. + }
  45369. - /** Builds an {@link MlImage} instance. */
  45370. - public MlImage build() {
  45371. - return new MlImage(
  45372. - new MediaImageContainer(mediaImage),
  45373. - rotation,
  45374. - roi,
  45375. - timestamp,
  45376. - mediaImage.getWidth(),
  45377. - mediaImage.getHeight());
  45378. - }
  45379. + /** Builds an {@link MlImage} instance. */
  45380. + public MlImage build() {
  45381. + return new MlImage(new MediaImageContainer(mediaImage), rotation, roi, timestamp,
  45382. + mediaImage.getWidth(), mediaImage.getHeight());
  45383. + }
  45384. }
  45385. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/MlImage.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/MlImage.java
  45386. index 2ed3539de67f5..7e21e6ad428f2 100644
  45387. --- a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/MlImage.java
  45388. +++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/MlImage.java
  45389. @@ -16,9 +16,12 @@ limitations under the License.
  45390. package com.google.android.odml.image;
  45391. import android.graphics.Rect;
  45392. +
  45393. import androidx.annotation.IntDef;
  45394. import androidx.annotation.Nullable;
  45395. +
  45396. import com.google.android.odml.image.annotation.KeepForSdk;
  45397. +
  45398. import java.io.Closeable;
  45399. import java.lang.annotation.Retention;
  45400. import java.lang.annotation.RetentionPolicy;
  45401. @@ -62,228 +65,232 @@ import java.util.Map.Entry;
  45402. * and multiple storages.
  45403. */
  45404. public class MlImage implements Closeable {
  45405. + /** Specifies the image format of an image. */
  45406. + @IntDef({
  45407. + IMAGE_FORMAT_UNKNOWN,
  45408. + IMAGE_FORMAT_RGBA,
  45409. + IMAGE_FORMAT_RGB,
  45410. + IMAGE_FORMAT_NV12,
  45411. + IMAGE_FORMAT_NV21,
  45412. + IMAGE_FORMAT_YV12,
  45413. + IMAGE_FORMAT_YV21,
  45414. + IMAGE_FORMAT_YUV_420_888,
  45415. + IMAGE_FORMAT_ALPHA,
  45416. + IMAGE_FORMAT_JPEG,
  45417. + })
  45418. + @Retention(RetentionPolicy.SOURCE)
  45419. + public @interface ImageFormat {}
  45420. +
  45421. + public static final int IMAGE_FORMAT_UNKNOWN = 0;
  45422. + public static final int IMAGE_FORMAT_RGBA = 1;
  45423. + public static final int IMAGE_FORMAT_RGB = 2;
  45424. + public static final int IMAGE_FORMAT_NV12 = 3;
  45425. + public static final int IMAGE_FORMAT_NV21 = 4;
  45426. + public static final int IMAGE_FORMAT_YV12 = 5;
  45427. + public static final int IMAGE_FORMAT_YV21 = 6;
  45428. + public static final int IMAGE_FORMAT_YUV_420_888 = 7;
  45429. + public static final int IMAGE_FORMAT_ALPHA = 8;
  45430. + public static final int IMAGE_FORMAT_JPEG = 9;
  45431. +
  45432. + /** Specifies the image container type. Would be useful for choosing extractors. */
  45433. + @IntDef({
  45434. + STORAGE_TYPE_BITMAP,
  45435. + STORAGE_TYPE_BYTEBUFFER,
  45436. + STORAGE_TYPE_MEDIA_IMAGE,
  45437. + STORAGE_TYPE_IMAGE_PROXY,
  45438. + })
  45439. + @Retention(RetentionPolicy.SOURCE)
  45440. + public @interface StorageType {}
  45441. +
  45442. + public static final int STORAGE_TYPE_BITMAP = 1;
  45443. + public static final int STORAGE_TYPE_BYTEBUFFER = 2;
  45444. + public static final int STORAGE_TYPE_MEDIA_IMAGE = 3;
  45445. + public static final int STORAGE_TYPE_IMAGE_PROXY = 4;
  45446. +
  45447. + /**
  45448. + * Returns a list of supported image properties for this {@link MlImage}.
  45449. + *
  45450. + * <p>Currently {@link MlImage} only support single storage type so the size of return list will
  45451. + * always be 1.
  45452. + *
  45453. + * @see ImageProperties
  45454. + */
  45455. + public List<ImageProperties> getContainedImageProperties() {
  45456. + return Collections.singletonList(getContainer().getImageProperties());
  45457. + }
  45458. +
  45459. + /** Returns the rotation value attached to the image. Rotation value will be 0, 90, 180, 270. */
  45460. + public int getRotation() {
  45461. + return rotation;
  45462. + }
  45463. +
  45464. + /** Returns the timestamp attached to the image. */
  45465. + long getTimestamp() {
  45466. + return timestamp;
  45467. + }
  45468. +
  45469. + /** Returns the width of the image. */
  45470. + public int getWidth() {
  45471. + return width;
  45472. + }
  45473. +
  45474. + /** Returns the height of the image. */
  45475. + public int getHeight() {
  45476. + return height;
  45477. + }
  45478. - /** Specifies the image format of an image. */
  45479. - @IntDef({
  45480. - IMAGE_FORMAT_UNKNOWN,
  45481. - IMAGE_FORMAT_RGBA,
  45482. - IMAGE_FORMAT_RGB,
  45483. - IMAGE_FORMAT_NV12,
  45484. - IMAGE_FORMAT_NV21,
  45485. - IMAGE_FORMAT_YV12,
  45486. - IMAGE_FORMAT_YV21,
  45487. - IMAGE_FORMAT_YUV_420_888,
  45488. - IMAGE_FORMAT_ALPHA,
  45489. - IMAGE_FORMAT_JPEG,
  45490. - })
  45491. - @Retention(RetentionPolicy.SOURCE)
  45492. - public @interface ImageFormat {}
  45493. -
  45494. - public static final int IMAGE_FORMAT_UNKNOWN = 0;
  45495. - public static final int IMAGE_FORMAT_RGBA = 1;
  45496. - public static final int IMAGE_FORMAT_RGB = 2;
  45497. - public static final int IMAGE_FORMAT_NV12 = 3;
  45498. - public static final int IMAGE_FORMAT_NV21 = 4;
  45499. - public static final int IMAGE_FORMAT_YV12 = 5;
  45500. - public static final int IMAGE_FORMAT_YV21 = 6;
  45501. - public static final int IMAGE_FORMAT_YUV_420_888 = 7;
  45502. - public static final int IMAGE_FORMAT_ALPHA = 8;
  45503. - public static final int IMAGE_FORMAT_JPEG = 9;
  45504. -
  45505. - /** Specifies the image container type. Would be useful for choosing extractors. */
  45506. - @IntDef({
  45507. - STORAGE_TYPE_BITMAP,
  45508. - STORAGE_TYPE_BYTEBUFFER,
  45509. - STORAGE_TYPE_MEDIA_IMAGE,
  45510. - STORAGE_TYPE_IMAGE_PROXY,
  45511. - })
  45512. - @Retention(RetentionPolicy.SOURCE)
  45513. - public @interface StorageType {}
  45514. -
  45515. - public static final int STORAGE_TYPE_BITMAP = 1;
  45516. - public static final int STORAGE_TYPE_BYTEBUFFER = 2;
  45517. - public static final int STORAGE_TYPE_MEDIA_IMAGE = 3;
  45518. - public static final int STORAGE_TYPE_IMAGE_PROXY = 4;
  45519. -
  45520. - /**
  45521. - * Returns a list of supported image properties for this {@link MlImage}.
  45522. - *
  45523. - * <p>Currently {@link MlImage} only support single storage type so the size of return list will
  45524. - * always be 1.
  45525. - *
  45526. - * @see ImageProperties
  45527. - */
  45528. - public List<ImageProperties> getContainedImageProperties() {
  45529. - return Collections.singletonList(getContainer().getImageProperties());
  45530. - }
  45531. -
  45532. - /** Returns the rotation value attached to the image. Rotation value will be 0, 90, 180, 270. */
  45533. - public int getRotation() {
  45534. - return rotation;
  45535. - }
  45536. -
  45537. - /** Returns the timestamp attached to the image. */
  45538. - long getTimestamp() {
  45539. - return timestamp;
  45540. - }
  45541. -
  45542. - /** Returns the width of the image. */
  45543. - public int getWidth() {
  45544. - return width;
  45545. - }
  45546. -
  45547. - /** Returns the height of the image. */
  45548. - public int getHeight() {
  45549. - return height;
  45550. - }
  45551. -
  45552. - /** Returns the region-of-interest rectangle attached to the image. */
  45553. - Rect getRoi() {
  45554. - Rect result = new Rect();
  45555. - result.set(roi);
  45556. - return result;
  45557. - }
  45558. -
  45559. - /** Acquires a reference on this {@link MlImage}. This will increase the reference count by 1. */
  45560. - private synchronized void acquire() {
  45561. - referenceCount += 1;
  45562. - }
  45563. -
  45564. - /**
  45565. - * Removes a reference that was previously acquired or init.
  45566. - *
  45567. - * <p>When {@link MlImage} is created, it has 1 reference count.
  45568. - *
  45569. - * <p>When the reference count becomes 0, it will release the resource under the hood.
  45570. - */
  45571. - @Override
  45572. - // TODO(b/189767728): Create an internal flag to indicate image is closed, or use referenceCount
  45573. - public synchronized void close() {
  45574. - referenceCount -= 1;
  45575. - if (referenceCount == 0) {
  45576. - for (ImageContainer imageContainer : containerMap.values()) {
  45577. - imageContainer.close();
  45578. - }
  45579. + /** Returns the region-of-interest rectangle attached to the image. */
  45580. + Rect getRoi() {
  45581. + Rect result = new Rect();
  45582. + result.set(roi);
  45583. + return result;
  45584. }
  45585. - }
  45586. -
  45587. - /**
  45588. - * Advanced API access for {@link MlImage}.
  45589. - *
  45590. - * <p>These APIs are useful for other infrastructures, for example, acquiring extra reference
  45591. - * count for {@link MlImage}. However, an App developer should avoid using the following APIs.
  45592. - *
  45593. - * <p>APIs inside are treated as internal APIs which are subject to change.
  45594. - */
  45595. - public static final class Internal {
  45596. /**
  45597. * Acquires a reference on this {@link MlImage}. This will increase the reference count by 1.
  45598. + */
  45599. + private synchronized void acquire() {
  45600. + referenceCount += 1;
  45601. + }
  45602. +
  45603. + /**
  45604. + * Removes a reference that was previously acquired or init.
  45605. + *
  45606. + * <p>When {@link MlImage} is created, it has 1 reference count.
  45607. *
  45608. - * <p>This method is more useful for image consumer to acquire a reference so image resource
  45609. - * will not be closed accidentally. As image creator, normal developer doesn't need to call this
  45610. - * method.
  45611. + * <p>When the reference count becomes 0, it will release the resource under the hood.
  45612. + */
  45613. + @Override
  45614. + // TODO(b/189767728): Create an internal flag to indicate image is closed, or use referenceCount
  45615. + public synchronized void close() {
  45616. + referenceCount -= 1;
  45617. + if (referenceCount == 0) {
  45618. + for (ImageContainer imageContainer : containerMap.values()) {
  45619. + imageContainer.close();
  45620. + }
  45621. + }
  45622. + }
  45623. +
  45624. + /**
  45625. + * Advanced API access for {@link MlImage}.
  45626. *
  45627. - * <p>The reference count is 1 when {@link MlImage} is created. Developer can call {@link
  45628. - * #close()} to indicate it doesn't need this {@link MlImage} anymore.
  45629. + * <p>These APIs are useful for other infrastructures, for example, acquiring extra reference
  45630. + * count for {@link MlImage}. However, an App developer should avoid using the following APIs.
  45631. *
  45632. - * @see #close()
  45633. + * <p>APIs inside are treated as internal APIs which are subject to change.
  45634. */
  45635. - public void acquire() {
  45636. - image.acquire();
  45637. + public static final class Internal {
  45638. + /**
  45639. + * Acquires a reference on this {@link MlImage}. This will increase the reference count
  45640. + * by 1.
  45641. + *
  45642. + * <p>This method is more useful for image consumer to acquire a reference so image resource
  45643. + * will not be closed accidentally. As image creator, normal developer doesn't need to call
  45644. + * this method.
  45645. + *
  45646. + * <p>The reference count is 1 when {@link MlImage} is created. Developer can call {@link
  45647. + * #close()} to indicate it doesn't need this {@link MlImage} anymore.
  45648. + *
  45649. + * @see #close()
  45650. + */
  45651. + public void acquire() {
  45652. + image.acquire();
  45653. + }
  45654. +
  45655. + private final MlImage image;
  45656. +
  45657. + // Only MlImage creates the internal helper.
  45658. + private Internal(MlImage image) {
  45659. + this.image = image;
  45660. + }
  45661. + }
  45662. +
  45663. + /** Gets {@link Internal} object which contains internal APIs. */
  45664. + public Internal getInternal() {
  45665. + return new Internal(this);
  45666. }
  45667. - private final MlImage image;
  45668. + private final Map<ImageProperties, ImageContainer> containerMap;
  45669. + private final int rotation;
  45670. + private final Rect roi;
  45671. + private final long timestamp;
  45672. + private final int width;
  45673. + private final int height;
  45674. +
  45675. + private int referenceCount;
  45676. +
  45677. + /** Constructs an {@link MlImage} with a built container. */
  45678. + @KeepForSdk
  45679. + MlImage(ImageContainer container, int rotation, Rect roi, long timestamp, int width,
  45680. + int height) {
  45681. + this.containerMap = new HashMap<>();
  45682. + containerMap.put(container.getImageProperties(), container);
  45683. + this.rotation = rotation;
  45684. + this.roi = new Rect();
  45685. + this.roi.set(roi);
  45686. + this.timestamp = timestamp;
  45687. + this.width = width;
  45688. + this.height = height;
  45689. + this.referenceCount = 1;
  45690. + }
  45691. +
  45692. + /**
  45693. + * Gets one available container.
  45694. + *
  45695. + * @return the current container.
  45696. + */
  45697. + @KeepForSdk
  45698. + ImageContainer getContainer() {
  45699. + // According to the design, in the future we will support multiple containers in one image.
  45700. + // Currently just return the original container.
  45701. + // TODO(b/182443927): Cache multiple containers in MlImage.
  45702. + return containerMap.values().iterator().next();
  45703. + }
  45704. - // Only MlImage creates the internal helper.
  45705. - private Internal(MlImage image) {
  45706. - this.image = image;
  45707. + /**
  45708. + * Gets container from required {@code storageType}. Returns {@code null} if not existed.
  45709. + *
  45710. + * <p>If there are multiple containers with required {@code storageType}, returns the first one.
  45711. + */
  45712. + @Nullable
  45713. + @KeepForSdk
  45714. + ImageContainer getContainer(@StorageType int storageType) {
  45715. + for (Entry<ImageProperties, ImageContainer> entry : containerMap.entrySet()) {
  45716. + if (entry.getKey().getStorageType() == storageType) {
  45717. + return entry.getValue();
  45718. + }
  45719. + }
  45720. + return null;
  45721. }
  45722. - }
  45723. -
  45724. - /** Gets {@link Internal} object which contains internal APIs. */
  45725. - public Internal getInternal() {
  45726. - return new Internal(this);
  45727. - }
  45728. -
  45729. - private final Map<ImageProperties, ImageContainer> containerMap;
  45730. - private final int rotation;
  45731. - private final Rect roi;
  45732. - private final long timestamp;
  45733. - private final int width;
  45734. - private final int height;
  45735. -
  45736. - private int referenceCount;
  45737. -
  45738. - /** Constructs an {@link MlImage} with a built container. */
  45739. - @KeepForSdk
  45740. - MlImage(ImageContainer container, int rotation, Rect roi, long timestamp, int width, int height) {
  45741. - this.containerMap = new HashMap<>();
  45742. - containerMap.put(container.getImageProperties(), container);
  45743. - this.rotation = rotation;
  45744. - this.roi = new Rect();
  45745. - this.roi.set(roi);
  45746. - this.timestamp = timestamp;
  45747. - this.width = width;
  45748. - this.height = height;
  45749. - this.referenceCount = 1;
  45750. - }
  45751. -
  45752. - /**
  45753. - * Gets one available container.
  45754. - *
  45755. - * @return the current container.
  45756. - */
  45757. - @KeepForSdk
  45758. - ImageContainer getContainer() {
  45759. - // According to the design, in the future we will support multiple containers in one image.
  45760. - // Currently just return the original container.
  45761. - // TODO(b/182443927): Cache multiple containers in MlImage.
  45762. - return containerMap.values().iterator().next();
  45763. - }
  45764. -
  45765. - /**
  45766. - * Gets container from required {@code storageType}. Returns {@code null} if not existed.
  45767. - *
  45768. - * <p>If there are multiple containers with required {@code storageType}, returns the first one.
  45769. - */
  45770. - @Nullable
  45771. - @KeepForSdk
  45772. - ImageContainer getContainer(@StorageType int storageType) {
  45773. - for (Entry<ImageProperties, ImageContainer> entry : containerMap.entrySet()) {
  45774. - if (entry.getKey().getStorageType() == storageType) {
  45775. - return entry.getValue();
  45776. - }
  45777. +
  45778. + /**
  45779. + * Gets container from required {@code imageProperties}. Returns {@code null} if non existed.
  45780. + */
  45781. + @Nullable
  45782. + @KeepForSdk
  45783. + ImageContainer getContainer(ImageProperties imageProperties) {
  45784. + return containerMap.get(imageProperties);
  45785. }
  45786. - return null;
  45787. - }
  45788. -
  45789. - /** Gets container from required {@code imageProperties}. Returns {@code null} if non existed. */
  45790. - @Nullable
  45791. - @KeepForSdk
  45792. - ImageContainer getContainer(ImageProperties imageProperties) {
  45793. - return containerMap.get(imageProperties);
  45794. - }
  45795. -
  45796. - /** Adds a new container if it doesn't exist. Returns {@code true} if it succeeds. */
  45797. - boolean addContainer(ImageContainer container) {
  45798. - ImageProperties imageProperties = container.getImageProperties();
  45799. - if (containerMap.containsKey(imageProperties)) {
  45800. - return false;
  45801. +
  45802. + /** Adds a new container if it doesn't exist. Returns {@code true} if it succeeds. */
  45803. + boolean addContainer(ImageContainer container) {
  45804. + ImageProperties imageProperties = container.getImageProperties();
  45805. + if (containerMap.containsKey(imageProperties)) {
  45806. + return false;
  45807. + }
  45808. + containerMap.put(imageProperties, container);
  45809. + return true;
  45810. }
  45811. - containerMap.put(imageProperties, container);
  45812. - return true;
  45813. - }
  45814. -
  45815. - /**
  45816. - * Validates rotation values for builders. Only supports 0, 90, 180, 270.
  45817. - *
  45818. - * @throws IllegalArgumentException if the rotation value is invalid.
  45819. - */
  45820. - static void validateRotation(int rotation) {
  45821. - if (rotation != 0 && rotation != 90 && rotation != 180 && rotation != 270) {
  45822. - throw new IllegalArgumentException(
  45823. - "Rotation value " + rotation + " is not valid. Use only 0, 90, 180 or 270.");
  45824. +
  45825. + /**
  45826. + * Validates rotation values for builders. Only supports 0, 90, 180, 270.
  45827. + *
  45828. + * @throws IllegalArgumentException if the rotation value is invalid.
  45829. + */
  45830. + static void validateRotation(int rotation) {
  45831. + if (rotation != 0 && rotation != 90 && rotation != 180 && rotation != 270) {
  45832. + throw new IllegalArgumentException(
  45833. + "Rotation value " + rotation + " is not valid. Use only 0, 90, 180 or 270.");
  45834. + }
  45835. }
  45836. - }
  45837. }
  45838. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/BitmapExtractorTest.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/BitmapExtractorTest.java
  45839. index 44eb1198884fa..8408a0e424a9b 100644
  45840. --- a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/BitmapExtractorTest.java
  45841. +++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/BitmapExtractorTest.java
  45842. @@ -16,39 +16,37 @@ limitations under the License.
  45843. package com.google.android.odml.image;
  45844. import static com.google.common.truth.Truth.assertThat;
  45845. +
  45846. import static org.junit.Assert.assertThrows;
  45847. import android.graphics.Bitmap;
  45848. -import java.nio.ByteBuffer;
  45849. +
  45850. import org.junit.Test;
  45851. import org.junit.runner.RunWith;
  45852. import org.robolectric.RobolectricTestRunner;
  45853. +import java.nio.ByteBuffer;
  45854. +
  45855. /** Unit test for {@link BitmapExtractor}. */
  45856. @RunWith(RobolectricTestRunner.class)
  45857. public class BitmapExtractorTest {
  45858. + @Test
  45859. + public void extract_fromBitmap_succeeds() {
  45860. + Bitmap bitmap = TestImageCreator.createRgbaBitmap();
  45861. + MlImage image = new BitmapMlImageBuilder(bitmap).build();
  45862. +
  45863. + Bitmap result = BitmapExtractor.extract(image);
  45864. +
  45865. + assertThat(result).isSameInstanceAs(bitmap);
  45866. + }
  45867. +
  45868. + @Test
  45869. + public void extract_fromByteBuffer_throwsException() {
  45870. + ByteBuffer buffer = TestImageCreator.createRgbBuffer();
  45871. + MlImage image = new ByteBufferMlImageBuilder(buffer, TestImageCreator.getWidth(),
  45872. + TestImageCreator.getHeight(), MlImage.IMAGE_FORMAT_RGB)
  45873. + .build();
  45874. - @Test
  45875. - public void extract_fromBitmap_succeeds() {
  45876. - Bitmap bitmap = TestImageCreator.createRgbaBitmap();
  45877. - MlImage image = new BitmapMlImageBuilder(bitmap).build();
  45878. -
  45879. - Bitmap result = BitmapExtractor.extract(image);
  45880. -
  45881. - assertThat(result).isSameInstanceAs(bitmap);
  45882. - }
  45883. -
  45884. - @Test
  45885. - public void extract_fromByteBuffer_throwsException() {
  45886. - ByteBuffer buffer = TestImageCreator.createRgbBuffer();
  45887. - MlImage image =
  45888. - new ByteBufferMlImageBuilder(
  45889. - buffer,
  45890. - TestImageCreator.getWidth(),
  45891. - TestImageCreator.getHeight(),
  45892. - MlImage.IMAGE_FORMAT_RGB)
  45893. - .build();
  45894. -
  45895. - assertThrows(IllegalArgumentException.class, () -> BitmapExtractor.extract(image));
  45896. - }
  45897. + assertThrows(IllegalArgumentException.class, () -> BitmapExtractor.extract(image));
  45898. + }
  45899. }
  45900. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/BitmapMlImageBuilderTest.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/BitmapMlImageBuilderTest.java
  45901. index f9908210f2970..9a4051cdf8f6a 100644
  45902. --- a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/BitmapMlImageBuilderTest.java
  45903. +++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/BitmapMlImageBuilderTest.java
  45904. @@ -16,11 +16,13 @@ limitations under the License.
  45905. package com.google.android.odml.image;
  45906. import static com.google.common.truth.Truth.assertThat;
  45907. +
  45908. import static org.junit.Assert.assertThrows;
  45909. import android.graphics.Bitmap;
  45910. import android.graphics.Bitmap.Config;
  45911. import android.graphics.Rect;
  45912. +
  45913. import org.junit.Test;
  45914. import org.junit.runner.RunWith;
  45915. import org.robolectric.RobolectricTestRunner;
  45916. @@ -28,63 +30,59 @@ import org.robolectric.RobolectricTestRunner;
  45917. /** Tests for {@link BitmapMlImageBuilder} */
  45918. @RunWith(RobolectricTestRunner.class)
  45919. public final class BitmapMlImageBuilderTest {
  45920. -
  45921. - @Test
  45922. - public void build_fromBitmap_succeeds() {
  45923. - Bitmap bitmap = Bitmap.createBitmap(20, 25, Config.ARGB_8888);
  45924. -
  45925. - MlImage image = new BitmapMlImageBuilder(bitmap).build();
  45926. - ImageContainer container = image.getContainer(MlImage.STORAGE_TYPE_BITMAP);
  45927. -
  45928. - assertThat(image.getWidth()).isEqualTo(20);
  45929. - assertThat(image.getHeight()).isEqualTo(25);
  45930. - assertThat(image.getContainedImageProperties())
  45931. - .containsExactly(
  45932. - ImageProperties.builder()
  45933. - .setImageFormat(MlImage.IMAGE_FORMAT_RGBA)
  45934. - .setStorageType(MlImage.STORAGE_TYPE_BITMAP)
  45935. - .build());
  45936. - assertThat(((BitmapImageContainer) container).getBitmap().getConfig())
  45937. - .isEqualTo(Config.ARGB_8888);
  45938. - }
  45939. -
  45940. - @Test
  45941. - public void build_withOptionalProperties_succeeds() {
  45942. - Bitmap bitmap = Bitmap.createBitmap(20, 25, Config.ARGB_8888);
  45943. -
  45944. - MlImage image =
  45945. - new BitmapMlImageBuilder(bitmap)
  45946. - .setRoi(new Rect(0, 5, 10, 15))
  45947. - .setRotation(90)
  45948. - .setTimestamp(12345)
  45949. - .build();
  45950. -
  45951. - assertThat(image.getTimestamp()).isEqualTo(12345);
  45952. - assertThat(image.getRotation()).isEqualTo(90);
  45953. - assertThat(image.getRoi()).isEqualTo(new Rect(0, 5, 10, 15));
  45954. - }
  45955. -
  45956. - @Test
  45957. - public void build_withInvalidRotation_throwsException() {
  45958. - Bitmap bitmap = Bitmap.createBitmap(20, 25, Config.ARGB_8888);
  45959. - BitmapMlImageBuilder builder = new BitmapMlImageBuilder(bitmap);
  45960. -
  45961. - assertThrows(IllegalArgumentException.class, () -> builder.setRotation(360));
  45962. - }
  45963. -
  45964. - @Test
  45965. - public void release_recyclesBitmap() {
  45966. - Bitmap bitmap = Bitmap.createBitmap(20, 25, Config.ARGB_8888);
  45967. -
  45968. - MlImage image =
  45969. - new BitmapMlImageBuilder(bitmap)
  45970. - .setRoi(new Rect(0, 5, 10, 15))
  45971. - .setRotation(90)
  45972. - .setTimestamp(12345)
  45973. - .build();
  45974. - assertThat(bitmap.isRecycled()).isFalse();
  45975. - image.close();
  45976. -
  45977. - assertThat(bitmap.isRecycled()).isTrue();
  45978. - }
  45979. + @Test
  45980. + public void build_fromBitmap_succeeds() {
  45981. + Bitmap bitmap = Bitmap.createBitmap(20, 25, Config.ARGB_8888);
  45982. +
  45983. + MlImage image = new BitmapMlImageBuilder(bitmap).build();
  45984. + ImageContainer container = image.getContainer(MlImage.STORAGE_TYPE_BITMAP);
  45985. +
  45986. + assertThat(image.getWidth()).isEqualTo(20);
  45987. + assertThat(image.getHeight()).isEqualTo(25);
  45988. + assertThat(image.getContainedImageProperties())
  45989. + .containsExactly(ImageProperties.builder()
  45990. + .setImageFormat(MlImage.IMAGE_FORMAT_RGBA)
  45991. + .setStorageType(MlImage.STORAGE_TYPE_BITMAP)
  45992. + .build());
  45993. + assertThat(((BitmapImageContainer) container).getBitmap().getConfig())
  45994. + .isEqualTo(Config.ARGB_8888);
  45995. + }
  45996. +
  45997. + @Test
  45998. + public void build_withOptionalProperties_succeeds() {
  45999. + Bitmap bitmap = Bitmap.createBitmap(20, 25, Config.ARGB_8888);
  46000. +
  46001. + MlImage image = new BitmapMlImageBuilder(bitmap)
  46002. + .setRoi(new Rect(0, 5, 10, 15))
  46003. + .setRotation(90)
  46004. + .setTimestamp(12345)
  46005. + .build();
  46006. +
  46007. + assertThat(image.getTimestamp()).isEqualTo(12345);
  46008. + assertThat(image.getRotation()).isEqualTo(90);
  46009. + assertThat(image.getRoi()).isEqualTo(new Rect(0, 5, 10, 15));
  46010. + }
  46011. +
  46012. + @Test
  46013. + public void build_withInvalidRotation_throwsException() {
  46014. + Bitmap bitmap = Bitmap.createBitmap(20, 25, Config.ARGB_8888);
  46015. + BitmapMlImageBuilder builder = new BitmapMlImageBuilder(bitmap);
  46016. +
  46017. + assertThrows(IllegalArgumentException.class, () -> builder.setRotation(360));
  46018. + }
  46019. +
  46020. + @Test
  46021. + public void release_recyclesBitmap() {
  46022. + Bitmap bitmap = Bitmap.createBitmap(20, 25, Config.ARGB_8888);
  46023. +
  46024. + MlImage image = new BitmapMlImageBuilder(bitmap)
  46025. + .setRoi(new Rect(0, 5, 10, 15))
  46026. + .setRotation(90)
  46027. + .setTimestamp(12345)
  46028. + .build();
  46029. + assertThat(bitmap.isRecycled()).isFalse();
  46030. + image.close();
  46031. +
  46032. + assertThat(bitmap.isRecycled()).isTrue();
  46033. + }
  46034. }
  46035. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/ByteBufferExtractorTest.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/ByteBufferExtractorTest.java
  46036. index 2ff49010443a5..e675ba9abd479 100644
  46037. --- a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/ByteBufferExtractorTest.java
  46038. +++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/ByteBufferExtractorTest.java
  46039. @@ -16,15 +16,18 @@ limitations under the License.
  46040. package com.google.android.odml.image;
  46041. import static com.google.common.truth.Truth.assertThat;
  46042. +
  46043. import static org.junit.Assert.assertThrows;
  46044. import android.graphics.Bitmap;
  46045. -import java.nio.Buffer;
  46046. -import java.nio.ByteBuffer;
  46047. +
  46048. import org.junit.Test;
  46049. import org.junit.runner.RunWith;
  46050. import org.robolectric.RobolectricTestRunner;
  46051. +import java.nio.Buffer;
  46052. +import java.nio.ByteBuffer;
  46053. +
  46054. /**
  46055. * Tests for {@link ByteBufferExtractor}.
  46056. *
  46057. @@ -35,145 +38,120 @@ import org.robolectric.RobolectricTestRunner;
  46058. */
  46059. @RunWith(RobolectricTestRunner.class)
  46060. public final class ByteBufferExtractorTest {
  46061. -
  46062. - @Test
  46063. - public void extract_fromByteBuffer_succeeds() {
  46064. - ByteBuffer byteBuffer = TestImageCreator.createRgbBuffer();
  46065. - MlImage image =
  46066. - new ByteBufferMlImageBuilder(
  46067. - byteBuffer,
  46068. - TestImageCreator.getWidth(),
  46069. - TestImageCreator.getHeight(),
  46070. - MlImage.IMAGE_FORMAT_RGB)
  46071. - .build();
  46072. -
  46073. - ByteBuffer result = ByteBufferExtractor.extract(image);
  46074. -
  46075. - assertThat(result).isEquivalentAccordingToCompareTo(byteBuffer);
  46076. - assertThat(result.isReadOnly()).isTrue();
  46077. - }
  46078. -
  46079. - @Test
  46080. - public void extract_fromBitmap_throws() {
  46081. - Bitmap rgbaBitmap = TestImageCreator.createRgbaBitmap();
  46082. - MlImage image = new BitmapMlImageBuilder(rgbaBitmap).build();
  46083. -
  46084. - assertThrows(IllegalArgumentException.class, () -> ByteBufferExtractor.extract(image));
  46085. - }
  46086. -
  46087. - @Test
  46088. - public void extract_rgbFromRgbByteBuffer_succeeds() {
  46089. - ByteBuffer buffer = TestImageCreator.createRgbBuffer();
  46090. - MlImage image =
  46091. - new ByteBufferMlImageBuilder(
  46092. - buffer,
  46093. - TestImageCreator.getWidth(),
  46094. - TestImageCreator.getHeight(),
  46095. - MlImage.IMAGE_FORMAT_RGB)
  46096. - .build();
  46097. -
  46098. - ByteBuffer result = ByteBufferExtractor.extract(image, MlImage.IMAGE_FORMAT_RGB);
  46099. -
  46100. - assertThat(result.isReadOnly()).isTrue();
  46101. - assertThat(result).isEquivalentAccordingToCompareTo(TestImageCreator.createRgbBuffer());
  46102. - }
  46103. -
  46104. - @Test
  46105. - public void extract_rgbFromRgbaByteBuffer_succeeds() {
  46106. - ByteBuffer buffer = TestImageCreator.createRgbaBuffer();
  46107. - MlImage image =
  46108. - new ByteBufferMlImageBuilder(
  46109. - buffer,
  46110. - TestImageCreator.getWidth(),
  46111. - TestImageCreator.getHeight(),
  46112. - MlImage.IMAGE_FORMAT_RGBA)
  46113. - .build();
  46114. -
  46115. - ByteBuffer result = ByteBufferExtractor.extract(image, MlImage.IMAGE_FORMAT_RGB);
  46116. -
  46117. - assertThat(result).isEquivalentAccordingToCompareTo(TestImageCreator.createRgbBuffer());
  46118. - assertThat(buffer.position()).isEqualTo(0);
  46119. - }
  46120. -
  46121. - @Test
  46122. - public void extract_rgbaFromRgbByteBuffer_succeeds() {
  46123. - ByteBuffer buffer = TestImageCreator.createRgbBuffer();
  46124. - MlImage image =
  46125. - new ByteBufferMlImageBuilder(
  46126. - buffer,
  46127. - TestImageCreator.getWidth(),
  46128. - TestImageCreator.getHeight(),
  46129. - MlImage.IMAGE_FORMAT_RGB)
  46130. - .build();
  46131. -
  46132. - ByteBuffer result = ByteBufferExtractor.extract(image, MlImage.IMAGE_FORMAT_RGBA);
  46133. -
  46134. - assertThat(result).isEquivalentAccordingToCompareTo(TestImageCreator.createOpaqueRgbaBuffer());
  46135. - assertThat(buffer.position()).isEqualTo(0);
  46136. - }
  46137. -
  46138. - @Test
  46139. - public void extract_rgbFromRgbaBitmap_succeeds() {
  46140. - Bitmap rgbaBitmap = TestImageCreator.createRgbaBitmap();
  46141. - MlImage image = new BitmapMlImageBuilder(rgbaBitmap).build();
  46142. -
  46143. - ByteBuffer result = ByteBufferExtractor.extract(image, MlImage.IMAGE_FORMAT_RGB);
  46144. -
  46145. - assertThat(result.isReadOnly()).isTrue();
  46146. - assertThat(result).isEquivalentAccordingToCompareTo(TestImageCreator.createRgbBuffer());
  46147. -
  46148. - // Verifies ByteBuffer is cached inside MlImage.
  46149. - ByteBufferImageContainer byteBufferImageContainer =
  46150. - (ByteBufferImageContainer) image.getContainer(MlImage.STORAGE_TYPE_BYTEBUFFER);
  46151. - assertThat(byteBufferImageContainer.getByteBuffer()).isEqualTo(result);
  46152. - assertThat(byteBufferImageContainer.getImageFormat()).isEqualTo(MlImage.IMAGE_FORMAT_RGB);
  46153. -
  46154. - // Verifies that extracted ByteBuffer is the cached one.
  46155. - ByteBuffer result2 = ByteBufferExtractor.extract(image, MlImage.IMAGE_FORMAT_RGB);
  46156. - assertThat(result2).isEqualTo(result);
  46157. - }
  46158. -
  46159. - @Test
  46160. - public void extract_unsupportedFormatFromByteBuffer_throws() {
  46161. - ByteBuffer buffer = TestImageCreator.createRgbaBuffer();
  46162. - MlImage image =
  46163. - new ByteBufferMlImageBuilder(
  46164. - buffer,
  46165. - TestImageCreator.getWidth(),
  46166. - TestImageCreator.getHeight(),
  46167. - MlImage.IMAGE_FORMAT_RGBA)
  46168. - .build();
  46169. -
  46170. - assertThrows(
  46171. - IllegalArgumentException.class,
  46172. - () -> ByteBufferExtractor.extract(image, MlImage.IMAGE_FORMAT_YUV_420_888));
  46173. - }
  46174. -
  46175. - @Test
  46176. - public void extractInRecommendedFormat_anyFormatFromRgbByteBuffer_succeeds() {
  46177. - ByteBuffer buffer = TestImageCreator.createRgbBuffer();
  46178. - MlImage image =
  46179. - new ByteBufferMlImageBuilder(
  46180. - buffer,
  46181. - TestImageCreator.getWidth(),
  46182. - TestImageCreator.getHeight(),
  46183. - MlImage.IMAGE_FORMAT_RGB)
  46184. - .build();
  46185. -
  46186. - ByteBufferExtractor.Result result = ByteBufferExtractor.extractInRecommendedFormat(image);
  46187. -
  46188. - assertThat(result.buffer().isReadOnly()).isTrue();
  46189. - assertThat(result.format()).isEqualTo(MlImage.IMAGE_FORMAT_RGB);
  46190. -
  46191. - // Verifies ByteBuffer is cached inside MlImage.
  46192. - ByteBufferImageContainer byteBufferImageContainer =
  46193. - (ByteBufferImageContainer) image.getContainer(MlImage.STORAGE_TYPE_BYTEBUFFER);
  46194. - assertThat(byteBufferImageContainer.getByteBuffer()).isEqualTo(result.buffer());
  46195. - assertThat(byteBufferImageContainer.getImageFormat()).isEqualTo(MlImage.IMAGE_FORMAT_RGB);
  46196. -
  46197. - // Verifies that extracted ByteBuffer is the cached one.
  46198. - ByteBufferExtractor.Result result2 = ByteBufferExtractor.extractInRecommendedFormat(image);
  46199. - assertThat(result2.buffer()).isEqualTo(result.buffer());
  46200. - assertThat(result2.format()).isEqualTo(result.format());
  46201. - }
  46202. + @Test
  46203. + public void extract_fromByteBuffer_succeeds() {
  46204. + ByteBuffer byteBuffer = TestImageCreator.createRgbBuffer();
  46205. + MlImage image = new ByteBufferMlImageBuilder(byteBuffer, TestImageCreator.getWidth(),
  46206. + TestImageCreator.getHeight(), MlImage.IMAGE_FORMAT_RGB)
  46207. + .build();
  46208. +
  46209. + ByteBuffer result = ByteBufferExtractor.extract(image);
  46210. +
  46211. + assertThat(result).isEquivalentAccordingToCompareTo(byteBuffer);
  46212. + assertThat(result.isReadOnly()).isTrue();
  46213. + }
  46214. +
  46215. + @Test
  46216. + public void extract_fromBitmap_throws() {
  46217. + Bitmap rgbaBitmap = TestImageCreator.createRgbaBitmap();
  46218. + MlImage image = new BitmapMlImageBuilder(rgbaBitmap).build();
  46219. +
  46220. + assertThrows(IllegalArgumentException.class, () -> ByteBufferExtractor.extract(image));
  46221. + }
  46222. +
  46223. + @Test
  46224. + public void extract_rgbFromRgbByteBuffer_succeeds() {
  46225. + ByteBuffer buffer = TestImageCreator.createRgbBuffer();
  46226. + MlImage image = new ByteBufferMlImageBuilder(buffer, TestImageCreator.getWidth(),
  46227. + TestImageCreator.getHeight(), MlImage.IMAGE_FORMAT_RGB)
  46228. + .build();
  46229. +
  46230. + ByteBuffer result = ByteBufferExtractor.extract(image, MlImage.IMAGE_FORMAT_RGB);
  46231. +
  46232. + assertThat(result.isReadOnly()).isTrue();
  46233. + assertThat(result).isEquivalentAccordingToCompareTo(TestImageCreator.createRgbBuffer());
  46234. + }
  46235. +
  46236. + @Test
  46237. + public void extract_rgbFromRgbaByteBuffer_succeeds() {
  46238. + ByteBuffer buffer = TestImageCreator.createRgbaBuffer();
  46239. + MlImage image = new ByteBufferMlImageBuilder(buffer, TestImageCreator.getWidth(),
  46240. + TestImageCreator.getHeight(), MlImage.IMAGE_FORMAT_RGBA)
  46241. + .build();
  46242. +
  46243. + ByteBuffer result = ByteBufferExtractor.extract(image, MlImage.IMAGE_FORMAT_RGB);
  46244. +
  46245. + assertThat(result).isEquivalentAccordingToCompareTo(TestImageCreator.createRgbBuffer());
  46246. + assertThat(buffer.position()).isEqualTo(0);
  46247. + }
  46248. +
  46249. + @Test
  46250. + public void extract_rgbaFromRgbByteBuffer_succeeds() {
  46251. + ByteBuffer buffer = TestImageCreator.createRgbBuffer();
  46252. + MlImage image = new ByteBufferMlImageBuilder(buffer, TestImageCreator.getWidth(),
  46253. + TestImageCreator.getHeight(), MlImage.IMAGE_FORMAT_RGB)
  46254. + .build();
  46255. +
  46256. + ByteBuffer result = ByteBufferExtractor.extract(image, MlImage.IMAGE_FORMAT_RGBA);
  46257. +
  46258. + assertThat(result).isEquivalentAccordingToCompareTo(
  46259. + TestImageCreator.createOpaqueRgbaBuffer());
  46260. + assertThat(buffer.position()).isEqualTo(0);
  46261. + }
  46262. +
  46263. + @Test
  46264. + public void extract_rgbFromRgbaBitmap_succeeds() {
  46265. + Bitmap rgbaBitmap = TestImageCreator.createRgbaBitmap();
  46266. + MlImage image = new BitmapMlImageBuilder(rgbaBitmap).build();
  46267. +
  46268. + ByteBuffer result = ByteBufferExtractor.extract(image, MlImage.IMAGE_FORMAT_RGB);
  46269. +
  46270. + assertThat(result.isReadOnly()).isTrue();
  46271. + assertThat(result).isEquivalentAccordingToCompareTo(TestImageCreator.createRgbBuffer());
  46272. +
  46273. + // Verifies ByteBuffer is cached inside MlImage.
  46274. + ByteBufferImageContainer byteBufferImageContainer =
  46275. + (ByteBufferImageContainer) image.getContainer(MlImage.STORAGE_TYPE_BYTEBUFFER);
  46276. + assertThat(byteBufferImageContainer.getByteBuffer()).isEqualTo(result);
  46277. + assertThat(byteBufferImageContainer.getImageFormat()).isEqualTo(MlImage.IMAGE_FORMAT_RGB);
  46278. +
  46279. + // Verifies that extracted ByteBuffer is the cached one.
  46280. + ByteBuffer result2 = ByteBufferExtractor.extract(image, MlImage.IMAGE_FORMAT_RGB);
  46281. + assertThat(result2).isEqualTo(result);
  46282. + }
  46283. +
  46284. + @Test
  46285. + public void extract_unsupportedFormatFromByteBuffer_throws() {
  46286. + ByteBuffer buffer = TestImageCreator.createRgbaBuffer();
  46287. + MlImage image = new ByteBufferMlImageBuilder(buffer, TestImageCreator.getWidth(),
  46288. + TestImageCreator.getHeight(), MlImage.IMAGE_FORMAT_RGBA)
  46289. + .build();
  46290. +
  46291. + assertThrows(IllegalArgumentException.class,
  46292. + () -> ByteBufferExtractor.extract(image, MlImage.IMAGE_FORMAT_YUV_420_888));
  46293. + }
  46294. +
  46295. + @Test
  46296. + public void extractInRecommendedFormat_anyFormatFromRgbByteBuffer_succeeds() {
  46297. + ByteBuffer buffer = TestImageCreator.createRgbBuffer();
  46298. + MlImage image = new ByteBufferMlImageBuilder(buffer, TestImageCreator.getWidth(),
  46299. + TestImageCreator.getHeight(), MlImage.IMAGE_FORMAT_RGB)
  46300. + .build();
  46301. +
  46302. + ByteBufferExtractor.Result result = ByteBufferExtractor.extractInRecommendedFormat(image);
  46303. +
  46304. + assertThat(result.buffer().isReadOnly()).isTrue();
  46305. + assertThat(result.format()).isEqualTo(MlImage.IMAGE_FORMAT_RGB);
  46306. +
  46307. + // Verifies ByteBuffer is cached inside MlImage.
  46308. + ByteBufferImageContainer byteBufferImageContainer =
  46309. + (ByteBufferImageContainer) image.getContainer(MlImage.STORAGE_TYPE_BYTEBUFFER);
  46310. + assertThat(byteBufferImageContainer.getByteBuffer()).isEqualTo(result.buffer());
  46311. + assertThat(byteBufferImageContainer.getImageFormat()).isEqualTo(MlImage.IMAGE_FORMAT_RGB);
  46312. +
  46313. + // Verifies that extracted ByteBuffer is the cached one.
  46314. + ByteBufferExtractor.Result result2 = ByteBufferExtractor.extractInRecommendedFormat(image);
  46315. + assertThat(result2.buffer()).isEqualTo(result.buffer());
  46316. + assertThat(result2.format()).isEqualTo(result.format());
  46317. + }
  46318. }
  46319. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/ByteBufferMlImageBuilderTest.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/ByteBufferMlImageBuilderTest.java
  46320. index 45ba77934a61f..374c82b3f4e8d 100644
  46321. --- a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/ByteBufferMlImageBuilderTest.java
  46322. +++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/ByteBufferMlImageBuilderTest.java
  46323. @@ -16,61 +16,62 @@ limitations under the License.
  46324. package com.google.android.odml.image;
  46325. import static com.google.common.truth.Truth.assertThat;
  46326. +
  46327. import static org.junit.Assert.assertThrows;
  46328. import android.graphics.Rect;
  46329. -import java.nio.ByteBuffer;
  46330. +
  46331. import org.junit.Test;
  46332. import org.junit.runner.RunWith;
  46333. import org.robolectric.RobolectricTestRunner;
  46334. +import java.nio.ByteBuffer;
  46335. +
  46336. /** Tests for {@link ByteBufferMlImageBuilder} */
  46337. @RunWith(RobolectricTestRunner.class)
  46338. public final class ByteBufferMlImageBuilderTest {
  46339. + @Test
  46340. + public void build_fromByteBuffer_succeeds() {
  46341. + ByteBuffer buffer = ByteBuffer.allocate(500);
  46342. +
  46343. + MlImage image =
  46344. + new ByteBufferMlImageBuilder(buffer, 20, 25, MlImage.IMAGE_FORMAT_RGB).build();
  46345. + ImageContainer container = image.getContainer(MlImage.STORAGE_TYPE_BYTEBUFFER);
  46346. +
  46347. + assertThat(image.getWidth()).isEqualTo(20);
  46348. + assertThat(image.getHeight()).isEqualTo(25);
  46349. + assertThat(image.getRoi()).isEqualTo(new Rect(0, 0, 20, 25));
  46350. + assertThat(image.getRotation()).isEqualTo(0);
  46351. + assertThat(image.getContainedImageProperties())
  46352. + .containsExactly(ImageProperties.builder()
  46353. + .setStorageType(MlImage.STORAGE_TYPE_BYTEBUFFER)
  46354. + .setImageFormat(MlImage.IMAGE_FORMAT_RGB)
  46355. + .build());
  46356. + assertThat(((ByteBufferImageContainer) container).getImageFormat())
  46357. + .isEqualTo(MlImage.IMAGE_FORMAT_RGB);
  46358. + }
  46359. +
  46360. + @Test
  46361. + public void build_withOptionalProperties_succeeds() {
  46362. + ByteBuffer buffer = ByteBuffer.allocate(500);
  46363. +
  46364. + MlImage image = new ByteBufferMlImageBuilder(buffer, 20, 25, MlImage.IMAGE_FORMAT_RGB)
  46365. + .setRoi(new Rect(0, 5, 10, 15))
  46366. + .setRotation(90)
  46367. + .setTimestamp(12345)
  46368. + .build();
  46369. +
  46370. + assertThat(image.getTimestamp()).isEqualTo(12345);
  46371. + assertThat(image.getRotation()).isEqualTo(90);
  46372. + assertThat(image.getRoi()).isEqualTo(new Rect(0, 5, 10, 15));
  46373. + }
  46374. +
  46375. + @Test
  46376. + public void build_withInvalidRotation_throwsException() {
  46377. + ByteBuffer buffer = ByteBuffer.allocate(500);
  46378. + ByteBufferMlImageBuilder builder =
  46379. + new ByteBufferMlImageBuilder(buffer, 20, 25, MlImage.IMAGE_FORMAT_RGB);
  46380. - @Test
  46381. - public void build_fromByteBuffer_succeeds() {
  46382. - ByteBuffer buffer = ByteBuffer.allocate(500);
  46383. -
  46384. - MlImage image = new ByteBufferMlImageBuilder(buffer, 20, 25, MlImage.IMAGE_FORMAT_RGB).build();
  46385. - ImageContainer container = image.getContainer(MlImage.STORAGE_TYPE_BYTEBUFFER);
  46386. -
  46387. - assertThat(image.getWidth()).isEqualTo(20);
  46388. - assertThat(image.getHeight()).isEqualTo(25);
  46389. - assertThat(image.getRoi()).isEqualTo(new Rect(0, 0, 20, 25));
  46390. - assertThat(image.getRotation()).isEqualTo(0);
  46391. - assertThat(image.getContainedImageProperties())
  46392. - .containsExactly(
  46393. - ImageProperties.builder()
  46394. - .setStorageType(MlImage.STORAGE_TYPE_BYTEBUFFER)
  46395. - .setImageFormat(MlImage.IMAGE_FORMAT_RGB)
  46396. - .build());
  46397. - assertThat(((ByteBufferImageContainer) container).getImageFormat())
  46398. - .isEqualTo(MlImage.IMAGE_FORMAT_RGB);
  46399. - }
  46400. -
  46401. - @Test
  46402. - public void build_withOptionalProperties_succeeds() {
  46403. - ByteBuffer buffer = ByteBuffer.allocate(500);
  46404. -
  46405. - MlImage image =
  46406. - new ByteBufferMlImageBuilder(buffer, 20, 25, MlImage.IMAGE_FORMAT_RGB)
  46407. - .setRoi(new Rect(0, 5, 10, 15))
  46408. - .setRotation(90)
  46409. - .setTimestamp(12345)
  46410. - .build();
  46411. -
  46412. - assertThat(image.getTimestamp()).isEqualTo(12345);
  46413. - assertThat(image.getRotation()).isEqualTo(90);
  46414. - assertThat(image.getRoi()).isEqualTo(new Rect(0, 5, 10, 15));
  46415. - }
  46416. -
  46417. - @Test
  46418. - public void build_withInvalidRotation_throwsException() {
  46419. - ByteBuffer buffer = ByteBuffer.allocate(500);
  46420. - ByteBufferMlImageBuilder builder =
  46421. - new ByteBufferMlImageBuilder(buffer, 20, 25, MlImage.IMAGE_FORMAT_RGB);
  46422. -
  46423. - assertThrows(IllegalArgumentException.class, () -> builder.setRotation(360));
  46424. - }
  46425. + assertThrows(IllegalArgumentException.class, () -> builder.setRotation(360));
  46426. + }
  46427. }
  46428. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/MediaImageExtractorTest.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/MediaImageExtractorTest.java
  46429. index 67ed4a7f6e2c4..fa832671e4458 100644
  46430. --- a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/MediaImageExtractorTest.java
  46431. +++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/MediaImageExtractorTest.java
  46432. @@ -16,6 +16,7 @@ limitations under the License.
  46433. package com.google.android.odml.image;
  46434. import static com.google.common.truth.Truth.assertThat;
  46435. +
  46436. import static org.junit.Assert.assertThrows;
  46437. import static org.mockito.Mockito.when;
  46438. @@ -23,6 +24,7 @@ import android.graphics.Bitmap;
  46439. import android.graphics.Bitmap.Config;
  46440. import android.graphics.ImageFormat;
  46441. import android.media.Image;
  46442. +
  46443. import org.junit.Before;
  46444. import org.junit.Test;
  46445. import org.junit.runner.RunWith;
  46446. @@ -33,34 +35,34 @@ import org.robolectric.RobolectricTestRunner;
  46447. /** Tests for {@link MediaImageExtractor} */
  46448. @RunWith(RobolectricTestRunner.class)
  46449. public final class MediaImageExtractorTest {
  46450. - private static final int HEIGHT = 100;
  46451. - private static final int WIDTH = 50;
  46452. + private static final int HEIGHT = 100;
  46453. + private static final int WIDTH = 50;
  46454. - @Mock private Image mediaImage;
  46455. + @Mock
  46456. + private Image mediaImage;
  46457. - @Before
  46458. - public void setUp() {
  46459. - MockitoAnnotations.initMocks(this);
  46460. + @Before
  46461. + public void setUp() {
  46462. + MockitoAnnotations.initMocks(this);
  46463. - when(mediaImage.getHeight()).thenReturn(HEIGHT);
  46464. - when(mediaImage.getWidth()).thenReturn(WIDTH);
  46465. - when(mediaImage.getFormat()).thenReturn(ImageFormat.YUV_420_888);
  46466. - }
  46467. + when(mediaImage.getHeight()).thenReturn(HEIGHT);
  46468. + when(mediaImage.getWidth()).thenReturn(WIDTH);
  46469. + when(mediaImage.getFormat()).thenReturn(ImageFormat.YUV_420_888);
  46470. + }
  46471. - @Test
  46472. - public void extract_fromMediaMlImage_succeeds() {
  46473. - MlImage image = new MediaMlImageBuilder(mediaImage).build();
  46474. - Image extractedMediaImage = MediaImageExtractor.extract(image);
  46475. + @Test
  46476. + public void extract_fromMediaMlImage_succeeds() {
  46477. + MlImage image = new MediaMlImageBuilder(mediaImage).build();
  46478. + Image extractedMediaImage = MediaImageExtractor.extract(image);
  46479. - assertThat(extractedMediaImage).isSameInstanceAs(image);
  46480. - }
  46481. + assertThat(extractedMediaImage).isSameInstanceAs(image);
  46482. + }
  46483. - @Test
  46484. - public void extract_fromBitmapMlImage_throwsException() {
  46485. - MlImage image =
  46486. - new BitmapMlImageBuilder(
  46487. + @Test
  46488. + public void extract_fromBitmapMlImage_throwsException() {
  46489. + MlImage image = new BitmapMlImageBuilder(
  46490. Bitmap.createBitmap(/* width= */ 20, /* height= */ 25, Config.ARGB_8888))
  46491. - .build();
  46492. - assertThrows(IllegalArgumentException.class, () -> MediaImageExtractor.extract(image));
  46493. - }
  46494. + .build();
  46495. + assertThrows(IllegalArgumentException.class, () -> MediaImageExtractor.extract(image));
  46496. + }
  46497. }
  46498. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/MediaMlImageBuilderTest.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/MediaMlImageBuilderTest.java
  46499. index 4f589874bfaf8..60397feceb067 100644
  46500. --- a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/MediaMlImageBuilderTest.java
  46501. +++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/MediaMlImageBuilderTest.java
  46502. @@ -16,12 +16,14 @@ limitations under the License.
  46503. package com.google.android.odml.image;
  46504. import static com.google.common.truth.Truth.assertThat;
  46505. +
  46506. import static org.junit.Assert.assertThrows;
  46507. import static org.mockito.Mockito.when;
  46508. import android.graphics.ImageFormat;
  46509. import android.graphics.Rect;
  46510. import android.media.Image;
  46511. +
  46512. import org.junit.Before;
  46513. import org.junit.Test;
  46514. import org.junit.runner.RunWith;
  46515. @@ -32,58 +34,57 @@ import org.robolectric.RobolectricTestRunner;
  46516. /** Tests for {@link MediaMlImageBuilder} */
  46517. @RunWith(RobolectricTestRunner.class)
  46518. public final class MediaMlImageBuilderTest {
  46519. - private static final int HEIGHT = 100;
  46520. - private static final int WIDTH = 50;
  46521. -
  46522. - @Mock private Image mediaImage;
  46523. -
  46524. - @Before
  46525. - public void setUp() {
  46526. - MockitoAnnotations.initMocks(this);
  46527. -
  46528. - when(mediaImage.getHeight()).thenReturn(HEIGHT);
  46529. - when(mediaImage.getWidth()).thenReturn(WIDTH);
  46530. - when(mediaImage.getFormat()).thenReturn(ImageFormat.YUV_420_888);
  46531. - }
  46532. -
  46533. - @Test
  46534. - public void build_fromMediaImage_succeeds() {
  46535. - MlImage image = new MediaMlImageBuilder(mediaImage).build();
  46536. - ImageContainer container = image.getContainer(MlImage.STORAGE_TYPE_MEDIA_IMAGE);
  46537. -
  46538. - assertThat(image.getWidth()).isEqualTo(WIDTH);
  46539. - assertThat(image.getHeight()).isEqualTo(HEIGHT);
  46540. - assertThat(image.getRoi()).isEqualTo(new Rect(0, 0, WIDTH, HEIGHT));
  46541. - assertThat(image.getRotation()).isEqualTo(0);
  46542. - assertThat(image.getTimestamp()).isAtLeast(0);
  46543. - assertThat(image.getContainedImageProperties())
  46544. - .containsExactly(
  46545. - ImageProperties.builder()
  46546. - .setStorageType(MlImage.STORAGE_TYPE_MEDIA_IMAGE)
  46547. - .setImageFormat(MlImage.IMAGE_FORMAT_YUV_420_888)
  46548. - .build());
  46549. - assertThat(((MediaImageContainer) container).getImage().getFormat())
  46550. - .isEqualTo(ImageFormat.YUV_420_888);
  46551. - }
  46552. -
  46553. - @Test
  46554. - public void build_withOptionalProperties_succeeds() {
  46555. - MlImage image =
  46556. - new MediaMlImageBuilder(mediaImage)
  46557. - .setTimestamp(12345)
  46558. - .setRoi(new Rect(0, 5, 10, 15))
  46559. - .setRotation(90)
  46560. - .build();
  46561. -
  46562. - assertThat(image.getTimestamp()).isEqualTo(12345);
  46563. - assertThat(image.getRotation()).isEqualTo(90);
  46564. - assertThat(image.getRoi()).isEqualTo(new Rect(0, 5, 10, 15));
  46565. - }
  46566. -
  46567. - @Test
  46568. - public void build_withInvalidRotation_throwsException() {
  46569. - MediaMlImageBuilder builder = new MediaMlImageBuilder(mediaImage);
  46570. -
  46571. - assertThrows(IllegalArgumentException.class, () -> builder.setRotation(360));
  46572. - }
  46573. + private static final int HEIGHT = 100;
  46574. + private static final int WIDTH = 50;
  46575. +
  46576. + @Mock
  46577. + private Image mediaImage;
  46578. +
  46579. + @Before
  46580. + public void setUp() {
  46581. + MockitoAnnotations.initMocks(this);
  46582. +
  46583. + when(mediaImage.getHeight()).thenReturn(HEIGHT);
  46584. + when(mediaImage.getWidth()).thenReturn(WIDTH);
  46585. + when(mediaImage.getFormat()).thenReturn(ImageFormat.YUV_420_888);
  46586. + }
  46587. +
  46588. + @Test
  46589. + public void build_fromMediaImage_succeeds() {
  46590. + MlImage image = new MediaMlImageBuilder(mediaImage).build();
  46591. + ImageContainer container = image.getContainer(MlImage.STORAGE_TYPE_MEDIA_IMAGE);
  46592. +
  46593. + assertThat(image.getWidth()).isEqualTo(WIDTH);
  46594. + assertThat(image.getHeight()).isEqualTo(HEIGHT);
  46595. + assertThat(image.getRoi()).isEqualTo(new Rect(0, 0, WIDTH, HEIGHT));
  46596. + assertThat(image.getRotation()).isEqualTo(0);
  46597. + assertThat(image.getTimestamp()).isAtLeast(0);
  46598. + assertThat(image.getContainedImageProperties())
  46599. + .containsExactly(ImageProperties.builder()
  46600. + .setStorageType(MlImage.STORAGE_TYPE_MEDIA_IMAGE)
  46601. + .setImageFormat(MlImage.IMAGE_FORMAT_YUV_420_888)
  46602. + .build());
  46603. + assertThat(((MediaImageContainer) container).getImage().getFormat())
  46604. + .isEqualTo(ImageFormat.YUV_420_888);
  46605. + }
  46606. +
  46607. + @Test
  46608. + public void build_withOptionalProperties_succeeds() {
  46609. + MlImage image = new MediaMlImageBuilder(mediaImage)
  46610. + .setTimestamp(12345)
  46611. + .setRoi(new Rect(0, 5, 10, 15))
  46612. + .setRotation(90)
  46613. + .build();
  46614. +
  46615. + assertThat(image.getTimestamp()).isEqualTo(12345);
  46616. + assertThat(image.getRotation()).isEqualTo(90);
  46617. + assertThat(image.getRoi()).isEqualTo(new Rect(0, 5, 10, 15));
  46618. + }
  46619. +
  46620. + @Test
  46621. + public void build_withInvalidRotation_throwsException() {
  46622. + MediaMlImageBuilder builder = new MediaMlImageBuilder(mediaImage);
  46623. +
  46624. + assertThrows(IllegalArgumentException.class, () -> builder.setRotation(360));
  46625. + }
  46626. }
  46627. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/TestImageCreator.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/TestImageCreator.java
  46628. index c9e7134bedd93..28f54be2c70a3 100644
  46629. --- a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/TestImageCreator.java
  46630. +++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/TestImageCreator.java
  46631. @@ -17,6 +17,7 @@ package com.google.android.odml.image;
  46632. import android.graphics.Bitmap;
  46633. import android.graphics.Color;
  46634. +
  46635. import java.nio.ByteBuffer;
  46636. /**
  46637. @@ -35,113 +36,113 @@ import java.nio.ByteBuffer;
  46638. * <p>The created {@link Bitmap} is not pre-multiplied.
  46639. */
  46640. final class TestImageCreator {
  46641. + private static final int RED = 0x73;
  46642. + private static final int GREEN = 0x85;
  46643. + private static final int BLUE = 0x96;
  46644. + private static final int ALPHA = 0x70;
  46645. +
  46646. + static int getWidth() {
  46647. + return 10;
  46648. + }
  46649. +
  46650. + static int getHeight() {
  46651. + return 2;
  46652. + }
  46653. +
  46654. + /**
  46655. + * Creates an example non-pre-multiplied bitmap which is 100% opaque.
  46656. + *
  46657. + * @see TestImageCreator for details.
  46658. + */
  46659. + static Bitmap createOpaqueRgbaBitmap() {
  46660. + return createRgbaBitmap(0xff);
  46661. + }
  46662. +
  46663. + /**
  46664. + * Creates an example non-pre-multiplied bitmap which has non-trivial alpha channel.
  46665. + *
  46666. + * @see TestImageCreator for details.
  46667. + */
  46668. + static Bitmap createRgbaBitmap() {
  46669. + return createRgbaBitmap(ALPHA);
  46670. + }
  46671. - private static final int RED = 0x73;
  46672. - private static final int GREEN = 0x85;
  46673. - private static final int BLUE = 0x96;
  46674. - private static final int ALPHA = 0x70;
  46675. -
  46676. - static int getWidth() {
  46677. - return 10;
  46678. - }
  46679. -
  46680. - static int getHeight() {
  46681. - return 2;
  46682. - }
  46683. -
  46684. - /**
  46685. - * Creates an example non-pre-multiplied bitmap which is 100% opaque.
  46686. - *
  46687. - * @see TestImageCreator for details.
  46688. - */
  46689. - static Bitmap createOpaqueRgbaBitmap() {
  46690. - return createRgbaBitmap(0xff);
  46691. - }
  46692. -
  46693. - /**
  46694. - * Creates an example non-pre-multiplied bitmap which has non-trivial alpha channel.
  46695. - *
  46696. - * @see TestImageCreator for details.
  46697. - */
  46698. - static Bitmap createRgbaBitmap() {
  46699. - return createRgbaBitmap(ALPHA);
  46700. - }
  46701. -
  46702. - /**
  46703. - * Creates an example 10x2 bitmap demonstrated in the class doc. A channel sets to {@code alpha}.
  46704. - */
  46705. - static Bitmap createRgbaBitmap(int alpha) {
  46706. - int[] colors = new int[20];
  46707. - for (int i = 0; i < 5; i++) {
  46708. - colors[i] = Color.argb(alpha, 0, 0, BLUE);
  46709. - colors[i + 5] = Color.argb(alpha, 0xff, 0xff, 0xff);
  46710. - colors[i + 10] = Color.argb(alpha, 0, GREEN, 0);
  46711. - colors[i + 15] = Color.argb(alpha, RED, 0, 0);
  46712. + /**
  46713. + * Creates an example 10x2 bitmap demonstrated in the class doc. A channel sets to {@code
  46714. + * alpha}.
  46715. + */
  46716. + static Bitmap createRgbaBitmap(int alpha) {
  46717. + int[] colors = new int[20];
  46718. + for (int i = 0; i < 5; i++) {
  46719. + colors[i] = Color.argb(alpha, 0, 0, BLUE);
  46720. + colors[i + 5] = Color.argb(alpha, 0xff, 0xff, 0xff);
  46721. + colors[i + 10] = Color.argb(alpha, 0, GREEN, 0);
  46722. + colors[i + 15] = Color.argb(alpha, RED, 0, 0);
  46723. + }
  46724. + // We don't use Bitmap#createBitmap(int[] ...) here, because that method creates
  46725. + // pre-multiplied bitmaps.
  46726. + Bitmap bitmap = Bitmap.createBitmap(10, 2, Bitmap.Config.ARGB_8888);
  46727. + bitmap.setPremultiplied(false);
  46728. + bitmap.setPixels(colors, 0, 10, 0, 0, 10, 2);
  46729. + return bitmap;
  46730. }
  46731. - // We don't use Bitmap#createBitmap(int[] ...) here, because that method creates pre-multiplied
  46732. - // bitmaps.
  46733. - Bitmap bitmap = Bitmap.createBitmap(10, 2, Bitmap.Config.ARGB_8888);
  46734. - bitmap.setPremultiplied(false);
  46735. - bitmap.setPixels(colors, 0, 10, 0, 0, 10, 2);
  46736. - return bitmap;
  46737. - }
  46738. -
  46739. - /**
  46740. - * Creates an example 10*10*3 bytebuffer in R-G-B format.
  46741. - *
  46742. - * @see TestImageCreator for details.
  46743. - */
  46744. - static ByteBuffer createRgbBuffer() {
  46745. - return createRgbOrRgbaBuffer(false, 0xff);
  46746. - }
  46747. -
  46748. - /**
  46749. - * Creates an example 10*10*4 bytebuffer in R-G-B-A format.
  46750. - *
  46751. - * @see TestImageCreator for details.
  46752. - */
  46753. - static ByteBuffer createRgbaBuffer() {
  46754. - return createRgbOrRgbaBuffer(true, ALPHA);
  46755. - }
  46756. -
  46757. - /**
  46758. - * Creates an example 10*10*4 bytebuffer in R-G-B-A format, but the A channel is 0xFF.
  46759. - *
  46760. - * @see TestImageCreator for details.
  46761. - */
  46762. - static ByteBuffer createOpaqueRgbaBuffer() {
  46763. - return createRgbOrRgbaBuffer(true, 0xff);
  46764. - }
  46765. -
  46766. - /**
  46767. - * Creates an example 10x2x4 (or 10x2x3 if no alpha) bytebuffer demonstrated in the class doc.
  46768. - *
  46769. - * @param withAlpha if true, set A to {@code alpha}, otherwise A channel is ignored.
  46770. - * @param alpha alpha channel value. Only effective when {@code withAlpha} is {@code true}.
  46771. - */
  46772. - static ByteBuffer createRgbOrRgbaBuffer(boolean withAlpha, int alpha) {
  46773. - int capacity = withAlpha ? 80 : 60;
  46774. - ByteBuffer buffer = ByteBuffer.allocateDirect(capacity);
  46775. - putColorInByteBuffer(buffer, 0, 0, BLUE, withAlpha, alpha, 5);
  46776. - putColorInByteBuffer(buffer, 0xff, 0xff, 0xff, withAlpha, alpha, 5);
  46777. - putColorInByteBuffer(buffer, 0, GREEN, 0, withAlpha, alpha, 5);
  46778. - putColorInByteBuffer(buffer, RED, 0, 0, withAlpha, alpha, 5);
  46779. - buffer.rewind();
  46780. - return buffer;
  46781. - }
  46782. -
  46783. - private static void putColorInByteBuffer(
  46784. - ByteBuffer buffer, int r, int g, int b, boolean withAlpha, int alpha, int num) {
  46785. - for (int i = 0; i < num; i++) {
  46786. - buffer.put((byte) r);
  46787. - buffer.put((byte) g);
  46788. - buffer.put((byte) b);
  46789. - if (withAlpha) {
  46790. - buffer.put((byte) alpha);
  46791. - }
  46792. +
  46793. + /**
  46794. + * Creates an example 10*10*3 bytebuffer in R-G-B format.
  46795. + *
  46796. + * @see TestImageCreator for details.
  46797. + */
  46798. + static ByteBuffer createRgbBuffer() {
  46799. + return createRgbOrRgbaBuffer(false, 0xff);
  46800. + }
  46801. +
  46802. + /**
  46803. + * Creates an example 10*10*4 bytebuffer in R-G-B-A format.
  46804. + *
  46805. + * @see TestImageCreator for details.
  46806. + */
  46807. + static ByteBuffer createRgbaBuffer() {
  46808. + return createRgbOrRgbaBuffer(true, ALPHA);
  46809. + }
  46810. +
  46811. + /**
  46812. + * Creates an example 10*10*4 bytebuffer in R-G-B-A format, but the A channel is 0xFF.
  46813. + *
  46814. + * @see TestImageCreator for details.
  46815. + */
  46816. + static ByteBuffer createOpaqueRgbaBuffer() {
  46817. + return createRgbOrRgbaBuffer(true, 0xff);
  46818. + }
  46819. +
  46820. + /**
  46821. + * Creates an example 10x2x4 (or 10x2x3 if no alpha) bytebuffer demonstrated in the class doc.
  46822. + *
  46823. + * @param withAlpha if true, set A to {@code alpha}, otherwise A channel is ignored.
  46824. + * @param alpha alpha channel value. Only effective when {@code withAlpha} is {@code true}.
  46825. + */
  46826. + static ByteBuffer createRgbOrRgbaBuffer(boolean withAlpha, int alpha) {
  46827. + int capacity = withAlpha ? 80 : 60;
  46828. + ByteBuffer buffer = ByteBuffer.allocateDirect(capacity);
  46829. + putColorInByteBuffer(buffer, 0, 0, BLUE, withAlpha, alpha, 5);
  46830. + putColorInByteBuffer(buffer, 0xff, 0xff, 0xff, withAlpha, alpha, 5);
  46831. + putColorInByteBuffer(buffer, 0, GREEN, 0, withAlpha, alpha, 5);
  46832. + putColorInByteBuffer(buffer, RED, 0, 0, withAlpha, alpha, 5);
  46833. + buffer.rewind();
  46834. + return buffer;
  46835. + }
  46836. +
  46837. + private static void putColorInByteBuffer(
  46838. + ByteBuffer buffer, int r, int g, int b, boolean withAlpha, int alpha, int num) {
  46839. + for (int i = 0; i < num; i++) {
  46840. + buffer.put((byte) r);
  46841. + buffer.put((byte) g);
  46842. + buffer.put((byte) b);
  46843. + if (withAlpha) {
  46844. + buffer.put((byte) alpha);
  46845. + }
  46846. + }
  46847. }
  46848. - }
  46849. - // Should not be instantiated.
  46850. - private TestImageCreator() {}
  46851. + // Should not be instantiated.
  46852. + private TestImageCreator() {}
  46853. }
  46854. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/task/audio/core/pybinds/_pywrap_audio_buffer.cc b/third_party/tflite_support/src/tensorflow_lite_support/python/task/audio/core/pybinds/_pywrap_audio_buffer.cc
  46855. index b46a997c4e254..c5e317d8a82c0 100644
  46856. --- a/third_party/tflite_support/src/tensorflow_lite_support/python/task/audio/core/pybinds/_pywrap_audio_buffer.cc
  46857. +++ b/third_party/tflite_support/src/tensorflow_lite_support/python/task/audio/core/pybinds/_pywrap_audio_buffer.cc
  46858. @@ -39,16 +39,15 @@ PYBIND11_MODULE(_pywrap_audio_buffer, m) {
  46859. .def_readonly("sample_rate", &AudioBuffer::AudioFormat::sample_rate);
  46860. py::class_<AudioBuffer>(m, "AudioBuffer", py::buffer_protocol())
  46861. - .def(py::init([](
  46862. - py::buffer buffer, const int sample_count,
  46863. - const AudioBuffer::AudioFormat& audio_format)
  46864. - -> std::unique_ptr<AudioBuffer> {
  46865. - py::buffer_info info = buffer.request();
  46866. + .def(py::init([](py::buffer buffer, const int sample_count,
  46867. + const AudioBuffer::AudioFormat& audio_format)
  46868. + -> std::unique_ptr<AudioBuffer> {
  46869. + py::buffer_info info = buffer.request();
  46870. - auto audio_buffer = AudioBuffer::Create(
  46871. - static_cast<float*>(info.ptr), sample_count, audio_format);
  46872. - return core::get_value(audio_buffer);
  46873. - }))
  46874. + auto audio_buffer = AudioBuffer::Create(static_cast<float*>(info.ptr),
  46875. + sample_count, audio_format);
  46876. + return core::get_value(audio_buffer);
  46877. + }))
  46878. .def_property_readonly("audio_format", &AudioBuffer::GetAudioFormat)
  46879. .def_property_readonly("buffer_size", &AudioBuffer::GetBufferSize)
  46880. .def_property_readonly("float_buffer", [](AudioBuffer& self) {
  46881. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/task/audio/pybinds/_pywrap_audio_classifier.cc b/third_party/tflite_support/src/tensorflow_lite_support/python/task/audio/pybinds/_pywrap_audio_classifier.cc
  46882. index 5d94db2a01b37..e2054cf645c08 100644
  46883. --- a/third_party/tflite_support/src/tensorflow_lite_support/python/task/audio/pybinds/_pywrap_audio_classifier.cc
  46884. +++ b/third_party/tflite_support/src/tensorflow_lite_support/python/task/audio/pybinds/_pywrap_audio_classifier.cc
  46885. @@ -20,7 +20,6 @@ limitations under the License.
  46886. #include "tensorflow_lite_support/cc/task/audio/proto/classifications_proto_inc.h"
  46887. #include "tensorflow_lite_support/cc/task/processor/proto/classification_options.pb.h"
  46888. #include "tensorflow_lite_support/cc/task/processor/proto/classifications.pb.h"
  46889. -#include "tensorflow_lite_support/cc/task/processor/proto/classifications.pb.h"
  46890. #include "tensorflow_lite_support/python/task/core/pybinds/task_utils.h"
  46891. namespace tflite {
  46892. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/task/audio/pybinds/_pywrap_audio_embedder.cc b/third_party/tflite_support/src/tensorflow_lite_support/python/task/audio/pybinds/_pywrap_audio_embedder.cc
  46893. index 50e0b4f7ce4a8..8b1d67d9f8e05 100644
  46894. --- a/third_party/tflite_support/src/tensorflow_lite_support/python/task/audio/pybinds/_pywrap_audio_embedder.cc
  46895. +++ b/third_party/tflite_support/src/tensorflow_lite_support/python/task/audio/pybinds/_pywrap_audio_embedder.cc
  46896. @@ -15,9 +15,9 @@ limitations under the License.
  46897. #include "pybind11/pybind11.h"
  46898. #include "pybind11_protobuf/native_proto_caster.h" // from @pybind11_protobuf
  46899. -#include "tensorflow_lite_support/cc/task/processor/proto/embedding.pb.h"
  46900. #include "tensorflow_lite_support/cc/task/audio/audio_embedder.h"
  46901. #include "tensorflow_lite_support/cc/task/audio/core/audio_buffer.h"
  46902. +#include "tensorflow_lite_support/cc/task/processor/proto/embedding.pb.h"
  46903. #include "tensorflow_lite_support/python/task/core/pybinds/task_utils.h"
  46904. namespace tflite {
  46905. @@ -50,17 +50,17 @@ PYBIND11_MODULE(_pywrap_audio_embedder, m) {
  46906. return core::get_value(embedder);
  46907. })
  46908. .def_static("cosine_similarity",
  46909. - [](const processor::FeatureVector& u,
  46910. - const processor::FeatureVector& v) -> double {
  46911. - auto similarity = AudioEmbedder::CosineSimilarity(u, v);
  46912. - return core::get_value(similarity);
  46913. - })
  46914. + [](const processor::FeatureVector& u,
  46915. + const processor::FeatureVector& v) -> double {
  46916. + auto similarity = AudioEmbedder::CosineSimilarity(u, v);
  46917. + return core::get_value(similarity);
  46918. + })
  46919. .def("embed",
  46920. - [](AudioEmbedder& self,
  46921. - const AudioBuffer& audio_buffer) -> processor::EmbeddingResult {
  46922. - auto embedding_result = self.Embed(audio_buffer);
  46923. - return core::get_value(embedding_result);
  46924. - })
  46925. + [](AudioEmbedder& self,
  46926. + const AudioBuffer& audio_buffer) -> processor::EmbeddingResult {
  46927. + auto embedding_result = self.Embed(audio_buffer);
  46928. + return core::get_value(embedding_result);
  46929. + })
  46930. .def("get_embedding_dimension", &AudioEmbedder::GetEmbeddingDimension)
  46931. .def("get_number_of_output_layers",
  46932. &AudioEmbedder::GetNumberOfOutputLayers)
  46933. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/core/pybinds/image_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/core/pybinds/image_utils.cc
  46934. index 977b4e16175ac..124f5cb1ad15d 100644
  46935. --- a/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/core/pybinds/image_utils.cc
  46936. +++ b/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/core/pybinds/image_utils.cc
  46937. @@ -43,13 +43,13 @@ PYBIND11_MODULE(image_utils, m) {
  46938. int width = info.shape[1];
  46939. int channels = info.ndim == 3 ? info.shape[2] : 1;
  46940. - return ImageData{static_cast<uint8 *>(info.ptr), width, height,
  46941. + return ImageData{static_cast<uint8*>(info.ptr), width, height,
  46942. channels};
  46943. }))
  46944. .def_readonly("width", &ImageData::width)
  46945. .def_readonly("height", &ImageData::height)
  46946. .def_readonly("channels", &ImageData::channels)
  46947. - .def_buffer([](ImageData &data) -> py::buffer_info {
  46948. + .def_buffer([](ImageData& data) -> py::buffer_info {
  46949. return py::buffer_info(
  46950. data.pixel_data, sizeof(uint8),
  46951. py::format_descriptor<uint8>::format(), 3,
  46952. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/pybinds/_pywrap_image_classifier.cc b/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/pybinds/_pywrap_image_classifier.cc
  46953. index 4ca20a363345e..b4f23baa6e0b1 100644
  46954. --- a/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/pybinds/_pywrap_image_classifier.cc
  46955. +++ b/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/pybinds/_pywrap_image_classifier.cc
  46956. @@ -67,17 +67,17 @@ PYBIND11_MODULE(_pywrap_image_classifier, m) {
  46957. return core::get_value(classifier);
  46958. })
  46959. .def("classify",
  46960. - [](ImageClassifier& self, const ImageData& image_data)
  46961. - -> processor::ClassificationResult {
  46962. + [](ImageClassifier& self,
  46963. + const ImageData& image_data) -> processor::ClassificationResult {
  46964. auto frame_buffer = CreateFrameBufferFromImageData(image_data);
  46965. - auto vision_classification_result = self.Classify(
  46966. - *core::get_value(frame_buffer));
  46967. + auto vision_classification_result =
  46968. + self.Classify(*core::get_value(frame_buffer));
  46969. // Convert from vision::ClassificationResult to
  46970. // processor::ClassificationResult as required by the Python layer.
  46971. processor::ClassificationResult classification_result;
  46972. - classification_result.ParseFromString(
  46973. + classification_result.ParseFromString(
  46974. core::get_value(vision_classification_result)
  46975. - .SerializeAsString());
  46976. + .SerializeAsString());
  46977. return classification_result;
  46978. })
  46979. .def("classify",
  46980. @@ -96,9 +96,9 @@ PYBIND11_MODULE(_pywrap_image_classifier, m) {
  46981. // Convert from vision::ClassificationResult to
  46982. // processor::ClassificationResult as required by the Python layer.
  46983. processor::ClassificationResult classification_result;
  46984. - classification_result.ParseFromString(
  46985. + classification_result.ParseFromString(
  46986. core::get_value(vision_classification_result)
  46987. - .SerializeAsString());
  46988. + .SerializeAsString());
  46989. return classification_result;
  46990. });
  46991. }
  46992. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/pybinds/_pywrap_image_segmenter.cc b/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/pybinds/_pywrap_image_segmenter.cc
  46993. index 3ebf09fb4f284..e71048e9ebb0b 100644
  46994. --- a/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/pybinds/_pywrap_image_segmenter.cc
  46995. +++ b/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/pybinds/_pywrap_image_segmenter.cc
  46996. @@ -47,23 +47,23 @@ PYBIND11_MODULE(_pywrap_image_segmenter, m) {
  46997. if (segmentation_options.has_display_names_locale()) {
  46998. options.set_display_names_locale(
  46999. - segmentation_options.display_names_locale());
  47000. + segmentation_options.display_names_locale());
  47001. }
  47002. if (segmentation_options.has_output_type()) {
  47003. options.set_output_type(
  47004. static_cast<ImageSegmenterOptions::OutputType>(
  47005. - segmentation_options.output_type()));
  47006. + segmentation_options.output_type()));
  47007. }
  47008. auto segmenter = ImageSegmenter::CreateFromOptions(options);
  47009. return core::get_value(segmenter);
  47010. })
  47011. .def("segment",
  47012. - [](ImageSegmenter& self, const ImageData& image_data)
  47013. - -> SegmentationResult {
  47014. + [](ImageSegmenter& self,
  47015. + const ImageData& image_data) -> SegmentationResult {
  47016. auto frame_buffer = CreateFrameBufferFromImageData(image_data);
  47017. - auto vision_segmentation_result = self.Segment(
  47018. - *core::get_value(frame_buffer));
  47019. + auto vision_segmentation_result =
  47020. + self.Segment(*core::get_value(frame_buffer));
  47021. return core::get_value(vision_segmentation_result);
  47022. });
  47023. }
  47024. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/pybinds/_pywrap_object_detector.cc b/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/pybinds/_pywrap_object_detector.cc
  47025. index 39e39c9df00e1..3749efc811019 100644
  47026. --- a/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/pybinds/_pywrap_object_detector.cc
  47027. +++ b/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/pybinds/_pywrap_object_detector.cc
  47028. @@ -65,17 +65,16 @@ PYBIND11_MODULE(_pywrap_object_detector, m) {
  47029. return core::get_value(detector);
  47030. })
  47031. .def("detect",
  47032. - [](ObjectDetector& self, const ImageData& image_data)
  47033. - -> processor::DetectionResult {
  47034. + [](ObjectDetector& self,
  47035. + const ImageData& image_data) -> processor::DetectionResult {
  47036. auto frame_buffer = CreateFrameBufferFromImageData(image_data);
  47037. - auto vision_detection_result = self.Detect(
  47038. - *core::get_value(frame_buffer));
  47039. + auto vision_detection_result =
  47040. + self.Detect(*core::get_value(frame_buffer));
  47041. // Convert from vision::DetectionResult to
  47042. // processor::DetectionResult as required by the Python layer.
  47043. processor::DetectionResult detection_result;
  47044. - detection_result.ParseFromString(
  47045. - core::get_value(vision_detection_result)
  47046. - .SerializeAsString());
  47047. + detection_result.ParseFromString(
  47048. + core::get_value(vision_detection_result).SerializeAsString());
  47049. return detection_result;
  47050. });
  47051. }
  47052. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/index_table_sum.h b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/index_table_sum.h
  47053. index 89c96e7d5e50a..67e0e303d4231 100644
  47054. --- a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/index_table_sum.h
  47055. +++ b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/index_table_sum.h
  47056. @@ -29,7 +29,9 @@ namespace scann_ondevice {
  47057. namespace core {
  47058. template <typename LutType>
  47059. -void RearrangeLUT(const LutType* input_data, int batch_elems, int batch_size,
  47060. +void RearrangeLUT(const LutType* input_data,
  47061. + int batch_elems,
  47062. + int batch_size,
  47063. LutType* const output_data) {
  47064. std::vector<int64_t> simd_sizes;
  47065. if (std::is_same<LutType, float>::value) {
  47066. @@ -88,10 +90,15 @@ struct MaxQuantizationValue<uint16_t> {
  47067. };
  47068. template <typename SimdType, typename LutType, size_t NumCenters = 0>
  47069. -size_t IndexTableSumSimdBatch(const uint8_t* indices, size_t num_chunks,
  47070. - size_t num_outputs, const LutType* lookup_table,
  47071. - size_t batch_size, size_t num_centers, float min,
  47072. - float max, size_t batch_index,
  47073. +size_t IndexTableSumSimdBatch(const uint8_t* indices,
  47074. + size_t num_chunks,
  47075. + size_t num_outputs,
  47076. + const LutType* lookup_table,
  47077. + size_t batch_size,
  47078. + size_t num_centers,
  47079. + float min,
  47080. + float max,
  47081. + size_t batch_index,
  47082. float* const output) {
  47083. if (num_centers == 256) {
  47084. return IndexTableSumSimdBatch<SimdType, LutType, 256>(
  47085. @@ -176,9 +183,14 @@ size_t IndexTableSumSimdBatch(const uint8_t* indices, size_t num_chunks,
  47086. }
  47087. template <typename LutType>
  47088. -void IndexTableSum(const uint8_t* indices, size_t num_chunks,
  47089. - size_t num_outputs, const LutType* lookup_table,
  47090. - size_t batch_size, size_t num_centers, float min, float max,
  47091. +void IndexTableSum(const uint8_t* indices,
  47092. + size_t num_chunks,
  47093. + size_t num_outputs,
  47094. + const LutType* lookup_table,
  47095. + size_t batch_size,
  47096. + size_t num_centers,
  47097. + float min,
  47098. + float max,
  47099. float* const output) {
  47100. static_assert(std::is_same<LutType, uint8_t>::value ||
  47101. std::is_same<LutType, uint16_t>::value,
  47102. @@ -206,10 +218,15 @@ void IndexTableSum(const uint8_t* indices, size_t num_chunks,
  47103. }
  47104. template <>
  47105. -inline void IndexTableSum<float>(const uint8_t* indices, size_t num_chunks,
  47106. - size_t num_outputs, const float* lookup_table,
  47107. - size_t batch_size, size_t num_centers,
  47108. - float min, float max, float* const output) {
  47109. +inline void IndexTableSum<float>(const uint8_t* indices,
  47110. + size_t num_chunks,
  47111. + size_t num_outputs,
  47112. + const float* lookup_table,
  47113. + size_t batch_size,
  47114. + size_t num_centers,
  47115. + float min,
  47116. + float max,
  47117. + float* const output) {
  47118. std::fill(output, output + batch_size * num_outputs, 0.0f);
  47119. size_t i = 0;
  47120. #ifdef __AVX__
  47121. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/indexer.cc b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/indexer.cc
  47122. index ed17d7f1708f8..6df064553d2c5 100644
  47123. --- a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/indexer.cc
  47124. +++ b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/indexer.cc
  47125. @@ -31,7 +31,8 @@ namespace core {
  47126. namespace {
  47127. float ComputeSquaredL2Distance(Span<const float> a, Span<const float> b) {
  47128. - if (a.size() != b.size()) return 0;
  47129. + if (a.size() != b.size())
  47130. + return 0;
  47131. float result = 0;
  47132. for (int i = 0; i < a.size(); ++i) {
  47133. result += (a[i] - b[i]) * (a[i] - b[i]);
  47134. @@ -40,7 +41,8 @@ float ComputeSquaredL2Distance(Span<const float> a, Span<const float> b) {
  47135. }
  47136. float ComputeDotProductDistance(Span<const float> a, Span<const float> b) {
  47137. - if (a.size() != b.size()) return 0;
  47138. + if (a.size() != b.size())
  47139. + return 0;
  47140. float result = 0;
  47141. for (int i = 0; i < a.size(); ++i) {
  47142. result += a[i] * b[i];
  47143. @@ -62,7 +64,8 @@ AsymmetricHashingIndexer::AsymmetricHashingIndexer(
  47144. int subspace_index = 0;
  47145. for (const AsymmetricHashingProto::SubspaceCodebook& codebook :
  47146. ah_proto.subspace()) {
  47147. - if (codebook.entry().empty()) return;
  47148. + if (codebook.entry().empty())
  47149. + return;
  47150. const int dimension = codebook.entry(0).dimension_size();
  47151. const int num_codes = codebook.entry_size();
  47152. @@ -81,13 +84,17 @@ AsymmetricHashingIndexer::AsymmetricHashingIndexer(
  47153. }
  47154. total_dimension_ = 0;
  47155. - for (const uint8_t dim : dimensions_) total_dimension_ += dim;
  47156. + for (const uint8_t dim : dimensions_)
  47157. + total_dimension_ += dim;
  47158. }
  47159. void AsymmetricHashingIndexer::EncodeDatapoint(
  47160. - absl::Span<const float> original, absl::Span<uint8_t> encoded) const {
  47161. - if (original.size() != total_dimension_) return;
  47162. - if (encoded.size() != dimensions_.size()) return;
  47163. + absl::Span<const float> original,
  47164. + absl::Span<uint8_t> encoded) const {
  47165. + if (original.size() != total_dimension_)
  47166. + return;
  47167. + if (encoded.size() != dimensions_.size())
  47168. + return;
  47169. int start_index = 0;
  47170. for (int i = 0; i < dimensions_.size(); ++i) {
  47171. @@ -118,7 +125,8 @@ void AsymmetricHashingIndexer::EncodeDatapoint(
  47172. }
  47173. absl::Status AsymmetricHashingIndexer::DecodeDatapoint(
  47174. - absl::Span<const uint8_t> encoded, absl::Span<float> reconstructed) const {
  47175. + absl::Span<const uint8_t> encoded,
  47176. + absl::Span<float> reconstructed) const {
  47177. if (encoded.size() < dimensions_.size()) {
  47178. return absl::InvalidArgumentError("Mismatching dimensions");
  47179. }
  47180. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/indexer.h b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/indexer.h
  47181. index 0328a75837ba9..a0515667e8373 100644
  47182. --- a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/indexer.h
  47183. +++ b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/indexer.h
  47184. @@ -21,7 +21,7 @@ limitations under the License.
  47185. #include <string>
  47186. #include "absl/status/status.h" // from @com_google_absl
  47187. -#include "absl/types/span.h" // from @com_google_absl
  47188. +#include "absl/types/span.h" // from @com_google_absl
  47189. #include "tensorflow_lite_support/scann_ondevice/cc/core/serialized_searcher.pb.h"
  47190. namespace tflite {
  47191. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/indexer_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/indexer_test.cc
  47192. index 3ef7427d6e21a..fca3d8f3d21c9 100644
  47193. --- a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/indexer_test.cc
  47194. +++ b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/indexer_test.cc
  47195. @@ -92,7 +92,8 @@ TEST(IndexerTest, SquaredL2AsymmetricReconstruct1) {
  47196. indexer.EncodeDatapoint(datapoint, absl::MakeSpan(result));
  47197. vector<float> datapoint_recon(5, 0);
  47198. - SUPPORT_EXPECT_OK(indexer.DecodeDatapoint(result, absl::MakeSpan(datapoint_recon)));
  47199. + SUPPORT_EXPECT_OK(
  47200. + indexer.DecodeDatapoint(result, absl::MakeSpan(datapoint_recon)));
  47201. EXPECT_EQ(std::vector<float>({0.1, 0.2, -0.1, -0.2, -0.3}), datapoint_recon);
  47202. }
  47203. @@ -122,7 +123,8 @@ TEST(IndexerTest, SquaredL2AsymmetricReconstruct2) {
  47204. indexer.EncodeDatapoint(datapoint, absl::MakeSpan(result));
  47205. vector<float> datapoint_recon = {0.1, 0.2, -0.1, -0.2, -0.3};
  47206. - SUPPORT_EXPECT_OK(indexer.DecodeDatapoint(result, absl::MakeSpan(datapoint_recon)));
  47207. + SUPPORT_EXPECT_OK(
  47208. + indexer.DecodeDatapoint(result, absl::MakeSpan(datapoint_recon)));
  47209. EXPECT_EQ(std::vector<float>({0.9, 0.8, -0.3, -0.2, -0.1}), datapoint_recon);
  47210. }
  47211. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/partitioner.cc b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/partitioner.cc
  47212. index 3217c57c0e831..e86fd77cc3321 100644
  47213. --- a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/partitioner.cc
  47214. +++ b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/partitioner.cc
  47215. @@ -87,7 +87,9 @@ bool Partitioner::Partition(const Eigen::Ref<const Eigen::MatrixXf>& queries,
  47216. return true;
  47217. }
  47218. -int Partitioner::NumPartitions() const { return leaves_.rows(); }
  47219. +int Partitioner::NumPartitions() const {
  47220. + return leaves_.rows();
  47221. +}
  47222. bool NoOpPartitioner::Partition(
  47223. const Eigen::Ref<const Eigen::MatrixXf>& queries,
  47224. @@ -108,7 +110,9 @@ bool NoOpPartitioner::Partition(
  47225. return true;
  47226. }
  47227. -int NoOpPartitioner::NumPartitions() const { return 1; }
  47228. +int NoOpPartitioner::NumPartitions() const {
  47229. + return 1;
  47230. +}
  47231. } // namespace core
  47232. } // namespace scann_ondevice
  47233. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/partitioner.h b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/partitioner.h
  47234. index 2a1fb36e9f28e..f4e9eb9e34804 100644
  47235. --- a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/partitioner.h
  47236. +++ b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/partitioner.h
  47237. @@ -17,8 +17,8 @@ limitations under the License.
  47238. #include <utility>
  47239. +#include "Eigen/Core" // from @eigen
  47240. #include "absl/types/optional.h" // from @com_google_absl
  47241. -#include "Eigen/Core" // from @eigen
  47242. #include "tensorflow_lite_support/scann_ondevice/cc/core/serialized_searcher.pb.h"
  47243. namespace tflite {
  47244. @@ -45,7 +45,8 @@ class Partitioner : public PartitionerInterface {
  47245. }
  47246. private:
  47247. - Partitioner(Eigen::MatrixXf leaves, Eigen::VectorXf leaf_norms,
  47248. + Partitioner(Eigen::MatrixXf leaves,
  47249. + Eigen::VectorXf leaf_norms,
  47250. DistanceMeasure distance)
  47251. : leaves_(std::move(leaves)),
  47252. leaf_norms_(std::move(leaf_norms)),
  47253. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/searcher.h b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/searcher.h
  47254. index 9fab870790db6..419681b829b1d 100644
  47255. --- a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/searcher.h
  47256. +++ b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/searcher.h
  47257. @@ -22,8 +22,8 @@ limitations under the License.
  47258. #include <vector>
  47259. #include <glog/logging.h>
  47260. +#include "Eigen/Core" // from @eigen
  47261. #include "absl/types/span.h" // from @com_google_absl
  47262. -#include "Eigen/Core" // from @eigen
  47263. #include "tensorflow_lite_support/cc/port/integral_types.h"
  47264. #include "tensorflow_lite_support/scann_ondevice/cc/core/index_table_sum.h"
  47265. #include "tensorflow_lite_support/scann_ondevice/cc/core/processor.h"
  47266. @@ -47,7 +47,8 @@ void ComputeAHDistance(const QueryInfo& query_info,
  47267. template <class T>
  47268. bool AsymmetricHashFindNeighbors(const QueryInfo& query_info,
  47269. Eigen::Ref<const Matrix8u> database,
  47270. - size_t global_offset, absl::Span<T> topn) {
  47271. + size_t global_offset,
  47272. + absl::Span<T> topn) {
  47273. const int batch_size = query_info.query_lut->cols();
  47274. if (topn.size() != batch_size) {
  47275. return false;
  47276. @@ -67,7 +68,8 @@ template <class T>
  47277. bool AsymmetricHashFindNeighbors(Eigen::Ref<const Eigen::MatrixXf> queries,
  47278. const PreProcessorInterface& preprocessor,
  47279. Eigen::Ref<const Matrix8u> database,
  47280. - size_t global_offset, absl::Span<T> topn) {
  47281. + size_t global_offset,
  47282. + absl::Span<T> topn) {
  47283. if (queries.cols() != topn.size()) {
  47284. return false;
  47285. }
  47286. @@ -116,10 +118,12 @@ template <class T>
  47287. class AsymmetricHashLeafSearcherT : public SearcherInterfaceT<T> {
  47288. public:
  47289. static std::unique_ptr<AsymmetricHashLeafSearcherT<T>> Create(
  47290. - std::shared_ptr<QueryInfo::Matrix<uint8_t>> database, int global_offset,
  47291. + std::shared_ptr<QueryInfo::Matrix<uint8_t>> database,
  47292. + int global_offset,
  47293. std::shared_ptr<PreProcessorInterface> preprocessor);
  47294. static std::unique_ptr<AsymmetricHashLeafSearcherT<T>> Create(
  47295. - std::shared_ptr<QueryInfo::Matrix<uint8_t>> database, int global_offset,
  47296. + std::shared_ptr<QueryInfo::Matrix<uint8_t>> database,
  47297. + int global_offset,
  47298. std::shared_ptr<PreProcessorInterface> preprocessor,
  47299. size_t mini_batch_size);
  47300. bool FindNeighbors(const Eigen::Ref<const Eigen::MatrixXf>& queries,
  47301. @@ -128,7 +132,8 @@ class AsymmetricHashLeafSearcherT : public SearcherInterfaceT<T> {
  47302. private:
  47303. AsymmetricHashLeafSearcherT(
  47304. - std::shared_ptr<QueryInfo::Matrix<uint8_t>> database, int global_offset,
  47305. + std::shared_ptr<QueryInfo::Matrix<uint8_t>> database,
  47306. + int global_offset,
  47307. std::shared_ptr<PreProcessorInterface> preprocessor,
  47308. size_t mini_batch_size)
  47309. : database_(std::move(database)),
  47310. @@ -154,7 +159,8 @@ class LinearLeafSearcherT : public SearcherInterfaceT<T> {
  47311. private:
  47312. LinearLeafSearcherT(std::shared_ptr<Eigen::MatrixXf> database,
  47313. - DistanceMeasure distance_measure, int global_offset)
  47314. + DistanceMeasure distance_measure,
  47315. + int global_offset)
  47316. : database_(std::move(database)),
  47317. distance_measure_(distance_measure),
  47318. global_offset_(global_offset) {}
  47319. @@ -167,7 +173,8 @@ class LinearLeafSearcherT : public SearcherInterfaceT<T> {
  47320. template <class T>
  47321. std::unique_ptr<AsymmetricHashLeafSearcherT<T>>
  47322. AsymmetricHashLeafSearcherT<T>::Create(
  47323. - std::shared_ptr<Matrix8u> database, int global_offset,
  47324. + std::shared_ptr<Matrix8u> database,
  47325. + int global_offset,
  47326. std::shared_ptr<PreProcessorInterface> preprocessor) {
  47327. return AsymmetricHashLeafSearcherT<T>::Create(
  47328. database, global_offset, preprocessor,
  47329. @@ -177,7 +184,8 @@ AsymmetricHashLeafSearcherT<T>::Create(
  47330. template <class T>
  47331. std::unique_ptr<AsymmetricHashLeafSearcherT<T>>
  47332. AsymmetricHashLeafSearcherT<T>::Create(
  47333. - std::shared_ptr<Matrix8u> database, int global_offset,
  47334. + std::shared_ptr<Matrix8u> database,
  47335. + int global_offset,
  47336. std::shared_ptr<PreProcessorInterface> preprocessor,
  47337. size_t mini_batch_size) {
  47338. if (mini_batch_size == 0 || global_offset < 0) {
  47339. @@ -220,7 +228,8 @@ bool AsymmetricHashLeafSearcherT<T>::FindNeighbors(const QueryInfo& query_info,
  47340. template <class T>
  47341. std::unique_ptr<LinearLeafSearcherT<T>> LinearLeafSearcherT<T>::Create(
  47342. - std::shared_ptr<Eigen::MatrixXf> database, DistanceMeasure distance_measure,
  47343. + std::shared_ptr<Eigen::MatrixXf> database,
  47344. + DistanceMeasure distance_measure,
  47345. int global_offset) {
  47346. if (global_offset < 0) {
  47347. return nullptr;
  47348. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/searcher_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/searcher_test.cc
  47349. index 8c67bca0da939..f3931f3619b8d 100644
  47350. --- a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/searcher_test.cc
  47351. +++ b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/searcher_test.cc
  47352. @@ -21,8 +21,8 @@ limitations under the License.
  47353. #include <utility>
  47354. #include <glog/logging.h>
  47355. +#include "Eigen/Core" // from @eigen
  47356. #include "absl/synchronization/mutex.h" // from @com_google_absl
  47357. -#include "Eigen/Core" // from @eigen
  47358. #include "tensorflow_lite_support/cc/port/gmock.h"
  47359. #include "tensorflow_lite_support/cc/port/gtest.h"
  47360. #include "tensorflow_lite_support/cc/port/integral_types.h"
  47361. @@ -520,9 +520,10 @@ TEST_P(SearcherTest, AsymmetricHashMiniBatchedSimdFail) {
  47362. }
  47363. #endif
  47364. -INSTANTIATE_TEST_SUITE_P(SearcherTest, SearcherTest,
  47365. - Values(std::numeric_limits<size_t>::max(), 1, 2, 3, 7,
  47366. - 23));
  47367. +INSTANTIATE_TEST_SUITE_P(
  47368. + SearcherTest,
  47369. + SearcherTest,
  47370. + Values(std::numeric_limits<size_t>::max(), 1, 2, 3, 7, 23));
  47371. } // namespace
  47372. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/top_n_amortized_constant.h b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/top_n_amortized_constant.h
  47373. index 8f53ddf0669c4..3e5a6b00736d0 100644
  47374. --- a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/top_n_amortized_constant.h
  47375. +++ b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/top_n_amortized_constant.h
  47376. @@ -44,7 +44,8 @@ class TopNAmortizedConstant {
  47377. std::vector<T> TakeUnsorted() {
  47378. DCHECK_GT(limit_, 0) << "Cannot call TakeUnsorted on uninitialized "
  47379. "TopNAmortizedConstant instance.";
  47380. - if (elements_.size() > limit_) PartitionAndResizeToLimit();
  47381. + if (elements_.size() > limit_)
  47382. + PartitionAndResizeToLimit();
  47383. auto result = std::move(elements_);
  47384. elements_.clear();
  47385. approx_bottom_ = original_approx_bottom_;
  47386. @@ -53,13 +54,15 @@ class TopNAmortizedConstant {
  47387. const std::vector<T>& ExtractUnsorted() {
  47388. DCHECK_GT(limit_, 0) << "Cannot call ExtractUnsorted on uninitialized "
  47389. "TopNAmortizedConstant instance.";
  47390. - if (elements_.size() > limit_) PartitionAndResizeToLimit();
  47391. + if (elements_.size() > limit_)
  47392. + PartitionAndResizeToLimit();
  47393. return elements_;
  47394. }
  47395. std::vector<T> Take() {
  47396. DCHECK_GT(limit_, 0) << "Cannot call Take on uninitialized "
  47397. "TopNAmortizedConstant instance.";
  47398. - if (elements_.size() > limit_) PartitionAndResizeToLimit();
  47399. + if (elements_.size() > limit_)
  47400. + PartitionAndResizeToLimit();
  47401. std::sort(elements_.begin(), elements_.end(), cmp_);
  47402. auto result = std::move(elements_);
  47403. elements_.clear();
  47404. @@ -100,7 +103,8 @@ struct Comparator {
  47405. const std::pair<float, int>& b) const {
  47406. return a.first < b.first;
  47407. }
  47408. - bool operator()(float distance, int,
  47409. + bool operator()(float distance,
  47410. + int,
  47411. const std::pair<float, int>& other) const {
  47412. return distance < other.first;
  47413. }
  47414. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/index.cc b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/index.cc
  47415. index 8e45119d7364d..e8be5f6572f17 100644
  47416. --- a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/index.cc
  47417. +++ b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/index.cc
  47418. @@ -18,17 +18,17 @@ limitations under the License.
  47419. #include <cstddef>
  47420. #include <memory>
  47421. -#include "absl/memory/memory.h" // from @com_google_absl
  47422. -#include "absl/status/status.h" // from @com_google_absl
  47423. -#include "absl/status/statusor.h" // from @com_google_absl
  47424. -#include "absl/strings/str_format.h" // from @com_google_absl
  47425. +#include "absl/memory/memory.h" // from @com_google_absl
  47426. +#include "absl/status/status.h" // from @com_google_absl
  47427. +#include "absl/status/statusor.h" // from @com_google_absl
  47428. +#include "absl/strings/str_format.h" // from @com_google_absl
  47429. #include "absl/strings/string_view.h" // from @com_google_absl
  47430. -#include "leveldb/cache.h" // from @com_google_leveldb
  47431. -#include "leveldb/iterator.h" // from @com_google_leveldb
  47432. -#include "leveldb/options.h" // from @com_google_leveldb
  47433. -#include "leveldb/slice.h" // from @com_google_leveldb
  47434. -#include "leveldb/status.h" // from @com_google_leveldb
  47435. -#include "leveldb/table.h" // from @com_google_leveldb
  47436. +#include "leveldb/cache.h" // from @com_google_leveldb
  47437. +#include "leveldb/iterator.h" // from @com_google_leveldb
  47438. +#include "leveldb/options.h" // from @com_google_leveldb
  47439. +#include "leveldb/slice.h" // from @com_google_leveldb
  47440. +#include "leveldb/status.h" // from @com_google_leveldb
  47441. +#include "leveldb/table.h" // from @com_google_leveldb
  47442. #include "tensorflow_lite_support/cc/port/status_macros.h"
  47443. #include "tensorflow_lite_support/scann_ondevice/cc/mem_random_access_file.h"
  47444. #include "tensorflow_lite_support/scann_ondevice/cc/utils.h"
  47445. @@ -60,7 +60,8 @@ absl::StatusOr<absl::string_view> GetValueForKey(leveldb::Iterator* iterator,
  47446. /* static */
  47447. absl::StatusOr<std::unique_ptr<Index>> Index::CreateFromIndexBuffer(
  47448. - const char* buffer_data, size_t buffer_size) {
  47449. + const char* buffer_data,
  47450. + size_t buffer_size) {
  47451. // Use absl::WrapUnique() to call private constructor:
  47452. // https://abseil.io/tips/126.
  47453. std::unique_ptr<Index> index = absl::WrapUnique(new Index());
  47454. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/index.h b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/index.h
  47455. index c630e6f827caa..15e709183a606 100644
  47456. --- a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/index.h
  47457. +++ b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/index.h
  47458. @@ -18,12 +18,12 @@ limitations under the License.
  47459. #include <memory>
  47460. -#include "absl/status/status.h" // from @com_google_absl
  47461. -#include "absl/status/statusor.h" // from @com_google_absl
  47462. +#include "absl/status/status.h" // from @com_google_absl
  47463. +#include "absl/status/statusor.h" // from @com_google_absl
  47464. #include "absl/strings/string_view.h" // from @com_google_absl
  47465. -#include "leveldb/cache.h" // from @com_google_leveldb
  47466. -#include "leveldb/iterator.h" // from @com_google_leveldb
  47467. -#include "leveldb/table.h" // from @com_google_leveldb
  47468. +#include "leveldb/cache.h" // from @com_google_leveldb
  47469. +#include "leveldb/iterator.h" // from @com_google_leveldb
  47470. +#include "leveldb/table.h" // from @com_google_leveldb
  47471. #include "tensorflow_lite_support/scann_ondevice/cc/mem_random_access_file.h"
  47472. #include "tensorflow_lite_support/scann_ondevice/proto/index_config.pb.h"
  47473. @@ -43,7 +43,8 @@ class Index {
  47474. // Warning: Does not take ownership of the provided buffer, which must outlive
  47475. // this object.
  47476. static absl::StatusOr<std::unique_ptr<Index>> CreateFromIndexBuffer(
  47477. - const char* buffer_data, size_t buffer_size);
  47478. + const char* buffer_data,
  47479. + size_t buffer_size);
  47480. // Parses and returns the `IndexConfig` stored in the index file.
  47481. absl::StatusOr<IndexConfig> GetIndexConfig() const;
  47482. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/index_builder.cc b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/index_builder.cc
  47483. index fe5d1ef1175e4..0d802024c2b01 100644
  47484. --- a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/index_builder.cc
  47485. +++ b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/index_builder.cc
  47486. @@ -21,13 +21,13 @@ limitations under the License.
  47487. #include <vector>
  47488. #include "absl/container/btree_map.h" // from @com_google_absl
  47489. -#include "absl/status/status.h" // from @com_google_absl
  47490. -#include "absl/strings/str_format.h" // from @com_google_absl
  47491. -#include "leveldb/options.h" // from @com_google_leveldb
  47492. -#include "leveldb/slice.h" // from @com_google_leveldb
  47493. -#include "leveldb/status.h" // from @com_google_leveldb
  47494. -#include "leveldb/table_builder.h" // from @com_google_leveldb
  47495. -#include "leveldb/write_batch.h" // from @com_google_leveldb
  47496. +#include "absl/status/status.h" // from @com_google_absl
  47497. +#include "absl/strings/str_format.h" // from @com_google_absl
  47498. +#include "leveldb/options.h" // from @com_google_leveldb
  47499. +#include "leveldb/slice.h" // from @com_google_leveldb
  47500. +#include "leveldb/status.h" // from @com_google_leveldb
  47501. +#include "leveldb/table_builder.h" // from @com_google_leveldb
  47502. +#include "leveldb/write_batch.h" // from @com_google_leveldb
  47503. #include "tensorflow_lite_support/cc/port/status_macros.h"
  47504. #include "tensorflow_lite_support/scann_ondevice/cc/mem_writable_file.h"
  47505. #include "tensorflow_lite_support/scann_ondevice/cc/utils.h"
  47506. @@ -56,8 +56,10 @@ template <typename T>
  47507. absl::StatusOr<std::string> CreateIndexBufferImpl(
  47508. absl::Span<const T> database,
  47509. absl::optional<absl::Span<const uint32_t>> partition_assignment,
  47510. - absl::Span<const std::string> metadata, const std::string& userinfo,
  47511. - IndexConfig index_config, bool compression) {
  47512. + absl::Span<const std::string> metadata,
  47513. + const std::string& userinfo,
  47514. + IndexConfig index_config,
  47515. + bool compression) {
  47516. size_t num_partitions = 1;
  47517. if (partition_assignment) {
  47518. if (partition_assignment->size() != metadata.size()) {
  47519. @@ -145,8 +147,8 @@ absl::StatusOr<std::string> CreateIndexBufferImpl(
  47520. } // namespace
  47521. -absl::StatusOr<std::string> CreateIndexBuffer(
  47522. - const IndexedArtifacts& artifacts, bool compression) {
  47523. +absl::StatusOr<std::string> CreateIndexBuffer(const IndexedArtifacts& artifacts,
  47524. + bool compression) {
  47525. if (artifacts.hashed_database.has_value() &&
  47526. artifacts.float_database.has_value()) {
  47527. return absl::InvalidArgumentError(
  47528. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/index_builder.h b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/index_builder.h
  47529. index e8f8f06220578..53cac9b583da4 100644
  47530. --- a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/index_builder.h
  47531. +++ b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/index_builder.h
  47532. @@ -16,12 +16,12 @@ limitations under the License.
  47533. #ifndef TENSORFLOW_LITE_SUPPORT_SCANN_ONDEVICE_CC_INDEX_FILE_MUTATOR_H_
  47534. #define TENSORFLOW_LITE_SUPPORT_SCANN_ONDEVICE_CC_INDEX_FILE_MUTATOR_H_
  47535. -#include "tensorflow_lite_support/scann_ondevice/cc/core/serialized_searcher.pb.h"
  47536. -#include "absl/status/statusor.h" // from @com_google_absl
  47537. +#include "absl/status/statusor.h" // from @com_google_absl
  47538. #include "absl/strings/string_view.h" // from @com_google_absl
  47539. -#include "absl/types/optional.h" // from @com_google_absl
  47540. -#include "absl/types/span.h" // from @com_google_absl
  47541. -#include "leveldb/db.h" // from @com_google_leveldb
  47542. +#include "absl/types/optional.h" // from @com_google_absl
  47543. +#include "absl/types/span.h" // from @com_google_absl
  47544. +#include "leveldb/db.h" // from @com_google_leveldb
  47545. +#include "tensorflow_lite_support/scann_ondevice/cc/core/serialized_searcher.pb.h"
  47546. #include "tensorflow_lite_support/scann_ondevice/proto/index_config.pb.h"
  47547. namespace tflite {
  47548. @@ -60,8 +60,8 @@ struct IndexedArtifacts {
  47549. // Creates a byte buffer for the index file from the artifacts. Returns errors
  47550. // when there are not exactly one database specified, or other issues with input
  47551. // such as shape mismatch, invalid partition indices etc.
  47552. -absl::StatusOr<std::string> CreateIndexBuffer(
  47553. - const IndexedArtifacts& artifacts, bool compression);
  47554. +absl::StatusOr<std::string> CreateIndexBuffer(const IndexedArtifacts& artifacts,
  47555. + bool compression);
  47556. } // namespace scann_ondevice
  47557. } // namespace tflite
  47558. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/mem_random_access_file.cc b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/mem_random_access_file.cc
  47559. index 7be71b90ef91d..59b9deb8e8682 100644
  47560. --- a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/mem_random_access_file.cc
  47561. +++ b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/mem_random_access_file.cc
  47562. @@ -19,8 +19,8 @@ limitations under the License.
  47563. #include <cstddef>
  47564. #include <cstdint>
  47565. -#include "leveldb/env.h" // from @com_google_leveldb
  47566. -#include "leveldb/slice.h" // from @com_google_leveldb
  47567. +#include "leveldb/env.h" // from @com_google_leveldb
  47568. +#include "leveldb/slice.h" // from @com_google_leveldb
  47569. #include "leveldb/status.h" // from @com_google_leveldb
  47570. namespace tflite {
  47571. @@ -32,7 +32,8 @@ MemRandomAccessFile::MemRandomAccessFile(const char* buffer_data,
  47572. MemRandomAccessFile::~MemRandomAccessFile() {}
  47573. -leveldb::Status MemRandomAccessFile::Read(uint64_t offset, size_t n,
  47574. +leveldb::Status MemRandomAccessFile::Read(uint64_t offset,
  47575. + size_t n,
  47576. leveldb::Slice* result,
  47577. char* scratch) const {
  47578. // Sanity check.
  47579. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/mem_random_access_file.h b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/mem_random_access_file.h
  47580. index 0cf9cbfed59f4..5ca68f2e2c91e 100644
  47581. --- a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/mem_random_access_file.h
  47582. +++ b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/mem_random_access_file.h
  47583. @@ -19,8 +19,8 @@ limitations under the License.
  47584. #include <cstddef>
  47585. #include <cstdint>
  47586. -#include "leveldb/env.h" // from @com_google_leveldb
  47587. -#include "leveldb/slice.h" // from @com_google_leveldb
  47588. +#include "leveldb/env.h" // from @com_google_leveldb
  47589. +#include "leveldb/slice.h" // from @com_google_leveldb
  47590. #include "leveldb/status.h" // from @com_google_leveldb
  47591. namespace tflite {
  47592. @@ -39,7 +39,9 @@ class MemRandomAccessFile : public leveldb::RandomAccessFile {
  47593. // Override of the `Read` function. Note that `scratch` is unused in the
  47594. // implementation.
  47595. - leveldb::Status Read(uint64_t offset, size_t n, leveldb::Slice* result,
  47596. + leveldb::Status Read(uint64_t offset,
  47597. + size_t n,
  47598. + leveldb::Slice* result,
  47599. char* scratch) const override;
  47600. // Class is movable and non-copyable.
  47601. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/mem_writable_file.h b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/mem_writable_file.h
  47602. index bb346bc7f12dc..842e837927d4e 100644
  47603. --- a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/mem_writable_file.h
  47604. +++ b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/mem_writable_file.h
  47605. @@ -20,10 +20,10 @@ limitations under the License.
  47606. #include <string>
  47607. #include "absl/status/statusor.h" // from @com_google_absl
  47608. -#include "absl/strings/cord.h" // from @com_google_absl
  47609. -#include "leveldb/env.h" // from @com_google_leveldb
  47610. -#include "leveldb/slice.h" // from @com_google_leveldb
  47611. -#include "leveldb/status.h" // from @com_google_leveldb
  47612. +#include "absl/strings/cord.h" // from @com_google_absl
  47613. +#include "leveldb/env.h" // from @com_google_leveldb
  47614. +#include "leveldb/slice.h" // from @com_google_leveldb
  47615. +#include "leveldb/status.h" // from @com_google_leveldb
  47616. namespace tflite {
  47617. namespace scann_ondevice {
  47618. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/python/index_builder_py_wrapper.cc b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/python/index_builder_py_wrapper.cc
  47619. index da147af88bc2a..709564035ff1f 100644
  47620. --- a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/python/index_builder_py_wrapper.cc
  47621. +++ b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/python/index_builder_py_wrapper.cc
  47622. @@ -15,14 +15,14 @@ limitations under the License.
  47623. #include <string>
  47624. -#include "tensorflow_lite_support/scann_ondevice/cc/core/serialized_searcher.pb.h"
  47625. #include "absl/types/optional.h" // from @com_google_absl
  47626. -#include "absl/types/span.h" // from @com_google_absl
  47627. +#include "absl/types/span.h" // from @com_google_absl
  47628. #include "pybind11/cast.h"
  47629. #include "pybind11/pybind11.h"
  47630. #include "pybind11/pytypes.h"
  47631. -#include "pybind11_abseil/absl_casters.h" // from @pybind11_abseil
  47632. +#include "pybind11_abseil/absl_casters.h" // from @pybind11_abseil
  47633. #include "pybind11_abseil/status_casters.h" // from @pybind11_abseil
  47634. +#include "tensorflow_lite_support/scann_ondevice/cc/core/serialized_searcher.pb.h"
  47635. #include "tensorflow_lite_support/scann_ondevice/cc/index_builder.h"
  47636. namespace pybind11 {
  47637. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/test/index_builder_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/test/index_builder_test.cc
  47638. index 07da739f4a888..a1af840cc2f14 100644
  47639. --- a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/test/index_builder_test.cc
  47640. +++ b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/test/index_builder_test.cc
  47641. @@ -18,18 +18,18 @@ limitations under the License.
  47642. #include <cstdint>
  47643. #include <string>
  47644. -#include "absl/flags/flag.h" // from @com_google_absl
  47645. -#include "absl/memory/memory.h" // from @com_google_absl
  47646. -#include "absl/status/status.h" // from @com_google_absl
  47647. -#include "absl/strings/str_format.h" // from @com_google_absl
  47648. +#include "absl/flags/flag.h" // from @com_google_absl
  47649. +#include "absl/memory/memory.h" // from @com_google_absl
  47650. +#include "absl/status/status.h" // from @com_google_absl
  47651. +#include "absl/strings/str_format.h" // from @com_google_absl
  47652. #include "absl/strings/string_view.h" // from @com_google_absl
  47653. -#include "absl/types/span.h" // from @com_google_absl
  47654. -#include "leveldb/env.h" // from @com_google_leveldb
  47655. -#include "leveldb/iterator.h" // from @com_google_leveldb
  47656. -#include "leveldb/options.h" // from @com_google_leveldb
  47657. -#include "leveldb/slice.h" // from @com_google_leveldb
  47658. -#include "leveldb/status.h" // from @com_google_leveldb
  47659. -#include "leveldb/table.h" // from @com_google_leveldb
  47660. +#include "absl/types/span.h" // from @com_google_absl
  47661. +#include "leveldb/env.h" // from @com_google_leveldb
  47662. +#include "leveldb/iterator.h" // from @com_google_leveldb
  47663. +#include "leveldb/options.h" // from @com_google_leveldb
  47664. +#include "leveldb/slice.h" // from @com_google_leveldb
  47665. +#include "leveldb/status.h" // from @com_google_leveldb
  47666. +#include "leveldb/table.h" // from @com_google_leveldb
  47667. #include "tensorflow_lite_support/cc/port/gmock.h"
  47668. #include "tensorflow_lite_support/cc/port/gtest.h"
  47669. #include "tensorflow_lite_support/cc/port/status_matchers.h"
  47670. @@ -137,22 +137,23 @@ TEST_P(PopulateIndexFileTest, WritesHashedDatabaseWithPartitioner) {
  47671. {
  47672. tflite::scann_ondevice::core::ScannOnDeviceConfig config =
  47673. - ParseTextProtoOrDie<tflite::scann_ondevice::core::ScannOnDeviceConfig>(R"pb(
  47674. - partitioner: {
  47675. - leaf { dimension: 0 dimension: 0 }
  47676. - leaf { dimension: 1 dimension: 1 }
  47677. - leaf { dimension: 2 dimension: 2 }
  47678. - leaf { dimension: 3 dimension: 3 }
  47679. - leaf { dimension: 4 dimension: 4 }
  47680. - leaf { dimension: 5 dimension: 5 }
  47681. - leaf { dimension: 6 dimension: 6 }
  47682. - leaf { dimension: 7 dimension: 7 }
  47683. - leaf { dimension: 8 dimension: 8 }
  47684. - leaf { dimension: 9 dimension: 9 }
  47685. - leaf { dimension: 10 dimension: 10 }
  47686. - leaf { dimension: 11 dimension: 11 }
  47687. - }
  47688. - )pb");
  47689. + ParseTextProtoOrDie<tflite::scann_ondevice::core::ScannOnDeviceConfig>(
  47690. + R"pb(
  47691. + partitioner: {
  47692. + leaf { dimension: 0 dimension: 0 }
  47693. + leaf { dimension: 1 dimension: 1 }
  47694. + leaf { dimension: 2 dimension: 2 }
  47695. + leaf { dimension: 3 dimension: 3 }
  47696. + leaf { dimension: 4 dimension: 4 }
  47697. + leaf { dimension: 5 dimension: 5 }
  47698. + leaf { dimension: 6 dimension: 6 }
  47699. + leaf { dimension: 7 dimension: 7 }
  47700. + leaf { dimension: 8 dimension: 8 }
  47701. + leaf { dimension: 9 dimension: 9 }
  47702. + leaf { dimension: 10 dimension: 10 }
  47703. + leaf { dimension: 11 dimension: 11 }
  47704. + }
  47705. + )pb");
  47706. std::vector<uint8_t> hashed_database;
  47707. hashed_database.reserve(kNumEmbeddings * kDimensions);
  47708. for (int i = 0; i < kNumEmbeddings; ++i) {
  47709. @@ -202,16 +203,18 @@ TEST_P(PopulateIndexFileTest, WritesHashedDatabaseWithPartitioner) {
  47710. auto hashed_table_iterator =
  47711. absl::WrapUnique(hashed_table->NewIterator(leveldb::ReadOptions()));
  47712. - SUPPORT_ASSERT_OK_AND_ASSIGN(std::string serialized_config,
  47713. - LookupKey(hashed_table_iterator.get(), "INDEX_CONFIG"));
  47714. + SUPPORT_ASSERT_OK_AND_ASSIGN(
  47715. + std::string serialized_config,
  47716. + LookupKey(hashed_table_iterator.get(), "INDEX_CONFIG"));
  47717. IndexConfig index_config;
  47718. EXPECT_TRUE(index_config.ParseFromString(serialized_config));
  47719. EXPECT_THAT(
  47720. index_config,
  47721. EqualsProto(CreateExpectedConfigWithPartitioner(IndexConfig::UINT8)));
  47722. - SUPPORT_ASSERT_OK_AND_ASSIGN(std::string userinfo,
  47723. - LookupKey(hashed_table_iterator.get(), "USER_INFO"));
  47724. + SUPPORT_ASSERT_OK_AND_ASSIGN(
  47725. + std::string userinfo,
  47726. + LookupKey(hashed_table_iterator.get(), "USER_INFO"));
  47727. EXPECT_EQ(userinfo, "hashed_userinfo");
  47728. // Partition assignment is based on i % kNumPartitions, so:
  47729. @@ -253,9 +256,10 @@ TEST_P(PopulateIndexFileTest, WritesHashedDatabaseWithoutPartitioner) {
  47730. {
  47731. tflite::scann_ondevice::core::ScannOnDeviceConfig config =
  47732. - ParseTextProtoOrDie<tflite::scann_ondevice::core::ScannOnDeviceConfig>(R"pb(
  47733. - query_distance: SQUARED_L2_DISTANCE
  47734. - )pb");
  47735. + ParseTextProtoOrDie<tflite::scann_ondevice::core::ScannOnDeviceConfig>(
  47736. + R"pb(
  47737. + query_distance: SQUARED_L2_DISTANCE
  47738. + )pb");
  47739. std::vector<uint8_t> hashed_database;
  47740. hashed_database.reserve(kNumEmbeddings * kDimensions);
  47741. for (int i = 0; i < kNumEmbeddings; ++i) {
  47742. @@ -299,22 +303,23 @@ TEST_P(PopulateIndexFileTest, WritesHashedDatabaseWithoutPartitioner) {
  47743. auto float_table_iterator =
  47744. absl::WrapUnique(float_table->NewIterator(leveldb::ReadOptions()));
  47745. - SUPPORT_ASSERT_OK_AND_ASSIGN(std::string serialized_config,
  47746. - LookupKey(float_table_iterator.get(), "INDEX_CONFIG"));
  47747. + SUPPORT_ASSERT_OK_AND_ASSIGN(
  47748. + std::string serialized_config,
  47749. + LookupKey(float_table_iterator.get(), "INDEX_CONFIG"));
  47750. IndexConfig index_config;
  47751. EXPECT_TRUE(index_config.ParseFromString(serialized_config));
  47752. EXPECT_THAT(
  47753. index_config,
  47754. EqualsProto(CreateExpectedConfigWithoutPartitioner(IndexConfig::UINT8)));
  47755. - SUPPORT_ASSERT_OK_AND_ASSIGN(std::string userinfo,
  47756. - LookupKey(float_table_iterator.get(), "USER_INFO"));
  47757. + SUPPORT_ASSERT_OK_AND_ASSIGN(
  47758. + std::string userinfo, LookupKey(float_table_iterator.get(), "USER_INFO"));
  47759. EXPECT_EQ(userinfo, "hashed_userinfo");
  47760. // Check that the unique embedding partition has the exact same contents as
  47761. // the database used at construction time.
  47762. SUPPORT_ASSERT_OK_AND_ASSIGN(std::string raw_partition_hashed,
  47763. - LookupKey(float_table_iterator.get(), "E_0"));
  47764. + LookupKey(float_table_iterator.get(), "E_0"));
  47765. std::vector<char> hashed_partition(raw_partition_hashed.begin(),
  47766. raw_partition_hashed.end());
  47767. std::vector<char> expected;
  47768. @@ -342,22 +347,23 @@ TEST_P(PopulateIndexFileTest, WritesFloatDatabaseWithPartitioner) {
  47769. {
  47770. tflite::scann_ondevice::core::ScannOnDeviceConfig config =
  47771. - ParseTextProtoOrDie<tflite::scann_ondevice::core::ScannOnDeviceConfig>(R"pb(
  47772. - partitioner: {
  47773. - leaf { dimension: 0 dimension: 0 }
  47774. - leaf { dimension: 1 dimension: 1 }
  47775. - leaf { dimension: 2 dimension: 2 }
  47776. - leaf { dimension: 3 dimension: 3 }
  47777. - leaf { dimension: 4 dimension: 4 }
  47778. - leaf { dimension: 5 dimension: 5 }
  47779. - leaf { dimension: 6 dimension: 6 }
  47780. - leaf { dimension: 7 dimension: 7 }
  47781. - leaf { dimension: 8 dimension: 8 }
  47782. - leaf { dimension: 9 dimension: 9 }
  47783. - leaf { dimension: 10 dimension: 10 }
  47784. - leaf { dimension: 11 dimension: 11 }
  47785. - }
  47786. - )pb");
  47787. + ParseTextProtoOrDie<tflite::scann_ondevice::core::ScannOnDeviceConfig>(
  47788. + R"pb(
  47789. + partitioner: {
  47790. + leaf { dimension: 0 dimension: 0 }
  47791. + leaf { dimension: 1 dimension: 1 }
  47792. + leaf { dimension: 2 dimension: 2 }
  47793. + leaf { dimension: 3 dimension: 3 }
  47794. + leaf { dimension: 4 dimension: 4 }
  47795. + leaf { dimension: 5 dimension: 5 }
  47796. + leaf { dimension: 6 dimension: 6 }
  47797. + leaf { dimension: 7 dimension: 7 }
  47798. + leaf { dimension: 8 dimension: 8 }
  47799. + leaf { dimension: 9 dimension: 9 }
  47800. + leaf { dimension: 10 dimension: 10 }
  47801. + leaf { dimension: 11 dimension: 11 }
  47802. + }
  47803. + )pb");
  47804. std::vector<float> float_database;
  47805. float_database.reserve(kNumEmbeddings * kDimensions);
  47806. for (int i = 0; i < kNumEmbeddings; ++i) {
  47807. @@ -407,16 +413,17 @@ TEST_P(PopulateIndexFileTest, WritesFloatDatabaseWithPartitioner) {
  47808. auto float_table_iterator =
  47809. absl::WrapUnique(float_table->NewIterator(leveldb::ReadOptions()));
  47810. - SUPPORT_ASSERT_OK_AND_ASSIGN(std::string serialized_config,
  47811. - LookupKey(float_table_iterator.get(), "INDEX_CONFIG"));
  47812. + SUPPORT_ASSERT_OK_AND_ASSIGN(
  47813. + std::string serialized_config,
  47814. + LookupKey(float_table_iterator.get(), "INDEX_CONFIG"));
  47815. IndexConfig index_config;
  47816. EXPECT_TRUE(index_config.ParseFromString(serialized_config));
  47817. EXPECT_THAT(
  47818. index_config,
  47819. EqualsProto(CreateExpectedConfigWithPartitioner(IndexConfig::FLOAT)));
  47820. - SUPPORT_ASSERT_OK_AND_ASSIGN(std::string userinfo,
  47821. - LookupKey(float_table_iterator.get(), "USER_INFO"));
  47822. + SUPPORT_ASSERT_OK_AND_ASSIGN(
  47823. + std::string userinfo, LookupKey(float_table_iterator.get(), "USER_INFO"));
  47824. EXPECT_EQ(userinfo, "float_userinfo");
  47825. // Partition assignment is based on i % kNumPartitions, so:
  47826. @@ -461,9 +468,10 @@ TEST_P(PopulateIndexFileTest, WritesFloatDatabaseWithoutPartitioner) {
  47827. {
  47828. tflite::scann_ondevice::core::ScannOnDeviceConfig config =
  47829. - ParseTextProtoOrDie<tflite::scann_ondevice::core::ScannOnDeviceConfig>(R"pb(
  47830. - query_distance: SQUARED_L2_DISTANCE
  47831. - )pb");
  47832. + ParseTextProtoOrDie<tflite::scann_ondevice::core::ScannOnDeviceConfig>(
  47833. + R"pb(
  47834. + query_distance: SQUARED_L2_DISTANCE
  47835. + )pb");
  47836. std::vector<float> float_database;
  47837. float_database.reserve(kNumEmbeddings * kDimensions);
  47838. for (int i = 0; i < kNumEmbeddings; ++i) {
  47839. @@ -506,22 +514,23 @@ TEST_P(PopulateIndexFileTest, WritesFloatDatabaseWithoutPartitioner) {
  47840. auto float_table_iterator =
  47841. absl::WrapUnique(float_table->NewIterator(leveldb::ReadOptions()));
  47842. - SUPPORT_ASSERT_OK_AND_ASSIGN(std::string serialized_config,
  47843. - LookupKey(float_table_iterator.get(), "INDEX_CONFIG"));
  47844. + SUPPORT_ASSERT_OK_AND_ASSIGN(
  47845. + std::string serialized_config,
  47846. + LookupKey(float_table_iterator.get(), "INDEX_CONFIG"));
  47847. IndexConfig index_config;
  47848. EXPECT_TRUE(index_config.ParseFromString(serialized_config));
  47849. EXPECT_THAT(
  47850. index_config,
  47851. EqualsProto(CreateExpectedConfigWithoutPartitioner(IndexConfig::FLOAT)));
  47852. - SUPPORT_ASSERT_OK_AND_ASSIGN(std::string userinfo,
  47853. - LookupKey(float_table_iterator.get(), "USER_INFO"));
  47854. + SUPPORT_ASSERT_OK_AND_ASSIGN(
  47855. + std::string userinfo, LookupKey(float_table_iterator.get(), "USER_INFO"));
  47856. EXPECT_EQ(userinfo, "float_userinfo");
  47857. // Check that the unique embedding partition has the exact same contents as
  47858. // the database used at construction time.
  47859. SUPPORT_ASSERT_OK_AND_ASSIGN(std::string raw_partition_float,
  47860. - LookupKey(float_table_iterator.get(), "E_0"));
  47861. + LookupKey(float_table_iterator.get(), "E_0"));
  47862. const float* raw_partition_float_ptr =
  47863. reinterpret_cast<const float*>(raw_partition_float.data());
  47864. std::vector<float> float_partition(
  47865. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/test/index_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/test/index_test.cc
  47866. index 983dd8d2bc8e8..cc1225f679f66 100644
  47867. --- a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/test/index_test.cc
  47868. +++ b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/test/index_test.cc
  47869. @@ -18,9 +18,8 @@ limitations under the License.
  47870. #include <cstdint>
  47871. #include <memory>
  47872. -#include "tensorflow_lite_support/scann_ondevice/cc/core/serialized_searcher.pb.h"
  47873. -#include "absl/flags/flag.h" // from @com_google_absl
  47874. -#include "absl/status/status.h" // from @com_google_absl
  47875. +#include "absl/flags/flag.h" // from @com_google_absl
  47876. +#include "absl/status/status.h" // from @com_google_absl
  47877. #include "absl/strings/string_view.h" // from @com_google_absl
  47878. #include "tensorflow/lite/core/shims/cc/shims_test_util.h"
  47879. #include "tensorflow_lite_support/cc/port/gmock.h"
  47880. @@ -29,6 +28,7 @@ limitations under the License.
  47881. #include "tensorflow_lite_support/cc/task/core/external_file_handler.h"
  47882. #include "tensorflow_lite_support/cc/task/core/proto/external_file_proto_inc.h"
  47883. #include "tensorflow_lite_support/cc/test/test_utils.h"
  47884. +#include "tensorflow_lite_support/scann_ondevice/cc/core/serialized_searcher.pb.h"
  47885. #include "tensorflow_lite_support/scann_ondevice/proto/index_config.pb.h"
  47886. namespace tflite {
  47887. @@ -47,10 +47,10 @@ constexpr char kDummyIndexPath[] =
  47888. TEST(CreateFromOptionsTest, Succeeds) {
  47889. // Load file in memory using ExternalFile.
  47890. ExternalFile file;
  47891. - file.set_file_name(
  47892. - JoinPath("./" /*test src dir*/, kDummyIndexPath));
  47893. - SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ExternalFileHandler> handler,
  47894. - ExternalFileHandler::CreateFromExternalFile(&file));
  47895. + file.set_file_name(JoinPath("./" /*test src dir*/, kDummyIndexPath));
  47896. + SUPPORT_ASSERT_OK_AND_ASSIGN(
  47897. + std::unique_ptr<ExternalFileHandler> handler,
  47898. + ExternalFileHandler::CreateFromExternalFile(&file));
  47899. absl::string_view file_contents = handler->GetFileContent();
  47900. SUPPORT_EXPECT_OK(
  47901. @@ -62,8 +62,7 @@ class IndexTest : public tflite_shims::testing::Test {
  47902. IndexTest() {
  47903. // Load file in memory using ExternalFile.
  47904. ExternalFile file;
  47905. - file.set_file_name(
  47906. - JoinPath("./" /*test src dir*/, kDummyIndexPath));
  47907. + file.set_file_name(JoinPath("./" /*test src dir*/, kDummyIndexPath));
  47908. handler_ = ExternalFileHandler::CreateFromExternalFile(&file).value();
  47909. absl::string_view file_contents = handler_->GetFileContent();
  47910. // Build index.
  47911. @@ -98,18 +97,18 @@ TEST_F(IndexTest, GetUserInfoSucceeds) {
  47912. TEST_F(IndexTest, GetPartitionAtIndexSucceeds) {
  47913. SUPPORT_ASSERT_OK_AND_ASSIGN(absl::string_view partition_0,
  47914. - index_->GetPartitionAtIndex(0));
  47915. + index_->GetPartitionAtIndex(0));
  47916. EXPECT_EQ(partition_0.size(), 8);
  47917. - const uint8_t *partition =
  47918. - reinterpret_cast<const uint8_t *>(partition_0.data());
  47919. + const uint8_t* partition =
  47920. + reinterpret_cast<const uint8_t*>(partition_0.data());
  47921. for (int i = 0; i < 8; ++i) {
  47922. EXPECT_EQ(partition[i], i);
  47923. }
  47924. SUPPORT_ASSERT_OK_AND_ASSIGN(absl::string_view partition_1,
  47925. - index_->GetPartitionAtIndex(1));
  47926. + index_->GetPartitionAtIndex(1));
  47927. EXPECT_EQ(partition_1.size(), 4);
  47928. - partition = reinterpret_cast<const uint8_t *>(partition_1.data());
  47929. + partition = reinterpret_cast<const uint8_t*>(partition_1.data());
  47930. for (int i = 0; i < 4; ++i) {
  47931. EXPECT_EQ(partition[i], i + 8);
  47932. }
  47933. @@ -122,15 +121,15 @@ TEST_F(IndexTest, GetPartitionAtIndexFailsOutOfBounds) {
  47934. TEST_F(IndexTest, GetMetadataAtIndexSucceeds) {
  47935. SUPPORT_ASSERT_OK_AND_ASSIGN(absl::string_view metadata_0,
  47936. - index_->GetMetadataAtIndex(0));
  47937. + index_->GetMetadataAtIndex(0));
  47938. EXPECT_EQ(metadata_0, "metadata_0");
  47939. SUPPORT_ASSERT_OK_AND_ASSIGN(absl::string_view metadata_1,
  47940. - index_->GetMetadataAtIndex(1));
  47941. + index_->GetMetadataAtIndex(1));
  47942. EXPECT_EQ(metadata_1, "metadata_1");
  47943. SUPPORT_ASSERT_OK_AND_ASSIGN(absl::string_view metadata_2,
  47944. - index_->GetMetadataAtIndex(2));
  47945. + index_->GetMetadataAtIndex(2));
  47946. EXPECT_EQ(metadata_2, "metadata_2");
  47947. }
  47948. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/test/mem_writable_file_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/test/mem_writable_file_test.cc
  47949. index 3230b34db05ba..afb55e5472161 100644
  47950. --- a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/test/mem_writable_file_test.cc
  47951. +++ b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/test/mem_writable_file_test.cc
  47952. @@ -26,7 +26,7 @@ namespace {
  47953. TEST(MemWritableFileTest, AppendsContent) {
  47954. std::string buffer;
  47955. SUPPORT_ASSERT_OK_AND_ASSIGN(auto mem_writable_file,
  47956. - MemWritableFile::Create(&buffer));
  47957. + MemWritableFile::Create(&buffer));
  47958. ASSERT_TRUE(mem_writable_file->Append("aaa").ok());
  47959. EXPECT_EQ(buffer, "aaa");
  47960. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/test/python/leveldb_testing_utils_py_wrapper.cc b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/test/python/leveldb_testing_utils_py_wrapper.cc
  47961. index ca364e06e7d1d..1ae7e0ce9ed09 100644
  47962. --- a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/test/python/leveldb_testing_utils_py_wrapper.cc
  47963. +++ b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/test/python/leveldb_testing_utils_py_wrapper.cc
  47964. @@ -16,17 +16,17 @@ limitations under the License.
  47965. #include <cstdint>
  47966. #include <vector>
  47967. -#include "absl/memory/memory.h" // from @com_google_absl
  47968. -#include "absl/status/status.h" // from @com_google_absl
  47969. -#include "absl/status/statusor.h" // from @com_google_absl
  47970. +#include "absl/memory/memory.h" // from @com_google_absl
  47971. +#include "absl/status/status.h" // from @com_google_absl
  47972. +#include "absl/status/statusor.h" // from @com_google_absl
  47973. #include "absl/strings/str_format.h" // from @com_google_absl
  47974. -#include "leveldb/env.h" // from @com_google_leveldb
  47975. -#include "leveldb/options.h" // from @com_google_leveldb
  47976. -#include "leveldb/table.h" // from @com_google_leveldb
  47977. +#include "leveldb/env.h" // from @com_google_leveldb
  47978. +#include "leveldb/options.h" // from @com_google_leveldb
  47979. +#include "leveldb/table.h" // from @com_google_leveldb
  47980. #include "pybind11/cast.h"
  47981. #include "pybind11/pybind11.h"
  47982. #include "pybind11/pytypes.h"
  47983. -#include "pybind11_abseil/absl_casters.h" // from @pybind11_abseil
  47984. +#include "pybind11_abseil/absl_casters.h" // from @pybind11_abseil
  47985. #include "pybind11_abseil/status_casters.h" // from @pybind11_abseil
  47986. namespace pybind11 {
  47987. diff --git a/third_party/tflite_support/src/third_party/fft2d/fft.h b/third_party/tflite_support/src/third_party/fft2d/fft.h
  47988. index 36d838b7f6280..35dbcc766c169 100644
  47989. --- a/third_party/tflite_support/src/third_party/fft2d/fft.h
  47990. +++ b/third_party/tflite_support/src/third_party/fft2d/fft.h
  47991. @@ -22,12 +22,12 @@ limitations under the License.
  47992. extern "C" {
  47993. #endif
  47994. -extern void cdft(int, int, double *, int *, double *);
  47995. -extern void rdft(int, int, double *, int *, double *);
  47996. -extern void ddct(int, int, double *, int *, double *);
  47997. -extern void ddst(int, int, double *, int *, double *);
  47998. -extern void dfct(int, double *, double *, int *, double *);
  47999. -extern void dfst(int, double *, double *, int *, double *);
  48000. +extern void cdft(int, int, double*, int*, double*);
  48001. +extern void rdft(int, int, double*, int*, double*);
  48002. +extern void ddct(int, int, double*, int*, double*);
  48003. +extern void ddst(int, int, double*, int*, double*);
  48004. +extern void dfct(int, double*, double*, int*, double*);
  48005. +extern void dfst(int, double*, double*, int*, double*);
  48006. #ifdef __cplusplus
  48007. }
  48008. diff --git a/third_party/tflite_support/src/third_party/fft2d/fft2d.h b/third_party/tflite_support/src/third_party/fft2d/fft2d.h
  48009. index d587b3b441ce2..d79441827d54c 100644
  48010. --- a/third_party/tflite_support/src/third_party/fft2d/fft2d.h
  48011. +++ b/third_party/tflite_support/src/third_party/fft2d/fft2d.h
  48012. @@ -22,12 +22,12 @@ limitations under the License.
  48013. extern "C" {
  48014. #endif
  48015. -extern void cdft2d(int, int, int, double **, double *, int *, double *);
  48016. -extern void rdft2d(int, int, int, double **, double *, int *, double *);
  48017. -extern void ddct2d(int, int, int, double **, double *, int *, double *);
  48018. -extern void ddst2d(int, int, int, double **, double *, int *, double *);
  48019. -extern void ddct8x8s(int isgn, double **a);
  48020. -extern void ddct16x16s(int isgn, double **a);
  48021. +extern void cdft2d(int, int, int, double**, double*, int*, double*);
  48022. +extern void rdft2d(int, int, int, double**, double*, int*, double*);
  48023. +extern void ddct2d(int, int, int, double**, double*, int*, double*);
  48024. +extern void ddst2d(int, int, int, double**, double*, int*, double*);
  48025. +extern void ddct8x8s(int isgn, double** a);
  48026. +extern void ddct16x16s(int isgn, double** a);
  48027. #ifdef __cplusplus
  48028. }
  48029. --
  48030. 2.36.1.124.g0e6072fb45-goog