123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982298329842985298629872988298929902991299229932994299529962997299829993000300130023003300430053006300730083009301030113012301330143015301630173018301930203021302230233024302530263027302830293030303130323033303430353036303730383039304030413042304330443045304630473048304930503051305230533054305530563057305830593060306130623063306430653066306730683069307030713072307330743075307630773078307930803081308230833084308530863087308830893090309130923093309430953096309730983099310031013102310331043105310631073108310931103111311231133114311531163117311831193120312131223123312431253126312731283129313031313132313331343135313631373138313931403141314231433144314531463147314831493150315131523153315431553156315731583159316031613162316331643165316631673168316931703171317231733174317531763177317831793180318131823183318431853186318731883189319031913192319331943195319631973198319932003201320232033204320532063207320832093210321132123213321432153216321732183219322032213222322332243225322632273228322932303231323232333234323532363237323832393240324132423243324432453246324732483249325032513252325332543255325632573258325932603261326232633264326532663267326832693270327132723273327432753276327732783279328032813282328332843285328632873288328932903291329232933294329532963297329832993300330133023303330433053306330733083309331033113312331333143315331633173318331933203321332233233324332533263327332833293330333133323333333433353336333733383339334033413342334333443345334633473348334933503351335233533354335533563357335833593360336133623363336433653366336733683369337033713372337333743375337633773378337933803381338233833384338533863387338833893390339133923393339433953396339733983399340034013402340334043405340634073408340934103411341234133414341534163417341834193420342134223423342434253426342734283429343034313432343334343435343634373438343934403441344234433444344534463447344834493450345134523453345434553456345734583459346034613462346334643465346634673468346934703471347234733474347534763477347834793480348134823483348434853486348734883489349034913492349334943495349634973498349935003501350235033504350535063507350835093510351135123513351435153516351735183519352035213522352335243525352635273528352935303531353235333534353535363537353835393540354135423543354435453546354735483549355035513552355335543555355635573558355935603561356235633564356535663567356835693570357135723573357435753576357735783579358035813582358335843585358635873588358935903591359235933594359535963597359835993600360136023603360436053606360736083609361036113612361336143615361636173618361936203621362236233624362536263627362836293630363136323633363436353636363736383639364036413642364336443645364636473648364936503651365236533654365536563657365836593660366136623663366436653666366736683669367036713672367336743675367636773678367936803681368236833684368536863687368836893690369136923693369436953696369736983699370037013702370337043705370637073708370937103711371237133714371537163717371837193720372137223723372437253726372737283729373037313732373337343735373637373738373937403741374237433744374537463747374837493750375137523753375437553756375737583759376037613762376337643765376637673768376937703771377237733774377537763777377837793780378137823783378437853786378737883789379037913792379337943795379637973798379938003801380238033804380538063807380838093810381138123813381438153816381738183819382038213822382338243825382638273828382938303831383238333834383538363837383838393840384138423843384438453846384738483849385038513852385338543855385638573858385938603861386238633864386538663867386838693870387138723873387438753876387738783879388038813882388338843885388638873888388938903891389238933894389538963897389838993900390139023903390439053906390739083909391039113912391339143915391639173918391939203921392239233924392539263927392839293930393139323933393439353936393739383939394039413942394339443945394639473948394939503951395239533954395539563957395839593960396139623963396439653966396739683969397039713972397339743975397639773978397939803981398239833984398539863987398839893990399139923993399439953996399739983999400040014002400340044005400640074008400940104011401240134014401540164017401840194020402140224023402440254026402740284029403040314032403340344035403640374038403940404041404240434044404540464047404840494050405140524053405440554056405740584059406040614062406340644065406640674068406940704071407240734074407540764077407840794080408140824083408440854086408740884089409040914092409340944095409640974098409941004101410241034104410541064107410841094110411141124113411441154116411741184119412041214122412341244125412641274128412941304131413241334134413541364137413841394140414141424143414441454146414741484149415041514152415341544155415641574158415941604161416241634164416541664167416841694170417141724173417441754176417741784179418041814182418341844185418641874188418941904191419241934194419541964197419841994200420142024203420442054206420742084209421042114212421342144215421642174218421942204221422242234224422542264227422842294230423142324233423442354236423742384239424042414242424342444245424642474248424942504251425242534254425542564257425842594260426142624263426442654266426742684269427042714272427342744275427642774278427942804281428242834284428542864287428842894290429142924293429442954296429742984299430043014302430343044305430643074308430943104311431243134314431543164317431843194320432143224323432443254326432743284329433043314332433343344335433643374338433943404341434243434344434543464347434843494350435143524353435443554356435743584359436043614362436343644365436643674368436943704371437243734374437543764377437843794380438143824383438443854386438743884389439043914392439343944395439643974398439944004401440244034404440544064407440844094410441144124413441444154416441744184419442044214422442344244425442644274428442944304431443244334434443544364437443844394440444144424443444444454446444744484449445044514452445344544455445644574458445944604461446244634464446544664467446844694470447144724473447444754476447744784479448044814482448344844485448644874488448944904491449244934494449544964497449844994500450145024503450445054506450745084509451045114512451345144515451645174518451945204521452245234524452545264527452845294530453145324533453445354536453745384539454045414542454345444545454645474548454945504551455245534554455545564557455845594560456145624563456445654566456745684569457045714572457345744575457645774578457945804581458245834584458545864587458845894590459145924593459445954596459745984599460046014602460346044605460646074608460946104611461246134614461546164617461846194620462146224623462446254626462746284629463046314632463346344635463646374638463946404641464246434644464546464647464846494650465146524653465446554656465746584659466046614662466346644665466646674668466946704671467246734674467546764677467846794680468146824683468446854686468746884689469046914692469346944695469646974698469947004701470247034704470547064707470847094710471147124713471447154716471747184719472047214722472347244725472647274728472947304731473247334734473547364737473847394740474147424743474447454746474747484749475047514752475347544755475647574758475947604761476247634764476547664767476847694770477147724773477447754776477747784779478047814782478347844785478647874788478947904791479247934794479547964797479847994800480148024803480448054806480748084809481048114812481348144815481648174818481948204821482248234824482548264827482848294830483148324833483448354836483748384839484048414842484348444845484648474848484948504851485248534854485548564857485848594860486148624863486448654866486748684869487048714872487348744875487648774878487948804881488248834884488548864887488848894890489148924893489448954896489748984899490049014902490349044905490649074908490949104911491249134914491549164917491849194920492149224923492449254926492749284929493049314932493349344935493649374938493949404941494249434944494549464947494849494950495149524953495449554956495749584959496049614962496349644965496649674968496949704971497249734974497549764977497849794980498149824983498449854986498749884989499049914992499349944995499649974998499950005001500250035004500550065007500850095010501150125013501450155016501750185019502050215022502350245025502650275028502950305031503250335034503550365037503850395040504150425043504450455046504750485049505050515052505350545055505650575058505950605061506250635064506550665067506850695070507150725073507450755076507750785079508050815082508350845085508650875088508950905091509250935094509550965097509850995100510151025103510451055106510751085109511051115112511351145115511651175118511951205121512251235124512551265127512851295130513151325133513451355136513751385139514051415142514351445145514651475148514951505151515251535154515551565157515851595160516151625163516451655166516751685169517051715172517351745175517651775178517951805181518251835184518551865187518851895190519151925193519451955196519751985199520052015202520352045205520652075208520952105211521252135214521552165217521852195220522152225223522452255226522752285229523052315232523352345235523652375238523952405241524252435244524552465247524852495250525152525253525452555256525752585259526052615262526352645265526652675268526952705271527252735274527552765277527852795280528152825283528452855286528752885289529052915292529352945295529652975298529953005301530253035304530553065307530853095310531153125313531453155316531753185319532053215322532353245325532653275328532953305331533253335334533553365337533853395340534153425343534453455346534753485349535053515352535353545355535653575358535953605361536253635364536553665367536853695370537153725373537453755376537753785379538053815382538353845385538653875388538953905391539253935394539553965397539853995400540154025403540454055406540754085409541054115412541354145415541654175418541954205421542254235424542554265427542854295430543154325433543454355436543754385439544054415442544354445445544654475448544954505451545254535454545554565457545854595460546154625463546454655466546754685469547054715472547354745475547654775478547954805481548254835484548554865487548854895490549154925493549454955496549754985499550055015502550355045505550655075508550955105511551255135514551555165517551855195520552155225523552455255526552755285529553055315532553355345535553655375538553955405541554255435544554555465547554855495550555155525553555455555556555755585559556055615562556355645565556655675568556955705571557255735574557555765577557855795580558155825583558455855586558755885589559055915592559355945595559655975598559956005601560256035604560556065607560856095610561156125613561456155616561756185619562056215622562356245625562656275628562956305631563256335634563556365637563856395640564156425643564456455646564756485649565056515652565356545655565656575658565956605661566256635664566556665667566856695670567156725673567456755676567756785679568056815682568356845685568656875688568956905691569256935694569556965697569856995700570157025703570457055706570757085709571057115712571357145715571657175718571957205721572257235724572557265727572857295730573157325733573457355736573757385739574057415742574357445745574657475748574957505751575257535754575557565757575857595760576157625763576457655766576757685769577057715772577357745775577657775778577957805781578257835784578557865787578857895790579157925793579457955796579757985799580058015802580358045805580658075808580958105811581258135814581558165817581858195820582158225823582458255826582758285829583058315832583358345835583658375838583958405841584258435844584558465847584858495850585158525853585458555856585758585859586058615862586358645865586658675868586958705871587258735874587558765877587858795880588158825883588458855886588758885889589058915892589358945895589658975898589959005901590259035904590559065907590859095910591159125913591459155916591759185919592059215922592359245925592659275928592959305931593259335934593559365937593859395940594159425943594459455946594759485949595059515952595359545955595659575958595959605961596259635964596559665967596859695970597159725973597459755976597759785979598059815982598359845985598659875988598959905991599259935994599559965997599859996000600160026003600460056006600760086009601060116012601360146015601660176018601960206021602260236024602560266027602860296030603160326033603460356036603760386039604060416042604360446045604660476048604960506051605260536054605560566057605860596060606160626063606460656066606760686069607060716072607360746075607660776078607960806081608260836084608560866087608860896090609160926093609460956096609760986099610061016102610361046105610661076108610961106111611261136114611561166117611861196120612161226123612461256126612761286129613061316132613361346135613661376138613961406141614261436144614561466147614861496150615161526153615461556156615761586159616061616162616361646165616661676168616961706171617261736174617561766177617861796180618161826183618461856186618761886189619061916192619361946195619661976198619962006201620262036204620562066207620862096210621162126213621462156216621762186219622062216222622362246225622662276228622962306231623262336234623562366237623862396240624162426243624462456246624762486249625062516252625362546255625662576258625962606261626262636264626562666267626862696270627162726273627462756276627762786279628062816282628362846285628662876288628962906291629262936294629562966297629862996300630163026303630463056306630763086309631063116312631363146315631663176318631963206321632263236324632563266327632863296330633163326333633463356336633763386339634063416342634363446345634663476348634963506351635263536354635563566357635863596360636163626363636463656366636763686369637063716372637363746375637663776378637963806381638263836384638563866387638863896390639163926393639463956396639763986399640064016402640364046405640664076408640964106411641264136414641564166417641864196420642164226423642464256426642764286429643064316432643364346435643664376438643964406441644264436444644564466447644864496450645164526453645464556456645764586459646064616462646364646465646664676468646964706471647264736474647564766477647864796480648164826483648464856486648764886489649064916492649364946495649664976498649965006501650265036504650565066507650865096510651165126513651465156516651765186519652065216522652365246525652665276528652965306531653265336534653565366537653865396540654165426543654465456546654765486549655065516552655365546555655665576558655965606561656265636564656565666567656865696570657165726573657465756576657765786579658065816582658365846585658665876588658965906591659265936594659565966597659865996600660166026603660466056606660766086609661066116612661366146615661666176618661966206621662266236624662566266627662866296630663166326633663466356636663766386639664066416642664366446645664666476648664966506651665266536654665566566657665866596660666166626663666466656666666766686669667066716672667366746675667666776678667966806681668266836684668566866687668866896690669166926693669466956696669766986699670067016702670367046705670667076708670967106711671267136714671567166717671867196720672167226723672467256726672767286729673067316732673367346735673667376738673967406741674267436744674567466747674867496750675167526753675467556756675767586759676067616762676367646765676667676768676967706771677267736774677567766777677867796780678167826783678467856786678767886789679067916792679367946795679667976798679968006801680268036804680568066807680868096810681168126813681468156816681768186819682068216822682368246825682668276828682968306831683268336834683568366837683868396840684168426843684468456846684768486849685068516852685368546855685668576858685968606861686268636864686568666867686868696870687168726873687468756876687768786879688068816882688368846885688668876888688968906891689268936894689568966897689868996900690169026903690469056906690769086909691069116912691369146915691669176918691969206921692269236924692569266927692869296930693169326933693469356936693769386939694069416942694369446945694669476948694969506951695269536954695569566957695869596960696169626963696469656966696769686969697069716972697369746975697669776978697969806981698269836984698569866987698869896990699169926993699469956996699769986999700070017002700370047005700670077008700970107011701270137014701570167017701870197020702170227023702470257026702770287029703070317032703370347035703670377038703970407041704270437044704570467047704870497050705170527053705470557056705770587059706070617062706370647065706670677068706970707071707270737074707570767077707870797080708170827083708470857086708770887089709070917092709370947095709670977098709971007101710271037104710571067107710871097110711171127113711471157116711771187119712071217122712371247125712671277128712971307131713271337134713571367137713871397140714171427143714471457146714771487149715071517152715371547155715671577158715971607161716271637164716571667167716871697170717171727173717471757176717771787179718071817182718371847185718671877188718971907191719271937194719571967197719871997200720172027203720472057206720772087209721072117212721372147215721672177218721972207221722272237224722572267227722872297230723172327233723472357236723772387239724072417242724372447245724672477248724972507251725272537254725572567257725872597260726172627263726472657266726772687269727072717272727372747275727672777278727972807281728272837284728572867287728872897290729172927293729472957296729772987299730073017302730373047305730673077308730973107311731273137314731573167317731873197320732173227323732473257326732773287329733073317332733373347335733673377338733973407341734273437344734573467347734873497350735173527353735473557356735773587359736073617362736373647365736673677368736973707371737273737374737573767377737873797380738173827383738473857386738773887389739073917392739373947395739673977398739974007401740274037404740574067407740874097410741174127413741474157416741774187419742074217422742374247425742674277428742974307431743274337434743574367437743874397440744174427443744474457446744774487449745074517452745374547455745674577458745974607461746274637464746574667467746874697470747174727473747474757476747774787479748074817482748374847485748674877488748974907491749274937494749574967497749874997500750175027503750475057506750775087509751075117512751375147515751675177518751975207521752275237524752575267527752875297530753175327533753475357536753775387539754075417542754375447545754675477548754975507551755275537554755575567557755875597560756175627563756475657566756775687569757075717572757375747575757675777578757975807581758275837584758575867587758875897590759175927593759475957596759775987599760076017602760376047605760676077608760976107611761276137614761576167617761876197620762176227623762476257626762776287629763076317632763376347635763676377638763976407641764276437644764576467647764876497650765176527653765476557656765776587659766076617662766376647665766676677668766976707671767276737674767576767677767876797680768176827683768476857686768776887689769076917692769376947695769676977698769977007701770277037704770577067707770877097710771177127713771477157716771777187719772077217722772377247725772677277728772977307731773277337734773577367737773877397740774177427743774477457746774777487749775077517752775377547755775677577758775977607761776277637764776577667767776877697770777177727773777477757776777777787779778077817782778377847785778677877788778977907791779277937794779577967797779877997800780178027803780478057806780778087809781078117812781378147815781678177818781978207821782278237824782578267827782878297830783178327833783478357836783778387839784078417842784378447845784678477848784978507851785278537854785578567857785878597860786178627863786478657866786778687869787078717872787378747875787678777878787978807881788278837884788578867887788878897890789178927893789478957896789778987899790079017902790379047905790679077908790979107911791279137914791579167917791879197920792179227923792479257926792779287929793079317932793379347935793679377938793979407941794279437944794579467947794879497950795179527953795479557956795779587959796079617962796379647965796679677968796979707971797279737974797579767977797879797980798179827983798479857986798779887989799079917992799379947995799679977998799980008001800280038004800580068007800880098010801180128013801480158016801780188019802080218022802380248025802680278028802980308031803280338034803580368037803880398040804180428043804480458046804780488049805080518052805380548055805680578058805980608061806280638064806580668067806880698070807180728073807480758076807780788079808080818082808380848085808680878088808980908091809280938094809580968097809880998100810181028103810481058106810781088109811081118112811381148115811681178118811981208121812281238124812581268127812881298130813181328133813481358136813781388139814081418142814381448145814681478148814981508151815281538154815581568157815881598160816181628163816481658166816781688169817081718172817381748175817681778178817981808181818281838184818581868187818881898190819181928193819481958196819781988199820082018202820382048205820682078208820982108211821282138214821582168217821882198220822182228223822482258226822782288229823082318232823382348235823682378238823982408241824282438244824582468247824882498250825182528253825482558256825782588259826082618262826382648265826682678268826982708271827282738274827582768277827882798280828182828283828482858286828782888289829082918292829382948295829682978298829983008301830283038304830583068307830883098310831183128313831483158316831783188319832083218322832383248325832683278328832983308331833283338334833583368337833883398340834183428343834483458346834783488349835083518352835383548355835683578358835983608361836283638364836583668367836883698370837183728373837483758376837783788379838083818382838383848385838683878388838983908391839283938394839583968397839883998400840184028403840484058406840784088409841084118412841384148415841684178418841984208421842284238424842584268427842884298430843184328433843484358436843784388439844084418442844384448445844684478448844984508451845284538454845584568457845884598460846184628463846484658466846784688469847084718472847384748475847684778478847984808481848284838484848584868487848884898490849184928493849484958496849784988499850085018502850385048505850685078508850985108511851285138514851585168517851885198520852185228523852485258526852785288529853085318532853385348535853685378538853985408541854285438544854585468547854885498550855185528553855485558556855785588559856085618562856385648565856685678568856985708571857285738574857585768577857885798580858185828583858485858586858785888589859085918592859385948595859685978598859986008601860286038604860586068607860886098610861186128613861486158616861786188619862086218622862386248625862686278628862986308631863286338634863586368637863886398640864186428643864486458646864786488649865086518652865386548655865686578658865986608661866286638664866586668667866886698670867186728673867486758676867786788679868086818682868386848685868686878688868986908691869286938694869586968697869886998700870187028703870487058706870787088709871087118712871387148715871687178718871987208721872287238724872587268727872887298730873187328733873487358736873787388739874087418742874387448745874687478748874987508751875287538754875587568757875887598760876187628763876487658766876787688769877087718772877387748775877687778778877987808781878287838784878587868787878887898790879187928793879487958796879787988799880088018802880388048805880688078808880988108811881288138814881588168817881888198820882188228823882488258826882788288829883088318832883388348835883688378838883988408841884288438844884588468847884888498850885188528853885488558856885788588859886088618862886388648865886688678868886988708871887288738874887588768877887888798880888188828883888488858886888788888889889088918892889388948895889688978898889989008901890289038904890589068907890889098910891189128913891489158916891789188919892089218922892389248925892689278928892989308931893289338934893589368937893889398940894189428943894489458946894789488949895089518952895389548955895689578958895989608961896289638964896589668967896889698970897189728973897489758976897789788979898089818982898389848985898689878988898989908991899289938994899589968997899889999000900190029003900490059006900790089009901090119012901390149015901690179018901990209021902290239024902590269027902890299030903190329033903490359036903790389039904090419042904390449045904690479048904990509051905290539054905590569057905890599060906190629063906490659066906790689069907090719072907390749075907690779078907990809081908290839084908590869087908890899090909190929093909490959096909790989099910091019102910391049105910691079108910991109111911291139114911591169117911891199120912191229123912491259126912791289129913091319132913391349135913691379138913991409141914291439144914591469147914891499150915191529153915491559156915791589159916091619162916391649165916691679168916991709171917291739174917591769177917891799180918191829183918491859186918791889189919091919192919391949195919691979198919992009201920292039204920592069207920892099210921192129213921492159216921792189219922092219222922392249225922692279228922992309231923292339234923592369237923892399240924192429243924492459246924792489249925092519252925392549255925692579258925992609261926292639264926592669267926892699270927192729273927492759276927792789279928092819282928392849285928692879288928992909291929292939294929592969297929892999300930193029303930493059306930793089309931093119312931393149315931693179318931993209321932293239324932593269327932893299330933193329333933493359336933793389339934093419342934393449345934693479348934993509351935293539354935593569357935893599360936193629363936493659366936793689369937093719372937393749375937693779378937993809381938293839384938593869387938893899390939193929393939493959396939793989399940094019402940394049405940694079408940994109411941294139414941594169417941894199420942194229423942494259426942794289429943094319432943394349435943694379438943994409441944294439444944594469447944894499450945194529453945494559456945794589459946094619462946394649465946694679468946994709471947294739474947594769477947894799480948194829483948494859486948794889489949094919492949394949495949694979498949995009501950295039504950595069507950895099510951195129513951495159516951795189519952095219522952395249525952695279528952995309531953295339534953595369537953895399540954195429543954495459546954795489549955095519552955395549555955695579558955995609561956295639564956595669567956895699570957195729573957495759576957795789579958095819582958395849585958695879588958995909591959295939594959595969597959895999600960196029603960496059606960796089609961096119612961396149615961696179618961996209621962296239624962596269627962896299630963196329633963496359636963796389639964096419642964396449645964696479648964996509651965296539654965596569657965896599660966196629663966496659666966796689669967096719672967396749675967696779678967996809681968296839684968596869687968896899690969196929693969496959696969796989699970097019702970397049705970697079708970997109711971297139714971597169717971897199720972197229723972497259726972797289729973097319732973397349735973697379738973997409741974297439744974597469747974897499750975197529753975497559756975797589759976097619762976397649765976697679768976997709771977297739774977597769777977897799780978197829783978497859786978797889789979097919792979397949795979697979798979998009801980298039804980598069807980898099810981198129813981498159816981798189819982098219822982398249825982698279828982998309831983298339834983598369837983898399840984198429843984498459846984798489849985098519852985398549855985698579858985998609861986298639864986598669867986898699870987198729873987498759876987798789879988098819882988398849885988698879888988998909891989298939894989598969897989898999900990199029903990499059906990799089909991099119912991399149915991699179918991999209921992299239924992599269927992899299930993199329933993499359936993799389939994099419942994399449945994699479948994999509951995299539954995599569957995899599960996199629963996499659966996799689969997099719972997399749975997699779978997999809981998299839984998599869987998899899990999199929993999499959996999799989999100001000110002100031000410005100061000710008100091001010011100121001310014100151001610017100181001910020100211002210023100241002510026100271002810029100301003110032100331003410035100361003710038100391004010041100421004310044100451004610047100481004910050100511005210053100541005510056100571005810059100601006110062100631006410065100661006710068100691007010071100721007310074100751007610077100781007910080100811008210083100841008510086100871008810089100901009110092100931009410095100961009710098100991010010101101021010310104101051010610107101081010910110101111011210113101141011510116101171011810119101201012110122101231012410125101261012710128101291013010131101321013310134101351013610137101381013910140101411014210143101441014510146101471014810149101501015110152101531015410155101561015710158101591016010161101621016310164101651016610167101681016910170101711017210173101741017510176101771017810179101801018110182101831018410185101861018710188101891019010191101921019310194101951019610197101981019910200102011020210203102041020510206102071020810209102101021110212102131021410215102161021710218102191022010221102221022310224102251022610227102281022910230102311023210233102341023510236102371023810239102401024110242102431024410245102461024710248102491025010251102521025310254102551025610257102581025910260102611026210263102641026510266102671026810269102701027110272102731027410275102761027710278102791028010281102821028310284102851028610287102881028910290102911029210293102941029510296102971029810299103001030110302103031030410305103061030710308103091031010311103121031310314103151031610317103181031910320103211032210323103241032510326103271032810329103301033110332103331033410335103361033710338103391034010341103421034310344103451034610347103481034910350103511035210353103541035510356103571035810359103601036110362103631036410365103661036710368103691037010371103721037310374103751037610377103781037910380103811038210383103841038510386103871038810389103901039110392103931039410395103961039710398103991040010401104021040310404104051040610407104081040910410104111041210413104141041510416104171041810419104201042110422104231042410425104261042710428104291043010431104321043310434104351043610437104381043910440104411044210443104441044510446104471044810449104501045110452104531045410455104561045710458104591046010461104621046310464104651046610467104681046910470104711047210473104741047510476104771047810479104801048110482104831048410485104861048710488104891049010491104921049310494104951049610497104981049910500105011050210503105041050510506105071050810509105101051110512105131051410515105161051710518105191052010521105221052310524105251052610527105281052910530105311053210533105341053510536105371053810539105401054110542105431054410545105461054710548105491055010551105521055310554105551055610557105581055910560105611056210563105641056510566105671056810569105701057110572105731057410575105761057710578105791058010581105821058310584105851058610587105881058910590105911059210593105941059510596105971059810599106001060110602106031060410605106061060710608106091061010611106121061310614106151061610617106181061910620106211062210623106241062510626106271062810629106301063110632106331063410635106361063710638106391064010641106421064310644106451064610647106481064910650106511065210653106541065510656106571065810659106601066110662106631066410665106661066710668106691067010671106721067310674106751067610677106781067910680106811068210683106841068510686106871068810689106901069110692106931069410695106961069710698106991070010701107021070310704107051070610707107081070910710107111071210713107141071510716107171071810719107201072110722107231072410725107261072710728107291073010731107321073310734107351073610737107381073910740107411074210743107441074510746107471074810749107501075110752107531075410755107561075710758107591076010761107621076310764107651076610767107681076910770107711077210773107741077510776107771077810779107801078110782107831078410785107861078710788107891079010791107921079310794107951079610797107981079910800108011080210803108041080510806108071080810809108101081110812108131081410815108161081710818108191082010821108221082310824108251082610827108281082910830108311083210833108341083510836108371083810839108401084110842108431084410845108461084710848108491085010851108521085310854108551085610857108581085910860108611086210863108641086510866108671086810869108701087110872108731087410875108761087710878108791088010881108821088310884108851088610887108881088910890108911089210893108941089510896108971089810899109001090110902109031090410905109061090710908109091091010911109121091310914109151091610917109181091910920109211092210923109241092510926109271092810929109301093110932109331093410935109361093710938109391094010941109421094310944109451094610947109481094910950109511095210953109541095510956109571095810959109601096110962109631096410965109661096710968109691097010971109721097310974109751097610977109781097910980109811098210983109841098510986109871098810989109901099110992109931099410995109961099710998109991100011001110021100311004110051100611007110081100911010110111101211013110141101511016110171101811019110201102111022110231102411025110261102711028110291103011031110321103311034110351103611037110381103911040110411104211043110441104511046110471104811049110501105111052110531105411055110561105711058110591106011061110621106311064110651106611067110681106911070110711107211073110741107511076110771107811079110801108111082110831108411085110861108711088110891109011091110921109311094110951109611097110981109911100111011110211103111041110511106111071110811109111101111111112111131111411115111161111711118111191112011121111221112311124111251112611127111281112911130111311113211133111341113511136111371113811139111401114111142111431114411145111461114711148111491115011151111521115311154111551115611157111581115911160111611116211163111641116511166111671116811169111701117111172111731117411175111761117711178111791118011181111821118311184111851118611187111881118911190111911119211193111941119511196111971119811199112001120111202112031120411205112061120711208112091121011211112121121311214112151121611217112181121911220112211122211223112241122511226112271122811229112301123111232112331123411235112361123711238112391124011241112421124311244112451124611247112481124911250112511125211253112541125511256112571125811259112601126111262112631126411265112661126711268112691127011271112721127311274112751127611277112781127911280112811128211283112841128511286112871128811289112901129111292112931129411295112961129711298112991130011301113021130311304113051130611307113081130911310113111131211313113141131511316113171131811319113201132111322113231132411325113261132711328113291133011331113321133311334113351133611337113381133911340113411134211343113441134511346113471134811349113501135111352113531135411355113561135711358113591136011361113621136311364113651136611367113681136911370113711137211373113741137511376113771137811379113801138111382113831138411385113861138711388113891139011391113921139311394113951139611397113981139911400114011140211403114041140511406114071140811409114101141111412114131141411415114161141711418114191142011421114221142311424114251142611427114281142911430114311143211433114341143511436114371143811439114401144111442114431144411445114461144711448114491145011451114521145311454114551145611457114581145911460114611146211463114641146511466114671146811469114701147111472114731147411475114761147711478114791148011481114821148311484114851148611487114881148911490114911149211493114941149511496114971149811499115001150111502115031150411505115061150711508115091151011511115121151311514115151151611517115181151911520115211152211523115241152511526115271152811529115301153111532115331153411535115361153711538115391154011541115421154311544115451154611547115481154911550115511155211553115541155511556115571155811559115601156111562115631156411565115661156711568115691157011571115721157311574115751157611577115781157911580115811158211583115841158511586115871158811589115901159111592115931159411595115961159711598115991160011601116021160311604116051160611607116081160911610116111161211613116141161511616116171161811619116201162111622116231162411625116261162711628116291163011631116321163311634116351163611637116381163911640116411164211643116441164511646116471164811649116501165111652116531165411655116561165711658116591166011661116621166311664116651166611667116681166911670116711167211673116741167511676116771167811679116801168111682116831168411685116861168711688116891169011691116921169311694116951169611697116981169911700117011170211703117041170511706117071170811709117101171111712117131171411715117161171711718117191172011721117221172311724117251172611727117281172911730117311173211733117341173511736117371173811739117401174111742117431174411745117461174711748117491175011751117521175311754117551175611757117581175911760117611176211763117641176511766117671176811769117701177111772117731177411775117761177711778117791178011781117821178311784117851178611787117881178911790117911179211793117941179511796117971179811799118001180111802118031180411805118061180711808118091181011811118121181311814118151181611817118181181911820118211182211823118241182511826118271182811829118301183111832118331183411835118361183711838118391184011841118421184311844118451184611847118481184911850118511185211853118541185511856118571185811859118601186111862118631186411865118661186711868118691187011871118721187311874118751187611877118781187911880118811188211883118841188511886118871188811889118901189111892118931189411895118961189711898118991190011901119021190311904119051190611907119081190911910119111191211913119141191511916119171191811919119201192111922119231192411925119261192711928119291193011931119321193311934119351193611937119381193911940119411194211943119441194511946119471194811949119501195111952119531195411955119561195711958119591196011961119621196311964119651196611967119681196911970119711197211973119741197511976119771197811979119801198111982119831198411985119861198711988119891199011991119921199311994119951199611997119981199912000120011200212003120041200512006120071200812009120101201112012120131201412015120161201712018120191202012021120221202312024120251202612027120281202912030120311203212033120341203512036120371203812039120401204112042120431204412045120461204712048120491205012051120521205312054120551205612057120581205912060120611206212063120641206512066120671206812069120701207112072120731207412075120761207712078120791208012081120821208312084120851208612087120881208912090120911209212093120941209512096120971209812099121001210112102121031210412105121061210712108121091211012111121121211312114121151211612117121181211912120121211212212123121241212512126121271212812129121301213112132121331213412135121361213712138121391214012141121421214312144121451214612147121481214912150121511215212153121541215512156121571215812159121601216112162121631216412165121661216712168121691217012171121721217312174121751217612177121781217912180121811218212183121841218512186121871218812189121901219112192121931219412195121961219712198121991220012201122021220312204122051220612207122081220912210122111221212213122141221512216122171221812219122201222112222122231222412225122261222712228122291223012231122321223312234122351223612237122381223912240122411224212243122441224512246122471224812249122501225112252122531225412255122561225712258122591226012261122621226312264122651226612267122681226912270122711227212273122741227512276122771227812279122801228112282122831228412285122861228712288122891229012291122921229312294122951229612297122981229912300123011230212303123041230512306123071230812309123101231112312123131231412315123161231712318123191232012321123221232312324123251232612327123281232912330123311233212333123341233512336123371233812339123401234112342123431234412345123461234712348123491235012351123521235312354123551235612357123581235912360123611236212363123641236512366123671236812369123701237112372123731237412375123761237712378123791238012381123821238312384123851238612387123881238912390123911239212393123941239512396123971239812399124001240112402124031240412405124061240712408124091241012411124121241312414124151241612417124181241912420124211242212423124241242512426124271242812429124301243112432124331243412435124361243712438124391244012441124421244312444124451244612447124481244912450124511245212453124541245512456124571245812459124601246112462124631246412465124661246712468124691247012471124721247312474124751247612477124781247912480124811248212483124841248512486124871248812489124901249112492124931249412495124961249712498124991250012501125021250312504125051250612507125081250912510125111251212513125141251512516125171251812519125201252112522125231252412525125261252712528125291253012531125321253312534125351253612537125381253912540125411254212543125441254512546125471254812549125501255112552125531255412555125561255712558125591256012561125621256312564125651256612567125681256912570125711257212573125741257512576125771257812579125801258112582125831258412585125861258712588125891259012591125921259312594125951259612597125981259912600126011260212603126041260512606126071260812609126101261112612126131261412615126161261712618126191262012621126221262312624126251262612627126281262912630126311263212633126341263512636126371263812639126401264112642126431264412645126461264712648126491265012651126521265312654126551265612657126581265912660126611266212663126641266512666126671266812669126701267112672126731267412675126761267712678126791268012681126821268312684126851268612687126881268912690126911269212693126941269512696126971269812699127001270112702127031270412705127061270712708127091271012711127121271312714127151271612717127181271912720127211272212723127241272512726127271272812729127301273112732127331273412735127361273712738127391274012741127421274312744127451274612747127481274912750127511275212753127541275512756127571275812759127601276112762127631276412765127661276712768127691277012771127721277312774127751277612777127781277912780127811278212783127841278512786127871278812789127901279112792127931279412795127961279712798127991280012801128021280312804128051280612807128081280912810128111281212813128141281512816128171281812819128201282112822128231282412825128261282712828128291283012831128321283312834128351283612837128381283912840128411284212843128441284512846128471284812849128501285112852128531285412855128561285712858128591286012861128621286312864128651286612867128681286912870128711287212873128741287512876128771287812879128801288112882128831288412885128861288712888128891289012891128921289312894128951289612897128981289912900129011290212903129041290512906129071290812909129101291112912129131291412915129161291712918129191292012921129221292312924129251292612927129281292912930129311293212933129341293512936129371293812939129401294112942129431294412945129461294712948129491295012951129521295312954129551295612957129581295912960129611296212963129641296512966129671296812969129701297112972129731297412975129761297712978129791298012981129821298312984129851298612987129881298912990129911299212993129941299512996129971299812999130001300113002130031300413005130061300713008130091301013011130121301313014130151301613017130181301913020130211302213023130241302513026130271302813029130301303113032130331303413035130361303713038130391304013041130421304313044130451304613047130481304913050130511305213053130541305513056130571305813059130601306113062130631306413065130661306713068130691307013071130721307313074130751307613077130781307913080130811308213083130841308513086130871308813089130901309113092130931309413095130961309713098130991310013101131021310313104131051310613107131081310913110131111311213113131141311513116131171311813119131201312113122131231312413125131261312713128131291313013131131321313313134131351313613137131381313913140131411314213143131441314513146131471314813149131501315113152131531315413155131561315713158131591316013161131621316313164131651316613167131681316913170131711317213173131741317513176131771317813179131801318113182131831318413185131861318713188131891319013191131921319313194131951319613197131981319913200132011320213203132041320513206132071320813209132101321113212132131321413215132161321713218132191322013221132221322313224132251322613227132281322913230132311323213233132341323513236132371323813239132401324113242132431324413245132461324713248132491325013251132521325313254132551325613257132581325913260132611326213263132641326513266132671326813269132701327113272132731327413275132761327713278132791328013281132821328313284132851328613287132881328913290132911329213293132941329513296132971329813299133001330113302133031330413305133061330713308133091331013311133121331313314133151331613317133181331913320133211332213323133241332513326133271332813329133301333113332133331333413335133361333713338133391334013341133421334313344133451334613347133481334913350133511335213353133541335513356133571335813359133601336113362133631336413365133661336713368133691337013371133721337313374133751337613377133781337913380133811338213383133841338513386133871338813389133901339113392133931339413395133961339713398133991340013401134021340313404134051340613407134081340913410134111341213413134141341513416134171341813419134201342113422134231342413425134261342713428134291343013431134321343313434134351343613437134381343913440134411344213443134441344513446134471344813449134501345113452134531345413455134561345713458134591346013461134621346313464134651346613467134681346913470134711347213473134741347513476134771347813479134801348113482134831348413485134861348713488134891349013491134921349313494134951349613497134981349913500135011350213503135041350513506135071350813509135101351113512135131351413515135161351713518135191352013521135221352313524135251352613527135281352913530135311353213533135341353513536135371353813539135401354113542135431354413545135461354713548135491355013551135521355313554135551355613557135581355913560135611356213563135641356513566135671356813569135701357113572135731357413575135761357713578135791358013581135821358313584135851358613587135881358913590135911359213593135941359513596135971359813599136001360113602136031360413605136061360713608136091361013611136121361313614136151361613617136181361913620136211362213623136241362513626136271362813629136301363113632136331363413635136361363713638136391364013641136421364313644136451364613647136481364913650136511365213653136541365513656136571365813659136601366113662136631366413665136661366713668136691367013671136721367313674136751367613677136781367913680136811368213683136841368513686136871368813689136901369113692136931369413695136961369713698136991370013701137021370313704137051370613707137081370913710137111371213713137141371513716137171371813719137201372113722137231372413725137261372713728137291373013731137321373313734137351373613737137381373913740137411374213743137441374513746137471374813749137501375113752137531375413755137561375713758137591376013761137621376313764137651376613767137681376913770137711377213773137741377513776137771377813779137801378113782137831378413785137861378713788137891379013791137921379313794137951379613797137981379913800138011380213803138041380513806138071380813809138101381113812138131381413815138161381713818138191382013821138221382313824138251382613827138281382913830138311383213833138341383513836138371383813839138401384113842138431384413845138461384713848138491385013851138521385313854138551385613857138581385913860138611386213863138641386513866138671386813869138701387113872138731387413875138761387713878138791388013881138821388313884138851388613887138881388913890138911389213893138941389513896138971389813899139001390113902139031390413905139061390713908139091391013911139121391313914139151391613917139181391913920139211392213923139241392513926139271392813929139301393113932139331393413935139361393713938139391394013941139421394313944139451394613947139481394913950139511395213953139541395513956139571395813959139601396113962139631396413965139661396713968139691397013971139721397313974139751397613977139781397913980139811398213983139841398513986139871398813989139901399113992139931399413995139961399713998139991400014001140021400314004140051400614007140081400914010140111401214013140141401514016140171401814019140201402114022140231402414025140261402714028140291403014031140321403314034140351403614037140381403914040140411404214043140441404514046140471404814049140501405114052140531405414055140561405714058140591406014061140621406314064140651406614067140681406914070140711407214073140741407514076140771407814079140801408114082140831408414085140861408714088140891409014091140921409314094140951409614097140981409914100141011410214103141041410514106141071410814109141101411114112141131411414115141161411714118141191412014121141221412314124141251412614127141281412914130141311413214133141341413514136141371413814139141401414114142141431414414145141461414714148141491415014151141521415314154141551415614157141581415914160141611416214163141641416514166141671416814169141701417114172141731417414175141761417714178141791418014181141821418314184141851418614187141881418914190141911419214193141941419514196141971419814199142001420114202142031420414205142061420714208142091421014211142121421314214142151421614217142181421914220142211422214223142241422514226142271422814229142301423114232142331423414235142361423714238142391424014241142421424314244142451424614247142481424914250142511425214253142541425514256142571425814259142601426114262142631426414265142661426714268142691427014271142721427314274142751427614277142781427914280142811428214283142841428514286142871428814289142901429114292142931429414295142961429714298142991430014301143021430314304143051430614307143081430914310143111431214313143141431514316143171431814319143201432114322143231432414325143261432714328143291433014331143321433314334143351433614337143381433914340143411434214343143441434514346143471434814349143501435114352143531435414355143561435714358143591436014361143621436314364143651436614367143681436914370143711437214373143741437514376143771437814379143801438114382143831438414385143861438714388143891439014391143921439314394143951439614397143981439914400144011440214403144041440514406144071440814409144101441114412144131441414415144161441714418144191442014421144221442314424144251442614427144281442914430144311443214433144341443514436144371443814439144401444114442144431444414445144461444714448144491445014451144521445314454144551445614457144581445914460144611446214463144641446514466144671446814469144701447114472144731447414475144761447714478144791448014481144821448314484144851448614487144881448914490144911449214493144941449514496144971449814499145001450114502145031450414505145061450714508145091451014511145121451314514145151451614517145181451914520145211452214523145241452514526145271452814529145301453114532145331453414535145361453714538145391454014541145421454314544145451454614547145481454914550145511455214553145541455514556145571455814559145601456114562145631456414565145661456714568145691457014571145721457314574145751457614577145781457914580145811458214583145841458514586145871458814589145901459114592145931459414595145961459714598145991460014601146021460314604146051460614607146081460914610146111461214613146141461514616146171461814619146201462114622146231462414625146261462714628146291463014631146321463314634146351463614637146381463914640146411464214643146441464514646146471464814649146501465114652146531465414655146561465714658146591466014661146621466314664146651466614667146681466914670146711467214673146741467514676146771467814679146801468114682146831468414685146861468714688146891469014691146921469314694146951469614697146981469914700147011470214703147041470514706147071470814709147101471114712147131471414715147161471714718147191472014721147221472314724147251472614727147281472914730147311473214733147341473514736147371473814739147401474114742147431474414745147461474714748147491475014751147521475314754147551475614757147581475914760147611476214763147641476514766147671476814769147701477114772147731477414775147761477714778147791478014781147821478314784147851478614787147881478914790147911479214793147941479514796147971479814799148001480114802148031480414805148061480714808148091481014811148121481314814148151481614817148181481914820148211482214823148241482514826148271482814829148301483114832148331483414835148361483714838148391484014841148421484314844148451484614847148481484914850148511485214853148541485514856148571485814859148601486114862148631486414865148661486714868148691487014871148721487314874148751487614877148781487914880148811488214883148841488514886148871488814889148901489114892148931489414895148961489714898148991490014901149021490314904149051490614907149081490914910149111491214913149141491514916149171491814919149201492114922149231492414925149261492714928149291493014931149321493314934149351493614937149381493914940149411494214943149441494514946149471494814949149501495114952149531495414955149561495714958149591496014961149621496314964149651496614967149681496914970149711497214973149741497514976149771497814979149801498114982149831498414985149861498714988149891499014991149921499314994149951499614997149981499915000150011500215003150041500515006150071500815009150101501115012150131501415015150161501715018150191502015021150221502315024150251502615027150281502915030150311503215033150341503515036150371503815039150401504115042150431504415045150461504715048150491505015051150521505315054150551505615057150581505915060150611506215063150641506515066150671506815069150701507115072150731507415075150761507715078150791508015081150821508315084150851508615087150881508915090150911509215093150941509515096150971509815099151001510115102151031510415105151061510715108151091511015111151121511315114151151511615117151181511915120151211512215123151241512515126151271512815129151301513115132151331513415135151361513715138151391514015141151421514315144151451514615147151481514915150151511515215153151541515515156151571515815159151601516115162151631516415165151661516715168151691517015171151721517315174151751517615177151781517915180151811518215183151841518515186151871518815189151901519115192151931519415195151961519715198151991520015201152021520315204152051520615207152081520915210152111521215213152141521515216152171521815219152201522115222152231522415225152261522715228152291523015231152321523315234152351523615237152381523915240152411524215243152441524515246152471524815249152501525115252152531525415255152561525715258152591526015261152621526315264152651526615267152681526915270152711527215273152741527515276152771527815279152801528115282152831528415285152861528715288152891529015291152921529315294152951529615297152981529915300153011530215303153041530515306153071530815309153101531115312153131531415315153161531715318153191532015321153221532315324153251532615327153281532915330153311533215333153341533515336153371533815339153401534115342153431534415345153461534715348153491535015351153521535315354153551535615357153581535915360153611536215363153641536515366153671536815369153701537115372153731537415375153761537715378153791538015381153821538315384153851538615387153881538915390153911539215393153941539515396153971539815399154001540115402154031540415405154061540715408154091541015411154121541315414154151541615417154181541915420154211542215423154241542515426154271542815429154301543115432154331543415435154361543715438154391544015441154421544315444154451544615447154481544915450154511545215453154541545515456154571545815459154601546115462154631546415465154661546715468154691547015471154721547315474154751547615477154781547915480154811548215483154841548515486154871548815489154901549115492154931549415495154961549715498154991550015501155021550315504155051550615507155081550915510155111551215513155141551515516155171551815519155201552115522155231552415525155261552715528155291553015531155321553315534155351553615537155381553915540155411554215543155441554515546155471554815549155501555115552155531555415555155561555715558155591556015561155621556315564155651556615567155681556915570155711557215573155741557515576155771557815579155801558115582155831558415585155861558715588155891559015591155921559315594155951559615597155981559915600156011560215603156041560515606156071560815609156101561115612156131561415615156161561715618156191562015621156221562315624156251562615627156281562915630156311563215633156341563515636156371563815639156401564115642156431564415645156461564715648156491565015651156521565315654156551565615657156581565915660156611566215663156641566515666156671566815669156701567115672156731567415675156761567715678156791568015681156821568315684156851568615687156881568915690156911569215693156941569515696156971569815699157001570115702157031570415705157061570715708157091571015711157121571315714157151571615717157181571915720157211572215723157241572515726157271572815729157301573115732157331573415735157361573715738157391574015741157421574315744157451574615747157481574915750157511575215753157541575515756157571575815759157601576115762157631576415765157661576715768157691577015771157721577315774157751577615777157781577915780157811578215783157841578515786157871578815789157901579115792157931579415795157961579715798157991580015801158021580315804158051580615807158081580915810158111581215813158141581515816158171581815819158201582115822158231582415825158261582715828158291583015831158321583315834158351583615837158381583915840158411584215843158441584515846158471584815849158501585115852158531585415855158561585715858158591586015861158621586315864158651586615867158681586915870158711587215873158741587515876158771587815879158801588115882158831588415885158861588715888158891589015891158921589315894158951589615897158981589915900159011590215903159041590515906159071590815909159101591115912159131591415915159161591715918159191592015921159221592315924159251592615927159281592915930159311593215933159341593515936159371593815939159401594115942159431594415945159461594715948159491595015951159521595315954159551595615957159581595915960159611596215963159641596515966159671596815969159701597115972159731597415975159761597715978159791598015981159821598315984159851598615987159881598915990159911599215993159941599515996159971599815999160001600116002160031600416005160061600716008160091601016011160121601316014160151601616017160181601916020160211602216023160241602516026160271602816029160301603116032160331603416035160361603716038160391604016041160421604316044160451604616047160481604916050160511605216053160541605516056160571605816059160601606116062160631606416065160661606716068160691607016071160721607316074160751607616077160781607916080160811608216083160841608516086160871608816089160901609116092160931609416095160961609716098160991610016101161021610316104161051610616107161081610916110161111611216113161141611516116161171611816119161201612116122161231612416125161261612716128161291613016131161321613316134161351613616137161381613916140161411614216143161441614516146161471614816149161501615116152161531615416155161561615716158161591616016161161621616316164161651616616167161681616916170161711617216173161741617516176161771617816179161801618116182161831618416185161861618716188161891619016191161921619316194161951619616197161981619916200162011620216203162041620516206162071620816209162101621116212162131621416215162161621716218162191622016221162221622316224162251622616227162281622916230162311623216233162341623516236162371623816239162401624116242162431624416245162461624716248162491625016251162521625316254162551625616257162581625916260162611626216263162641626516266162671626816269162701627116272162731627416275162761627716278162791628016281162821628316284162851628616287162881628916290162911629216293162941629516296162971629816299163001630116302163031630416305163061630716308163091631016311163121631316314163151631616317163181631916320163211632216323163241632516326163271632816329163301633116332163331633416335163361633716338163391634016341163421634316344163451634616347163481634916350163511635216353163541635516356163571635816359163601636116362163631636416365163661636716368163691637016371163721637316374163751637616377163781637916380163811638216383163841638516386163871638816389163901639116392163931639416395163961639716398163991640016401164021640316404164051640616407164081640916410164111641216413164141641516416164171641816419164201642116422164231642416425164261642716428164291643016431164321643316434164351643616437164381643916440164411644216443164441644516446164471644816449164501645116452164531645416455164561645716458164591646016461164621646316464164651646616467164681646916470164711647216473164741647516476164771647816479164801648116482164831648416485164861648716488164891649016491164921649316494164951649616497164981649916500165011650216503165041650516506165071650816509165101651116512165131651416515165161651716518165191652016521165221652316524165251652616527165281652916530165311653216533165341653516536165371653816539165401654116542165431654416545165461654716548165491655016551165521655316554165551655616557165581655916560165611656216563165641656516566165671656816569165701657116572165731657416575165761657716578165791658016581165821658316584165851658616587165881658916590165911659216593165941659516596165971659816599166001660116602166031660416605166061660716608166091661016611166121661316614166151661616617166181661916620166211662216623166241662516626166271662816629166301663116632166331663416635166361663716638166391664016641166421664316644166451664616647166481664916650166511665216653166541665516656166571665816659166601666116662166631666416665166661666716668166691667016671166721667316674166751667616677166781667916680166811668216683166841668516686166871668816689166901669116692166931669416695166961669716698166991670016701167021670316704167051670616707167081670916710167111671216713167141671516716167171671816719167201672116722167231672416725167261672716728167291673016731167321673316734167351673616737167381673916740167411674216743167441674516746167471674816749167501675116752167531675416755167561675716758167591676016761167621676316764167651676616767167681676916770167711677216773167741677516776167771677816779167801678116782167831678416785167861678716788167891679016791167921679316794167951679616797167981679916800168011680216803168041680516806168071680816809168101681116812168131681416815168161681716818168191682016821168221682316824168251682616827168281682916830168311683216833168341683516836168371683816839168401684116842168431684416845168461684716848168491685016851168521685316854168551685616857168581685916860168611686216863168641686516866168671686816869168701687116872168731687416875168761687716878168791688016881168821688316884168851688616887168881688916890168911689216893168941689516896168971689816899169001690116902169031690416905169061690716908169091691016911169121691316914169151691616917169181691916920169211692216923169241692516926169271692816929169301693116932169331693416935169361693716938169391694016941169421694316944169451694616947169481694916950169511695216953169541695516956169571695816959169601696116962169631696416965169661696716968169691697016971169721697316974169751697616977169781697916980169811698216983169841698516986169871698816989169901699116992169931699416995169961699716998169991700017001170021700317004170051700617007170081700917010170111701217013170141701517016170171701817019170201702117022170231702417025170261702717028170291703017031170321703317034170351703617037170381703917040170411704217043170441704517046170471704817049170501705117052170531705417055170561705717058170591706017061170621706317064170651706617067170681706917070170711707217073170741707517076170771707817079170801708117082170831708417085170861708717088170891709017091170921709317094170951709617097170981709917100171011710217103171041710517106171071710817109171101711117112171131711417115171161711717118171191712017121171221712317124171251712617127171281712917130171311713217133171341713517136171371713817139171401714117142171431714417145171461714717148171491715017151171521715317154171551715617157171581715917160171611716217163171641716517166171671716817169171701717117172171731717417175171761717717178171791718017181171821718317184171851718617187171881718917190171911719217193171941719517196171971719817199172001720117202172031720417205172061720717208172091721017211172121721317214172151721617217172181721917220172211722217223172241722517226172271722817229172301723117232172331723417235172361723717238172391724017241172421724317244172451724617247172481724917250172511725217253172541725517256172571725817259172601726117262172631726417265172661726717268172691727017271172721727317274172751727617277172781727917280172811728217283172841728517286172871728817289172901729117292172931729417295172961729717298172991730017301173021730317304173051730617307173081730917310173111731217313173141731517316173171731817319173201732117322173231732417325173261732717328173291733017331173321733317334173351733617337173381733917340173411734217343173441734517346173471734817349173501735117352173531735417355173561735717358173591736017361173621736317364173651736617367173681736917370173711737217373173741737517376173771737817379173801738117382173831738417385173861738717388173891739017391173921739317394173951739617397173981739917400174011740217403174041740517406174071740817409174101741117412174131741417415174161741717418174191742017421174221742317424174251742617427174281742917430174311743217433174341743517436174371743817439174401744117442174431744417445174461744717448174491745017451174521745317454174551745617457174581745917460174611746217463174641746517466174671746817469174701747117472174731747417475174761747717478174791748017481174821748317484174851748617487174881748917490174911749217493174941749517496174971749817499175001750117502175031750417505175061750717508175091751017511175121751317514175151751617517175181751917520175211752217523175241752517526175271752817529175301753117532175331753417535175361753717538175391754017541175421754317544175451754617547175481754917550175511755217553175541755517556175571755817559175601756117562175631756417565175661756717568175691757017571175721757317574175751757617577175781757917580175811758217583175841758517586175871758817589175901759117592175931759417595175961759717598175991760017601176021760317604176051760617607176081760917610176111761217613176141761517616176171761817619176201762117622176231762417625176261762717628176291763017631176321763317634176351763617637176381763917640176411764217643176441764517646176471764817649176501765117652176531765417655176561765717658176591766017661176621766317664176651766617667176681766917670176711767217673176741767517676176771767817679176801768117682176831768417685176861768717688176891769017691176921769317694176951769617697176981769917700177011770217703177041770517706177071770817709177101771117712177131771417715177161771717718177191772017721177221772317724177251772617727177281772917730177311773217733177341773517736177371773817739177401774117742177431774417745177461774717748177491775017751177521775317754177551775617757177581775917760177611776217763177641776517766177671776817769177701777117772177731777417775177761777717778177791778017781177821778317784177851778617787177881778917790177911779217793177941779517796177971779817799178001780117802178031780417805178061780717808178091781017811178121781317814178151781617817178181781917820178211782217823178241782517826178271782817829178301783117832178331783417835178361783717838178391784017841178421784317844178451784617847178481784917850178511785217853178541785517856178571785817859178601786117862178631786417865178661786717868178691787017871178721787317874178751787617877178781787917880178811788217883178841788517886178871788817889178901789117892178931789417895178961789717898178991790017901179021790317904179051790617907179081790917910179111791217913179141791517916179171791817919179201792117922179231792417925179261792717928179291793017931179321793317934179351793617937179381793917940179411794217943179441794517946179471794817949179501795117952179531795417955179561795717958179591796017961179621796317964179651796617967179681796917970179711797217973179741797517976179771797817979179801798117982179831798417985179861798717988179891799017991179921799317994179951799617997179981799918000180011800218003180041800518006180071800818009180101801118012180131801418015180161801718018180191802018021180221802318024180251802618027180281802918030180311803218033180341803518036180371803818039180401804118042180431804418045180461804718048180491805018051180521805318054180551805618057180581805918060180611806218063180641806518066180671806818069180701807118072180731807418075180761807718078180791808018081180821808318084180851808618087180881808918090180911809218093180941809518096180971809818099181001810118102181031810418105181061810718108181091811018111181121811318114181151811618117181181811918120181211812218123181241812518126181271812818129181301813118132181331813418135181361813718138181391814018141181421814318144181451814618147181481814918150181511815218153181541815518156181571815818159181601816118162181631816418165181661816718168181691817018171181721817318174181751817618177181781817918180181811818218183181841818518186181871818818189181901819118192181931819418195181961819718198181991820018201182021820318204182051820618207182081820918210182111821218213182141821518216182171821818219182201822118222182231822418225182261822718228182291823018231182321823318234182351823618237182381823918240182411824218243182441824518246182471824818249182501825118252182531825418255182561825718258182591826018261182621826318264182651826618267182681826918270182711827218273182741827518276182771827818279182801828118282182831828418285182861828718288182891829018291182921829318294182951829618297182981829918300183011830218303183041830518306183071830818309183101831118312183131831418315183161831718318183191832018321183221832318324183251832618327183281832918330183311833218333183341833518336183371833818339183401834118342183431834418345183461834718348183491835018351183521835318354183551835618357183581835918360183611836218363183641836518366183671836818369183701837118372183731837418375183761837718378183791838018381183821838318384183851838618387183881838918390183911839218393183941839518396183971839818399184001840118402184031840418405184061840718408184091841018411184121841318414184151841618417184181841918420184211842218423184241842518426184271842818429184301843118432184331843418435184361843718438184391844018441184421844318444184451844618447184481844918450184511845218453184541845518456184571845818459184601846118462184631846418465184661846718468184691847018471184721847318474184751847618477184781847918480184811848218483184841848518486184871848818489184901849118492184931849418495184961849718498184991850018501185021850318504185051850618507185081850918510185111851218513185141851518516185171851818519185201852118522185231852418525185261852718528185291853018531185321853318534185351853618537185381853918540185411854218543185441854518546185471854818549185501855118552185531855418555185561855718558185591856018561185621856318564185651856618567185681856918570185711857218573185741857518576185771857818579185801858118582185831858418585185861858718588185891859018591185921859318594185951859618597185981859918600186011860218603186041860518606186071860818609186101861118612186131861418615186161861718618186191862018621186221862318624186251862618627186281862918630186311863218633186341863518636186371863818639186401864118642186431864418645186461864718648186491865018651186521865318654186551865618657186581865918660186611866218663186641866518666186671866818669186701867118672186731867418675186761867718678186791868018681186821868318684186851868618687186881868918690186911869218693186941869518696186971869818699187001870118702187031870418705187061870718708187091871018711187121871318714187151871618717187181871918720187211872218723187241872518726187271872818729187301873118732187331873418735187361873718738187391874018741187421874318744187451874618747187481874918750187511875218753187541875518756187571875818759187601876118762187631876418765187661876718768187691877018771187721877318774187751877618777187781877918780187811878218783187841878518786187871878818789187901879118792187931879418795187961879718798187991880018801188021880318804188051880618807188081880918810188111881218813188141881518816188171881818819188201882118822188231882418825188261882718828188291883018831188321883318834188351883618837188381883918840188411884218843188441884518846188471884818849188501885118852188531885418855188561885718858188591886018861188621886318864188651886618867188681886918870188711887218873188741887518876188771887818879188801888118882188831888418885188861888718888188891889018891188921889318894188951889618897188981889918900189011890218903189041890518906189071890818909189101891118912189131891418915189161891718918189191892018921189221892318924189251892618927189281892918930189311893218933189341893518936189371893818939189401894118942189431894418945189461894718948189491895018951189521895318954189551895618957189581895918960189611896218963189641896518966189671896818969189701897118972189731897418975189761897718978189791898018981189821898318984189851898618987189881898918990189911899218993189941899518996189971899818999190001900119002190031900419005190061900719008190091901019011190121901319014190151901619017190181901919020190211902219023190241902519026190271902819029190301903119032190331903419035190361903719038190391904019041190421904319044190451904619047190481904919050190511905219053190541905519056190571905819059190601906119062190631906419065190661906719068190691907019071190721907319074190751907619077190781907919080190811908219083190841908519086190871908819089190901909119092190931909419095190961909719098190991910019101191021910319104191051910619107191081910919110191111911219113191141911519116191171911819119191201912119122191231912419125191261912719128191291913019131191321913319134191351913619137191381913919140191411914219143191441914519146191471914819149191501915119152191531915419155191561915719158191591916019161191621916319164191651916619167191681916919170191711917219173191741917519176191771917819179191801918119182191831918419185191861918719188191891919019191191921919319194191951919619197191981919919200192011920219203192041920519206192071920819209192101921119212192131921419215192161921719218192191922019221192221922319224192251922619227192281922919230192311923219233192341923519236192371923819239192401924119242192431924419245192461924719248192491925019251192521925319254192551925619257192581925919260192611926219263192641926519266192671926819269192701927119272192731927419275192761927719278192791928019281192821928319284192851928619287192881928919290192911929219293192941929519296192971929819299193001930119302193031930419305193061930719308193091931019311193121931319314193151931619317193181931919320193211932219323193241932519326193271932819329193301933119332193331933419335193361933719338193391934019341193421934319344193451934619347193481934919350193511935219353193541935519356193571935819359193601936119362193631936419365193661936719368193691937019371193721937319374193751937619377193781937919380193811938219383193841938519386193871938819389193901939119392193931939419395193961939719398193991940019401194021940319404194051940619407194081940919410194111941219413194141941519416194171941819419194201942119422194231942419425194261942719428194291943019431194321943319434194351943619437194381943919440194411944219443194441944519446194471944819449194501945119452194531945419455194561945719458194591946019461194621946319464194651946619467194681946919470194711947219473194741947519476194771947819479194801948119482194831948419485194861948719488194891949019491194921949319494194951949619497194981949919500195011950219503195041950519506195071950819509195101951119512195131951419515195161951719518195191952019521195221952319524195251952619527195281952919530195311953219533195341953519536195371953819539195401954119542195431954419545195461954719548195491955019551195521955319554195551955619557195581955919560195611956219563195641956519566195671956819569195701957119572195731957419575195761957719578195791958019581195821958319584195851958619587195881958919590195911959219593195941959519596195971959819599196001960119602196031960419605196061960719608196091961019611196121961319614196151961619617196181961919620196211962219623196241962519626196271962819629196301963119632196331963419635196361963719638196391964019641196421964319644196451964619647196481964919650196511965219653196541965519656196571965819659196601966119662196631966419665196661966719668196691967019671196721967319674196751967619677196781967919680196811968219683196841968519686196871968819689196901969119692196931969419695196961969719698196991970019701197021970319704197051970619707197081970919710197111971219713197141971519716197171971819719197201972119722197231972419725197261972719728197291973019731197321973319734197351973619737197381973919740197411974219743197441974519746197471974819749197501975119752197531975419755197561975719758197591976019761197621976319764197651976619767197681976919770197711977219773197741977519776197771977819779197801978119782197831978419785197861978719788197891979019791197921979319794197951979619797197981979919800198011980219803198041980519806198071980819809198101981119812198131981419815198161981719818198191982019821198221982319824198251982619827198281982919830198311983219833198341983519836198371983819839198401984119842198431984419845198461984719848198491985019851198521985319854198551985619857198581985919860198611986219863198641986519866198671986819869198701987119872198731987419875198761987719878198791988019881198821988319884198851988619887198881988919890198911989219893198941989519896198971989819899199001990119902199031990419905199061990719908199091991019911199121991319914199151991619917199181991919920199211992219923199241992519926199271992819929199301993119932199331993419935199361993719938199391994019941199421994319944199451994619947199481994919950199511995219953199541995519956199571995819959199601996119962199631996419965199661996719968199691997019971199721997319974199751997619977199781997919980199811998219983199841998519986199871998819989199901999119992199931999419995199961999719998199992000020001200022000320004200052000620007200082000920010200112001220013200142001520016200172001820019200202002120022200232002420025200262002720028200292003020031200322003320034200352003620037200382003920040200412004220043200442004520046200472004820049200502005120052200532005420055200562005720058200592006020061200622006320064200652006620067200682006920070200712007220073200742007520076200772007820079200802008120082200832008420085200862008720088200892009020091200922009320094200952009620097200982009920100201012010220103201042010520106201072010820109201102011120112201132011420115201162011720118201192012020121201222012320124201252012620127201282012920130201312013220133201342013520136201372013820139201402014120142201432014420145201462014720148201492015020151201522015320154201552015620157201582015920160201612016220163201642016520166201672016820169201702017120172201732017420175201762017720178201792018020181201822018320184201852018620187201882018920190201912019220193201942019520196201972019820199202002020120202202032020420205202062020720208202092021020211202122021320214202152021620217202182021920220202212022220223202242022520226202272022820229202302023120232202332023420235202362023720238202392024020241202422024320244202452024620247202482024920250202512025220253202542025520256202572025820259202602026120262202632026420265202662026720268202692027020271202722027320274202752027620277202782027920280202812028220283202842028520286202872028820289202902029120292202932029420295202962029720298202992030020301203022030320304203052030620307203082030920310203112031220313203142031520316203172031820319203202032120322203232032420325203262032720328203292033020331203322033320334203352033620337203382033920340203412034220343203442034520346203472034820349203502035120352203532035420355203562035720358203592036020361203622036320364203652036620367203682036920370203712037220373203742037520376203772037820379203802038120382203832038420385203862038720388203892039020391203922039320394203952039620397203982039920400204012040220403204042040520406204072040820409204102041120412204132041420415204162041720418204192042020421204222042320424204252042620427204282042920430204312043220433204342043520436204372043820439204402044120442204432044420445204462044720448204492045020451204522045320454204552045620457204582045920460204612046220463204642046520466204672046820469204702047120472204732047420475204762047720478204792048020481204822048320484204852048620487204882048920490204912049220493204942049520496204972049820499205002050120502205032050420505205062050720508205092051020511205122051320514205152051620517205182051920520205212052220523205242052520526205272052820529205302053120532205332053420535205362053720538205392054020541205422054320544205452054620547205482054920550205512055220553205542055520556205572055820559205602056120562205632056420565205662056720568205692057020571205722057320574205752057620577205782057920580205812058220583205842058520586205872058820589205902059120592205932059420595205962059720598205992060020601206022060320604206052060620607206082060920610206112061220613206142061520616206172061820619206202062120622206232062420625206262062720628206292063020631206322063320634206352063620637206382063920640206412064220643206442064520646206472064820649206502065120652206532065420655206562065720658206592066020661206622066320664206652066620667206682066920670206712067220673206742067520676206772067820679206802068120682206832068420685206862068720688206892069020691206922069320694206952069620697206982069920700207012070220703207042070520706207072070820709207102071120712207132071420715207162071720718207192072020721207222072320724207252072620727207282072920730207312073220733207342073520736207372073820739207402074120742207432074420745207462074720748207492075020751207522075320754207552075620757207582075920760207612076220763207642076520766207672076820769207702077120772207732077420775207762077720778207792078020781207822078320784207852078620787207882078920790207912079220793207942079520796207972079820799208002080120802208032080420805208062080720808208092081020811208122081320814208152081620817208182081920820208212082220823208242082520826208272082820829208302083120832208332083420835208362083720838208392084020841208422084320844208452084620847208482084920850208512085220853208542085520856208572085820859208602086120862208632086420865208662086720868208692087020871208722087320874208752087620877208782087920880208812088220883208842088520886208872088820889208902089120892208932089420895208962089720898208992090020901209022090320904209052090620907209082090920910209112091220913209142091520916209172091820919209202092120922209232092420925209262092720928209292093020931209322093320934209352093620937209382093920940209412094220943209442094520946209472094820949209502095120952209532095420955209562095720958209592096020961209622096320964209652096620967209682096920970209712097220973209742097520976209772097820979209802098120982209832098420985209862098720988209892099020991209922099320994209952099620997209982099921000210012100221003210042100521006210072100821009210102101121012210132101421015210162101721018210192102021021210222102321024210252102621027210282102921030210312103221033210342103521036210372103821039210402104121042210432104421045210462104721048210492105021051210522105321054210552105621057210582105921060210612106221063210642106521066210672106821069210702107121072210732107421075210762107721078210792108021081210822108321084210852108621087210882108921090210912109221093210942109521096210972109821099211002110121102211032110421105211062110721108211092111021111211122111321114211152111621117211182111921120211212112221123211242112521126211272112821129211302113121132211332113421135211362113721138211392114021141211422114321144211452114621147211482114921150211512115221153211542115521156211572115821159211602116121162211632116421165211662116721168211692117021171211722117321174211752117621177211782117921180211812118221183211842118521186211872118821189211902119121192211932119421195211962119721198211992120021201212022120321204212052120621207212082120921210212112121221213212142121521216212172121821219212202122121222212232122421225212262122721228212292123021231212322123321234212352123621237212382123921240212412124221243212442124521246212472124821249212502125121252212532125421255212562125721258212592126021261212622126321264212652126621267212682126921270212712127221273212742127521276212772127821279212802128121282212832128421285212862128721288212892129021291212922129321294212952129621297212982129921300213012130221303213042130521306213072130821309213102131121312213132131421315213162131721318213192132021321213222132321324213252132621327213282132921330213312133221333213342133521336213372133821339213402134121342213432134421345213462134721348213492135021351213522135321354213552135621357213582135921360213612136221363213642136521366213672136821369213702137121372213732137421375213762137721378213792138021381213822138321384213852138621387213882138921390213912139221393213942139521396213972139821399214002140121402214032140421405214062140721408214092141021411214122141321414214152141621417214182141921420214212142221423214242142521426214272142821429214302143121432214332143421435214362143721438214392144021441214422144321444214452144621447214482144921450214512145221453214542145521456214572145821459214602146121462214632146421465214662146721468214692147021471214722147321474214752147621477214782147921480214812148221483214842148521486214872148821489214902149121492214932149421495214962149721498214992150021501215022150321504215052150621507215082150921510215112151221513215142151521516215172151821519215202152121522215232152421525215262152721528215292153021531215322153321534215352153621537215382153921540215412154221543215442154521546215472154821549215502155121552215532155421555215562155721558215592156021561215622156321564215652156621567215682156921570215712157221573215742157521576215772157821579215802158121582215832158421585215862158721588215892159021591215922159321594215952159621597215982159921600216012160221603216042160521606216072160821609216102161121612216132161421615216162161721618216192162021621216222162321624216252162621627216282162921630216312163221633216342163521636216372163821639216402164121642216432164421645216462164721648216492165021651216522165321654216552165621657216582165921660216612166221663216642166521666216672166821669216702167121672216732167421675216762167721678216792168021681216822168321684216852168621687216882168921690216912169221693216942169521696216972169821699217002170121702217032170421705217062170721708217092171021711217122171321714217152171621717217182171921720217212172221723217242172521726217272172821729217302173121732217332173421735217362173721738217392174021741217422174321744217452174621747217482174921750217512175221753217542175521756217572175821759217602176121762217632176421765217662176721768217692177021771217722177321774217752177621777217782177921780217812178221783217842178521786217872178821789217902179121792217932179421795217962179721798217992180021801218022180321804218052180621807218082180921810218112181221813218142181521816218172181821819218202182121822218232182421825218262182721828218292183021831218322183321834218352183621837218382183921840218412184221843218442184521846218472184821849218502185121852218532185421855218562185721858218592186021861218622186321864218652186621867218682186921870218712187221873218742187521876218772187821879218802188121882218832188421885218862188721888218892189021891218922189321894218952189621897218982189921900219012190221903219042190521906219072190821909219102191121912219132191421915219162191721918219192192021921219222192321924219252192621927219282192921930219312193221933219342193521936219372193821939219402194121942219432194421945219462194721948219492195021951219522195321954219552195621957219582195921960219612196221963219642196521966219672196821969219702197121972219732197421975219762197721978219792198021981219822198321984219852198621987219882198921990219912199221993219942199521996219972199821999220002200122002220032200422005220062200722008220092201022011220122201322014220152201622017220182201922020220212202222023220242202522026220272202822029220302203122032220332203422035220362203722038220392204022041220422204322044220452204622047220482204922050220512205222053220542205522056220572205822059220602206122062220632206422065220662206722068220692207022071220722207322074220752207622077220782207922080220812208222083220842208522086220872208822089220902209122092220932209422095220962209722098220992210022101221022210322104221052210622107221082210922110221112211222113221142211522116221172211822119221202212122122221232212422125221262212722128221292213022131221322213322134221352213622137221382213922140221412214222143221442214522146221472214822149221502215122152221532215422155221562215722158221592216022161221622216322164221652216622167221682216922170221712217222173221742217522176221772217822179221802218122182221832218422185221862218722188221892219022191221922219322194221952219622197221982219922200222012220222203222042220522206222072220822209222102221122212222132221422215222162221722218222192222022221222222222322224222252222622227222282222922230222312223222233222342223522236222372223822239222402224122242222432224422245222462224722248222492225022251222522225322254222552225622257222582225922260222612226222263222642226522266222672226822269222702227122272222732227422275222762227722278222792228022281222822228322284222852228622287222882228922290222912229222293222942229522296222972229822299223002230122302223032230422305223062230722308223092231022311223122231322314223152231622317223182231922320223212232222323223242232522326223272232822329223302233122332223332233422335223362233722338223392234022341223422234322344223452234622347223482234922350223512235222353223542235522356223572235822359223602236122362223632236422365223662236722368223692237022371223722237322374223752237622377223782237922380223812238222383223842238522386223872238822389223902239122392223932239422395223962239722398223992240022401224022240322404224052240622407224082240922410224112241222413224142241522416224172241822419224202242122422224232242422425224262242722428224292243022431224322243322434224352243622437224382243922440224412244222443224442244522446224472244822449224502245122452224532245422455224562245722458224592246022461224622246322464224652246622467224682246922470224712247222473224742247522476224772247822479224802248122482224832248422485224862248722488224892249022491224922249322494224952249622497224982249922500225012250222503225042250522506225072250822509225102251122512225132251422515225162251722518225192252022521225222252322524225252252622527225282252922530225312253222533225342253522536225372253822539225402254122542225432254422545225462254722548225492255022551225522255322554225552255622557225582255922560225612256222563225642256522566225672256822569225702257122572225732257422575225762257722578225792258022581225822258322584225852258622587225882258922590225912259222593225942259522596225972259822599226002260122602226032260422605226062260722608226092261022611226122261322614226152261622617226182261922620226212262222623226242262522626226272262822629226302263122632226332263422635226362263722638226392264022641226422264322644226452264622647226482264922650226512265222653226542265522656226572265822659226602266122662226632266422665226662266722668226692267022671226722267322674226752267622677226782267922680226812268222683226842268522686226872268822689226902269122692226932269422695226962269722698226992270022701227022270322704227052270622707227082270922710227112271222713227142271522716227172271822719227202272122722227232272422725227262272722728227292273022731227322273322734227352273622737227382273922740227412274222743227442274522746227472274822749227502275122752227532275422755227562275722758227592276022761227622276322764227652276622767227682276922770227712277222773227742277522776227772277822779227802278122782227832278422785227862278722788227892279022791227922279322794227952279622797227982279922800228012280222803228042280522806228072280822809228102281122812228132281422815228162281722818228192282022821228222282322824228252282622827228282282922830228312283222833228342283522836228372283822839228402284122842228432284422845228462284722848228492285022851228522285322854228552285622857228582285922860228612286222863228642286522866228672286822869228702287122872228732287422875228762287722878228792288022881228822288322884228852288622887228882288922890228912289222893228942289522896228972289822899229002290122902229032290422905229062290722908229092291022911229122291322914229152291622917229182291922920229212292222923229242292522926229272292822929229302293122932229332293422935229362293722938229392294022941229422294322944229452294622947229482294922950229512295222953229542295522956229572295822959229602296122962229632296422965229662296722968229692297022971229722297322974229752297622977229782297922980229812298222983229842298522986229872298822989229902299122992229932299422995229962299722998229992300023001230022300323004230052300623007230082300923010230112301223013230142301523016230172301823019230202302123022230232302423025230262302723028230292303023031230322303323034230352303623037230382303923040230412304223043230442304523046230472304823049230502305123052230532305423055230562305723058230592306023061230622306323064230652306623067230682306923070230712307223073230742307523076230772307823079230802308123082230832308423085230862308723088230892309023091230922309323094230952309623097230982309923100231012310223103231042310523106231072310823109231102311123112231132311423115231162311723118231192312023121231222312323124231252312623127231282312923130231312313223133231342313523136231372313823139231402314123142231432314423145231462314723148231492315023151231522315323154231552315623157231582315923160231612316223163231642316523166231672316823169231702317123172231732317423175231762317723178231792318023181231822318323184231852318623187231882318923190231912319223193231942319523196231972319823199232002320123202232032320423205232062320723208232092321023211232122321323214232152321623217232182321923220232212322223223232242322523226232272322823229232302323123232232332323423235232362323723238232392324023241232422324323244232452324623247232482324923250232512325223253232542325523256232572325823259232602326123262232632326423265232662326723268232692327023271232722327323274232752327623277232782327923280232812328223283232842328523286232872328823289232902329123292232932329423295232962329723298232992330023301233022330323304233052330623307233082330923310233112331223313233142331523316233172331823319233202332123322233232332423325233262332723328233292333023331233322333323334233352333623337233382333923340233412334223343233442334523346233472334823349233502335123352233532335423355233562335723358233592336023361233622336323364233652336623367233682336923370233712337223373233742337523376233772337823379233802338123382233832338423385233862338723388233892339023391233922339323394233952339623397233982339923400234012340223403234042340523406234072340823409234102341123412234132341423415234162341723418234192342023421234222342323424234252342623427234282342923430234312343223433234342343523436234372343823439234402344123442234432344423445234462344723448234492345023451234522345323454234552345623457234582345923460234612346223463234642346523466234672346823469234702347123472234732347423475234762347723478234792348023481234822348323484234852348623487234882348923490234912349223493234942349523496234972349823499235002350123502235032350423505235062350723508235092351023511235122351323514235152351623517235182351923520235212352223523235242352523526235272352823529235302353123532235332353423535235362353723538235392354023541235422354323544235452354623547235482354923550235512355223553235542355523556235572355823559235602356123562235632356423565235662356723568235692357023571235722357323574235752357623577235782357923580235812358223583235842358523586235872358823589235902359123592235932359423595235962359723598235992360023601236022360323604236052360623607236082360923610236112361223613236142361523616236172361823619236202362123622236232362423625236262362723628236292363023631236322363323634236352363623637236382363923640236412364223643236442364523646236472364823649236502365123652236532365423655236562365723658236592366023661236622366323664236652366623667236682366923670236712367223673236742367523676236772367823679236802368123682236832368423685236862368723688236892369023691236922369323694236952369623697236982369923700237012370223703237042370523706237072370823709237102371123712237132371423715237162371723718237192372023721237222372323724237252372623727237282372923730237312373223733237342373523736237372373823739237402374123742237432374423745237462374723748237492375023751237522375323754237552375623757237582375923760237612376223763237642376523766237672376823769237702377123772237732377423775237762377723778237792378023781237822378323784237852378623787237882378923790237912379223793237942379523796237972379823799238002380123802238032380423805238062380723808238092381023811238122381323814238152381623817238182381923820238212382223823238242382523826238272382823829238302383123832238332383423835238362383723838238392384023841238422384323844238452384623847238482384923850238512385223853238542385523856238572385823859238602386123862238632386423865238662386723868238692387023871238722387323874238752387623877238782387923880238812388223883238842388523886238872388823889238902389123892238932389423895238962389723898238992390023901239022390323904239052390623907239082390923910239112391223913239142391523916239172391823919239202392123922239232392423925239262392723928239292393023931239322393323934239352393623937239382393923940239412394223943239442394523946239472394823949239502395123952239532395423955239562395723958239592396023961239622396323964239652396623967239682396923970239712397223973239742397523976239772397823979239802398123982239832398423985239862398723988239892399023991239922399323994239952399623997239982399924000240012400224003240042400524006240072400824009240102401124012240132401424015240162401724018240192402024021240222402324024240252402624027240282402924030240312403224033240342403524036240372403824039240402404124042240432404424045240462404724048240492405024051240522405324054240552405624057240582405924060240612406224063240642406524066240672406824069240702407124072240732407424075240762407724078240792408024081240822408324084240852408624087240882408924090240912409224093240942409524096240972409824099241002410124102241032410424105241062410724108241092411024111241122411324114241152411624117241182411924120241212412224123241242412524126241272412824129241302413124132241332413424135241362413724138241392414024141241422414324144241452414624147241482414924150241512415224153241542415524156241572415824159241602416124162241632416424165241662416724168241692417024171241722417324174241752417624177241782417924180241812418224183241842418524186241872418824189241902419124192241932419424195241962419724198241992420024201242022420324204242052420624207242082420924210242112421224213242142421524216242172421824219242202422124222242232422424225242262422724228242292423024231242322423324234242352423624237242382423924240242412424224243242442424524246242472424824249242502425124252242532425424255242562425724258242592426024261242622426324264242652426624267242682426924270242712427224273242742427524276242772427824279242802428124282242832428424285242862428724288242892429024291242922429324294242952429624297242982429924300243012430224303243042430524306243072430824309243102431124312243132431424315243162431724318243192432024321243222432324324243252432624327243282432924330243312433224333243342433524336243372433824339243402434124342243432434424345243462434724348243492435024351243522435324354243552435624357243582435924360243612436224363243642436524366243672436824369243702437124372243732437424375243762437724378243792438024381243822438324384243852438624387243882438924390243912439224393243942439524396243972439824399244002440124402244032440424405244062440724408244092441024411244122441324414244152441624417244182441924420244212442224423244242442524426244272442824429244302443124432244332443424435244362443724438244392444024441244422444324444244452444624447244482444924450244512445224453244542445524456244572445824459244602446124462244632446424465244662446724468244692447024471244722447324474244752447624477244782447924480244812448224483244842448524486244872448824489244902449124492244932449424495244962449724498244992450024501245022450324504245052450624507245082450924510245112451224513245142451524516245172451824519245202452124522245232452424525245262452724528245292453024531245322453324534245352453624537245382453924540245412454224543245442454524546245472454824549245502455124552245532455424555245562455724558245592456024561245622456324564245652456624567245682456924570245712457224573245742457524576245772457824579245802458124582245832458424585245862458724588245892459024591245922459324594245952459624597245982459924600246012460224603246042460524606246072460824609246102461124612246132461424615246162461724618246192462024621246222462324624246252462624627246282462924630246312463224633246342463524636246372463824639246402464124642246432464424645246462464724648246492465024651246522465324654246552465624657246582465924660246612466224663246642466524666246672466824669246702467124672246732467424675246762467724678246792468024681246822468324684246852468624687246882468924690246912469224693246942469524696246972469824699247002470124702247032470424705247062470724708247092471024711247122471324714247152471624717247182471924720247212472224723247242472524726247272472824729247302473124732247332473424735247362473724738247392474024741247422474324744247452474624747247482474924750247512475224753247542475524756247572475824759247602476124762247632476424765247662476724768247692477024771247722477324774247752477624777247782477924780247812478224783247842478524786247872478824789247902479124792247932479424795247962479724798247992480024801248022480324804248052480624807248082480924810248112481224813248142481524816248172481824819248202482124822248232482424825248262482724828248292483024831248322483324834248352483624837248382483924840248412484224843248442484524846248472484824849248502485124852248532485424855248562485724858248592486024861248622486324864248652486624867248682486924870248712487224873248742487524876248772487824879248802488124882248832488424885248862488724888248892489024891248922489324894248952489624897248982489924900249012490224903249042490524906249072490824909249102491124912249132491424915249162491724918249192492024921249222492324924249252492624927249282492924930249312493224933249342493524936249372493824939249402494124942249432494424945249462494724948249492495024951249522495324954249552495624957249582495924960249612496224963249642496524966249672496824969249702497124972249732497424975249762497724978249792498024981249822498324984249852498624987249882498924990249912499224993249942499524996249972499824999250002500125002250032500425005250062500725008250092501025011250122501325014250152501625017250182501925020250212502225023250242502525026250272502825029250302503125032250332503425035250362503725038250392504025041250422504325044250452504625047250482504925050250512505225053250542505525056250572505825059250602506125062250632506425065250662506725068250692507025071250722507325074250752507625077250782507925080250812508225083250842508525086250872508825089250902509125092250932509425095250962509725098250992510025101251022510325104251052510625107251082510925110251112511225113251142511525116251172511825119251202512125122251232512425125251262512725128251292513025131251322513325134251352513625137251382513925140251412514225143251442514525146251472514825149251502515125152251532515425155251562515725158251592516025161251622516325164251652516625167251682516925170251712517225173251742517525176251772517825179251802518125182251832518425185251862518725188251892519025191251922519325194251952519625197251982519925200252012520225203252042520525206252072520825209252102521125212252132521425215252162521725218252192522025221252222522325224252252522625227252282522925230252312523225233252342523525236252372523825239252402524125242252432524425245252462524725248252492525025251252522525325254252552525625257252582525925260252612526225263252642526525266252672526825269252702527125272252732527425275252762527725278252792528025281252822528325284252852528625287252882528925290252912529225293252942529525296252972529825299253002530125302253032530425305253062530725308253092531025311253122531325314253152531625317253182531925320253212532225323253242532525326253272532825329253302533125332253332533425335253362533725338253392534025341253422534325344253452534625347253482534925350253512535225353253542535525356253572535825359253602536125362253632536425365253662536725368253692537025371253722537325374253752537625377253782537925380253812538225383253842538525386253872538825389253902539125392253932539425395253962539725398253992540025401254022540325404254052540625407254082540925410254112541225413254142541525416254172541825419254202542125422254232542425425254262542725428254292543025431254322543325434254352543625437254382543925440254412544225443254442544525446254472544825449254502545125452254532545425455254562545725458254592546025461254622546325464254652546625467254682546925470254712547225473254742547525476254772547825479254802548125482254832548425485254862548725488254892549025491254922549325494254952549625497254982549925500255012550225503255042550525506255072550825509255102551125512255132551425515255162551725518255192552025521255222552325524255252552625527255282552925530255312553225533255342553525536255372553825539255402554125542255432554425545255462554725548255492555025551255522555325554255552555625557255582555925560255612556225563255642556525566255672556825569255702557125572255732557425575255762557725578255792558025581255822558325584255852558625587255882558925590255912559225593255942559525596255972559825599256002560125602256032560425605256062560725608256092561025611256122561325614256152561625617256182561925620256212562225623256242562525626256272562825629256302563125632256332563425635256362563725638256392564025641256422564325644256452564625647256482564925650256512565225653256542565525656256572565825659256602566125662256632566425665256662566725668256692567025671256722567325674256752567625677256782567925680256812568225683256842568525686256872568825689256902569125692256932569425695256962569725698256992570025701257022570325704257052570625707257082570925710257112571225713257142571525716257172571825719257202572125722257232572425725257262572725728257292573025731257322573325734257352573625737257382573925740257412574225743257442574525746257472574825749257502575125752257532575425755257562575725758257592576025761257622576325764257652576625767257682576925770257712577225773257742577525776257772577825779257802578125782257832578425785257862578725788257892579025791257922579325794257952579625797257982579925800258012580225803258042580525806258072580825809258102581125812258132581425815258162581725818258192582025821258222582325824258252582625827258282582925830258312583225833258342583525836258372583825839258402584125842258432584425845258462584725848258492585025851258522585325854258552585625857258582585925860258612586225863258642586525866258672586825869258702587125872258732587425875258762587725878258792588025881258822588325884258852588625887258882588925890258912589225893258942589525896258972589825899259002590125902259032590425905259062590725908259092591025911259122591325914259152591625917259182591925920259212592225923259242592525926259272592825929259302593125932259332593425935259362593725938259392594025941259422594325944259452594625947259482594925950259512595225953259542595525956259572595825959259602596125962259632596425965259662596725968259692597025971259722597325974259752597625977259782597925980259812598225983259842598525986259872598825989259902599125992259932599425995259962599725998259992600026001260022600326004260052600626007260082600926010260112601226013260142601526016260172601826019260202602126022260232602426025260262602726028260292603026031260322603326034260352603626037260382603926040260412604226043260442604526046260472604826049260502605126052260532605426055260562605726058260592606026061260622606326064260652606626067260682606926070260712607226073260742607526076260772607826079260802608126082260832608426085260862608726088260892609026091260922609326094260952609626097260982609926100261012610226103261042610526106261072610826109261102611126112261132611426115261162611726118261192612026121261222612326124261252612626127261282612926130261312613226133261342613526136261372613826139261402614126142261432614426145261462614726148261492615026151261522615326154261552615626157261582615926160261612616226163261642616526166261672616826169261702617126172261732617426175261762617726178261792618026181261822618326184261852618626187261882618926190261912619226193261942619526196261972619826199262002620126202262032620426205262062620726208262092621026211262122621326214262152621626217262182621926220262212622226223262242622526226262272622826229262302623126232262332623426235262362623726238262392624026241262422624326244262452624626247262482624926250262512625226253262542625526256262572625826259262602626126262262632626426265262662626726268262692627026271262722627326274262752627626277262782627926280262812628226283262842628526286262872628826289262902629126292262932629426295262962629726298262992630026301263022630326304263052630626307263082630926310263112631226313263142631526316263172631826319263202632126322263232632426325263262632726328263292633026331263322633326334263352633626337263382633926340263412634226343263442634526346263472634826349263502635126352263532635426355263562635726358263592636026361263622636326364263652636626367263682636926370263712637226373263742637526376263772637826379263802638126382263832638426385263862638726388263892639026391263922639326394263952639626397263982639926400264012640226403264042640526406264072640826409264102641126412264132641426415264162641726418264192642026421264222642326424264252642626427264282642926430264312643226433264342643526436264372643826439264402644126442264432644426445264462644726448264492645026451264522645326454264552645626457264582645926460264612646226463264642646526466264672646826469264702647126472264732647426475264762647726478264792648026481264822648326484264852648626487264882648926490264912649226493264942649526496264972649826499265002650126502265032650426505265062650726508265092651026511265122651326514265152651626517265182651926520265212652226523265242652526526265272652826529265302653126532265332653426535265362653726538265392654026541265422654326544265452654626547265482654926550265512655226553265542655526556265572655826559265602656126562265632656426565265662656726568265692657026571265722657326574265752657626577265782657926580265812658226583265842658526586265872658826589265902659126592265932659426595265962659726598265992660026601266022660326604266052660626607266082660926610266112661226613266142661526616266172661826619266202662126622266232662426625266262662726628266292663026631266322663326634266352663626637266382663926640266412664226643266442664526646266472664826649266502665126652266532665426655266562665726658266592666026661266622666326664266652666626667266682666926670266712667226673266742667526676266772667826679266802668126682266832668426685266862668726688266892669026691266922669326694266952669626697266982669926700267012670226703267042670526706267072670826709267102671126712267132671426715267162671726718267192672026721267222672326724267252672626727267282672926730267312673226733267342673526736267372673826739267402674126742267432674426745267462674726748267492675026751267522675326754267552675626757267582675926760267612676226763267642676526766267672676826769267702677126772267732677426775267762677726778267792678026781267822678326784267852678626787267882678926790267912679226793267942679526796267972679826799268002680126802268032680426805268062680726808268092681026811268122681326814268152681626817268182681926820268212682226823268242682526826268272682826829268302683126832268332683426835268362683726838268392684026841268422684326844268452684626847268482684926850268512685226853268542685526856268572685826859268602686126862268632686426865268662686726868268692687026871268722687326874268752687626877268782687926880268812688226883268842688526886268872688826889268902689126892268932689426895268962689726898268992690026901269022690326904269052690626907269082690926910269112691226913269142691526916269172691826919269202692126922269232692426925269262692726928269292693026931269322693326934269352693626937269382693926940269412694226943269442694526946269472694826949269502695126952269532695426955269562695726958269592696026961269622696326964269652696626967269682696926970269712697226973269742697526976269772697826979269802698126982269832698426985269862698726988269892699026991269922699326994269952699626997269982699927000270012700227003270042700527006270072700827009270102701127012270132701427015270162701727018270192702027021270222702327024270252702627027270282702927030270312703227033270342703527036270372703827039270402704127042270432704427045270462704727048270492705027051270522705327054270552705627057270582705927060270612706227063270642706527066270672706827069270702707127072270732707427075270762707727078270792708027081270822708327084270852708627087270882708927090270912709227093270942709527096270972709827099271002710127102271032710427105271062710727108271092711027111271122711327114271152711627117271182711927120271212712227123271242712527126271272712827129271302713127132271332713427135271362713727138271392714027141271422714327144271452714627147271482714927150271512715227153271542715527156271572715827159271602716127162271632716427165271662716727168271692717027171271722717327174271752717627177271782717927180271812718227183271842718527186271872718827189271902719127192271932719427195271962719727198271992720027201272022720327204272052720627207272082720927210272112721227213272142721527216272172721827219272202722127222272232722427225272262722727228272292723027231272322723327234272352723627237272382723927240272412724227243272442724527246272472724827249272502725127252272532725427255272562725727258272592726027261272622726327264272652726627267272682726927270272712727227273272742727527276272772727827279272802728127282272832728427285272862728727288272892729027291272922729327294272952729627297272982729927300273012730227303273042730527306273072730827309273102731127312273132731427315273162731727318273192732027321273222732327324273252732627327273282732927330273312733227333273342733527336273372733827339273402734127342273432734427345273462734727348273492735027351273522735327354273552735627357273582735927360273612736227363273642736527366273672736827369273702737127372273732737427375273762737727378273792738027381273822738327384273852738627387273882738927390273912739227393273942739527396273972739827399274002740127402274032740427405274062740727408274092741027411274122741327414274152741627417274182741927420274212742227423274242742527426274272742827429274302743127432274332743427435274362743727438274392744027441274422744327444274452744627447274482744927450274512745227453274542745527456274572745827459274602746127462274632746427465274662746727468274692747027471274722747327474274752747627477274782747927480274812748227483274842748527486274872748827489274902749127492274932749427495274962749727498274992750027501275022750327504275052750627507275082750927510275112751227513275142751527516275172751827519275202752127522275232752427525275262752727528275292753027531275322753327534275352753627537275382753927540275412754227543275442754527546275472754827549275502755127552275532755427555275562755727558275592756027561275622756327564275652756627567275682756927570275712757227573275742757527576275772757827579275802758127582275832758427585275862758727588275892759027591275922759327594275952759627597275982759927600276012760227603276042760527606276072760827609276102761127612276132761427615276162761727618276192762027621276222762327624276252762627627276282762927630276312763227633276342763527636276372763827639276402764127642276432764427645276462764727648276492765027651276522765327654276552765627657276582765927660276612766227663276642766527666276672766827669276702767127672276732767427675276762767727678276792768027681276822768327684276852768627687276882768927690276912769227693276942769527696276972769827699277002770127702277032770427705277062770727708277092771027711277122771327714277152771627717277182771927720277212772227723277242772527726277272772827729277302773127732277332773427735277362773727738277392774027741277422774327744277452774627747277482774927750277512775227753277542775527756277572775827759277602776127762277632776427765277662776727768277692777027771277722777327774277752777627777277782777927780277812778227783277842778527786277872778827789277902779127792277932779427795277962779727798277992780027801278022780327804278052780627807278082780927810278112781227813278142781527816278172781827819278202782127822278232782427825278262782727828278292783027831278322783327834278352783627837278382783927840278412784227843278442784527846278472784827849278502785127852278532785427855278562785727858278592786027861278622786327864278652786627867278682786927870278712787227873278742787527876278772787827879278802788127882278832788427885278862788727888278892789027891278922789327894278952789627897278982789927900279012790227903279042790527906279072790827909279102791127912279132791427915279162791727918279192792027921279222792327924279252792627927279282792927930279312793227933279342793527936279372793827939279402794127942279432794427945279462794727948279492795027951279522795327954279552795627957279582795927960279612796227963279642796527966279672796827969279702797127972279732797427975279762797727978279792798027981279822798327984279852798627987279882798927990279912799227993279942799527996279972799827999280002800128002280032800428005280062800728008280092801028011280122801328014280152801628017280182801928020280212802228023280242802528026280272802828029280302803128032280332803428035280362803728038280392804028041280422804328044280452804628047280482804928050280512805228053280542805528056280572805828059280602806128062280632806428065280662806728068280692807028071280722807328074280752807628077280782807928080280812808228083280842808528086280872808828089280902809128092280932809428095280962809728098280992810028101281022810328104281052810628107281082810928110281112811228113281142811528116281172811828119281202812128122281232812428125281262812728128281292813028131281322813328134281352813628137281382813928140281412814228143281442814528146281472814828149281502815128152281532815428155281562815728158281592816028161281622816328164281652816628167281682816928170281712817228173281742817528176281772817828179281802818128182281832818428185281862818728188281892819028191281922819328194281952819628197281982819928200282012820228203282042820528206282072820828209282102821128212282132821428215282162821728218282192822028221282222822328224282252822628227282282822928230282312823228233282342823528236282372823828239282402824128242282432824428245282462824728248282492825028251282522825328254282552825628257282582825928260282612826228263282642826528266282672826828269282702827128272282732827428275282762827728278282792828028281282822828328284282852828628287282882828928290282912829228293282942829528296282972829828299283002830128302283032830428305283062830728308283092831028311283122831328314283152831628317283182831928320283212832228323283242832528326283272832828329283302833128332283332833428335283362833728338283392834028341283422834328344283452834628347283482834928350283512835228353283542835528356283572835828359283602836128362283632836428365283662836728368283692837028371283722837328374283752837628377283782837928380283812838228383283842838528386283872838828389283902839128392283932839428395283962839728398283992840028401284022840328404284052840628407284082840928410284112841228413284142841528416284172841828419284202842128422284232842428425284262842728428284292843028431284322843328434284352843628437284382843928440284412844228443284442844528446284472844828449284502845128452284532845428455284562845728458284592846028461284622846328464284652846628467284682846928470284712847228473284742847528476284772847828479284802848128482284832848428485284862848728488284892849028491284922849328494284952849628497284982849928500285012850228503285042850528506285072850828509285102851128512285132851428515285162851728518285192852028521285222852328524285252852628527285282852928530285312853228533285342853528536285372853828539285402854128542285432854428545285462854728548285492855028551285522855328554285552855628557285582855928560285612856228563285642856528566285672856828569285702857128572285732857428575285762857728578285792858028581285822858328584285852858628587285882858928590285912859228593285942859528596285972859828599286002860128602286032860428605286062860728608286092861028611286122861328614286152861628617286182861928620286212862228623286242862528626286272862828629286302863128632286332863428635286362863728638286392864028641286422864328644286452864628647286482864928650286512865228653286542865528656286572865828659286602866128662286632866428665286662866728668286692867028671286722867328674286752867628677286782867928680286812868228683286842868528686286872868828689286902869128692286932869428695286962869728698286992870028701287022870328704287052870628707287082870928710287112871228713287142871528716287172871828719287202872128722287232872428725287262872728728287292873028731287322873328734287352873628737287382873928740287412874228743287442874528746287472874828749287502875128752287532875428755287562875728758287592876028761287622876328764287652876628767287682876928770287712877228773287742877528776287772877828779287802878128782287832878428785287862878728788287892879028791287922879328794287952879628797287982879928800288012880228803288042880528806288072880828809288102881128812288132881428815288162881728818288192882028821288222882328824288252882628827288282882928830288312883228833288342883528836288372883828839288402884128842288432884428845288462884728848288492885028851288522885328854288552885628857288582885928860288612886228863288642886528866288672886828869288702887128872288732887428875288762887728878288792888028881288822888328884288852888628887288882888928890288912889228893288942889528896288972889828899289002890128902289032890428905289062890728908289092891028911289122891328914289152891628917289182891928920289212892228923289242892528926289272892828929289302893128932289332893428935289362893728938289392894028941289422894328944289452894628947289482894928950289512895228953289542895528956289572895828959289602896128962289632896428965289662896728968289692897028971289722897328974289752897628977289782897928980289812898228983289842898528986289872898828989289902899128992289932899428995289962899728998289992900029001290022900329004290052900629007290082900929010290112901229013290142901529016290172901829019290202902129022290232902429025290262902729028290292903029031290322903329034290352903629037290382903929040290412904229043290442904529046290472904829049290502905129052290532905429055290562905729058290592906029061290622906329064290652906629067290682906929070290712907229073290742907529076290772907829079290802908129082290832908429085290862908729088290892909029091290922909329094290952909629097290982909929100291012910229103291042910529106291072910829109291102911129112291132911429115291162911729118291192912029121291222912329124291252912629127291282912929130291312913229133291342913529136291372913829139291402914129142291432914429145291462914729148291492915029151291522915329154291552915629157291582915929160291612916229163291642916529166291672916829169291702917129172291732917429175291762917729178291792918029181291822918329184291852918629187291882918929190291912919229193291942919529196291972919829199292002920129202292032920429205292062920729208292092921029211292122921329214292152921629217292182921929220292212922229223292242922529226292272922829229292302923129232292332923429235292362923729238292392924029241292422924329244292452924629247292482924929250292512925229253292542925529256292572925829259292602926129262292632926429265292662926729268292692927029271292722927329274292752927629277292782927929280292812928229283292842928529286292872928829289292902929129292292932929429295292962929729298292992930029301293022930329304293052930629307293082930929310293112931229313293142931529316293172931829319293202932129322293232932429325293262932729328293292933029331293322933329334293352933629337293382933929340293412934229343293442934529346293472934829349293502935129352293532935429355293562935729358293592936029361293622936329364293652936629367293682936929370293712937229373293742937529376293772937829379293802938129382293832938429385293862938729388293892939029391293922939329394293952939629397293982939929400294012940229403294042940529406294072940829409294102941129412294132941429415294162941729418294192942029421294222942329424294252942629427294282942929430294312943229433294342943529436294372943829439294402944129442294432944429445294462944729448294492945029451294522945329454294552945629457294582945929460294612946229463294642946529466294672946829469294702947129472294732947429475294762947729478294792948029481294822948329484294852948629487294882948929490294912949229493294942949529496294972949829499295002950129502295032950429505295062950729508295092951029511295122951329514295152951629517295182951929520295212952229523295242952529526295272952829529295302953129532295332953429535295362953729538295392954029541295422954329544295452954629547295482954929550295512955229553295542955529556295572955829559295602956129562295632956429565295662956729568295692957029571295722957329574295752957629577295782957929580295812958229583295842958529586295872958829589295902959129592295932959429595295962959729598295992960029601296022960329604296052960629607296082960929610296112961229613296142961529616296172961829619296202962129622296232962429625296262962729628296292963029631296322963329634296352963629637296382963929640296412964229643296442964529646296472964829649296502965129652296532965429655296562965729658296592966029661296622966329664296652966629667296682966929670296712967229673296742967529676296772967829679296802968129682296832968429685296862968729688296892969029691296922969329694296952969629697296982969929700297012970229703297042970529706297072970829709297102971129712297132971429715297162971729718297192972029721297222972329724297252972629727297282972929730297312973229733297342973529736297372973829739297402974129742297432974429745297462974729748297492975029751297522975329754297552975629757297582975929760297612976229763297642976529766297672976829769297702977129772297732977429775297762977729778297792978029781297822978329784297852978629787297882978929790297912979229793297942979529796297972979829799298002980129802298032980429805298062980729808298092981029811298122981329814298152981629817298182981929820298212982229823298242982529826298272982829829298302983129832298332983429835298362983729838298392984029841298422984329844298452984629847298482984929850298512985229853298542985529856298572985829859298602986129862298632986429865298662986729868298692987029871298722987329874298752987629877298782987929880298812988229883298842988529886298872988829889298902989129892298932989429895298962989729898298992990029901299022990329904299052990629907299082990929910299112991229913299142991529916299172991829919299202992129922299232992429925299262992729928299292993029931299322993329934299352993629937299382993929940299412994229943299442994529946299472994829949299502995129952299532995429955299562995729958299592996029961299622996329964299652996629967299682996929970299712997229973299742997529976299772997829979299802998129982299832998429985299862998729988299892999029991299922999329994299952999629997299982999930000300013000230003300043000530006300073000830009300103001130012300133001430015300163001730018300193002030021300223002330024300253002630027300283002930030300313003230033300343003530036300373003830039300403004130042300433004430045300463004730048300493005030051300523005330054300553005630057300583005930060300613006230063300643006530066300673006830069300703007130072300733007430075300763007730078300793008030081300823008330084300853008630087300883008930090300913009230093300943009530096300973009830099301003010130102301033010430105301063010730108301093011030111301123011330114301153011630117301183011930120301213012230123301243012530126301273012830129301303013130132301333013430135301363013730138301393014030141301423014330144301453014630147301483014930150301513015230153301543015530156301573015830159301603016130162301633016430165301663016730168301693017030171301723017330174301753017630177301783017930180301813018230183301843018530186301873018830189301903019130192301933019430195301963019730198301993020030201302023020330204302053020630207302083020930210302113021230213302143021530216302173021830219302203022130222302233022430225302263022730228302293023030231302323023330234302353023630237302383023930240302413024230243302443024530246302473024830249302503025130252302533025430255302563025730258302593026030261302623026330264302653026630267302683026930270302713027230273302743027530276302773027830279302803028130282302833028430285302863028730288302893029030291302923029330294302953029630297302983029930300303013030230303303043030530306303073030830309303103031130312303133031430315303163031730318303193032030321303223032330324303253032630327303283032930330303313033230333303343033530336303373033830339303403034130342303433034430345303463034730348303493035030351303523035330354303553035630357303583035930360303613036230363303643036530366303673036830369303703037130372303733037430375303763037730378303793038030381303823038330384303853038630387303883038930390303913039230393303943039530396303973039830399304003040130402304033040430405304063040730408304093041030411304123041330414304153041630417304183041930420304213042230423304243042530426304273042830429304303043130432304333043430435304363043730438304393044030441304423044330444304453044630447304483044930450304513045230453304543045530456304573045830459304603046130462304633046430465304663046730468304693047030471304723047330474304753047630477304783047930480304813048230483304843048530486304873048830489304903049130492304933049430495304963049730498304993050030501305023050330504305053050630507305083050930510305113051230513305143051530516305173051830519305203052130522305233052430525305263052730528305293053030531305323053330534305353053630537305383053930540305413054230543305443054530546305473054830549305503055130552305533055430555305563055730558305593056030561305623056330564305653056630567305683056930570305713057230573305743057530576305773057830579305803058130582305833058430585305863058730588305893059030591305923059330594305953059630597305983059930600306013060230603306043060530606306073060830609306103061130612306133061430615306163061730618306193062030621306223062330624306253062630627306283062930630306313063230633306343063530636306373063830639306403064130642306433064430645306463064730648306493065030651306523065330654306553065630657306583065930660306613066230663306643066530666306673066830669306703067130672306733067430675306763067730678306793068030681306823068330684306853068630687306883068930690306913069230693306943069530696306973069830699307003070130702307033070430705307063070730708307093071030711307123071330714307153071630717307183071930720307213072230723307243072530726307273072830729307303073130732307333073430735307363073730738307393074030741307423074330744307453074630747307483074930750307513075230753307543075530756307573075830759307603076130762307633076430765307663076730768307693077030771307723077330774307753077630777307783077930780307813078230783307843078530786307873078830789307903079130792307933079430795307963079730798307993080030801308023080330804308053080630807308083080930810308113081230813308143081530816308173081830819308203082130822308233082430825308263082730828308293083030831308323083330834308353083630837308383083930840308413084230843308443084530846308473084830849308503085130852308533085430855308563085730858308593086030861308623086330864308653086630867308683086930870308713087230873308743087530876308773087830879308803088130882308833088430885308863088730888308893089030891308923089330894308953089630897308983089930900309013090230903309043090530906309073090830909309103091130912309133091430915309163091730918309193092030921309223092330924309253092630927309283092930930309313093230933309343093530936309373093830939309403094130942309433094430945309463094730948309493095030951309523095330954309553095630957309583095930960309613096230963309643096530966309673096830969309703097130972309733097430975309763097730978309793098030981309823098330984309853098630987309883098930990309913099230993309943099530996309973099830999310003100131002310033100431005310063100731008310093101031011310123101331014310153101631017310183101931020310213102231023310243102531026310273102831029310303103131032310333103431035310363103731038310393104031041310423104331044310453104631047310483104931050310513105231053310543105531056310573105831059310603106131062310633106431065310663106731068310693107031071310723107331074310753107631077310783107931080310813108231083310843108531086310873108831089310903109131092310933109431095310963109731098310993110031101311023110331104311053110631107311083110931110311113111231113311143111531116311173111831119311203112131122311233112431125311263112731128311293113031131311323113331134311353113631137311383113931140311413114231143311443114531146311473114831149311503115131152311533115431155311563115731158311593116031161311623116331164311653116631167311683116931170311713117231173311743117531176311773117831179311803118131182311833118431185311863118731188311893119031191311923119331194311953119631197311983119931200312013120231203312043120531206312073120831209312103121131212312133121431215312163121731218312193122031221312223122331224312253122631227312283122931230312313123231233312343123531236312373123831239312403124131242312433124431245312463124731248312493125031251312523125331254312553125631257312583125931260312613126231263312643126531266312673126831269312703127131272312733127431275312763127731278312793128031281312823128331284312853128631287312883128931290312913129231293312943129531296312973129831299313003130131302313033130431305313063130731308313093131031311313123131331314313153131631317313183131931320313213132231323313243132531326313273132831329313303133131332313333133431335313363133731338313393134031341313423134331344313453134631347313483134931350313513135231353313543135531356313573135831359313603136131362313633136431365313663136731368313693137031371313723137331374313753137631377313783137931380313813138231383313843138531386313873138831389313903139131392313933139431395313963139731398313993140031401314023140331404314053140631407314083140931410314113141231413314143141531416314173141831419314203142131422314233142431425314263142731428314293143031431314323143331434314353143631437314383143931440314413144231443314443144531446314473144831449314503145131452314533145431455314563145731458314593146031461314623146331464314653146631467314683146931470314713147231473314743147531476314773147831479314803148131482314833148431485314863148731488314893149031491314923149331494314953149631497314983149931500315013150231503315043150531506315073150831509315103151131512315133151431515315163151731518315193152031521315223152331524315253152631527315283152931530315313153231533315343153531536315373153831539315403154131542315433154431545315463154731548315493155031551315523155331554315553155631557315583155931560315613156231563315643156531566315673156831569315703157131572315733157431575315763157731578315793158031581315823158331584315853158631587315883158931590315913159231593315943159531596315973159831599316003160131602316033160431605316063160731608316093161031611316123161331614316153161631617316183161931620316213162231623316243162531626316273162831629316303163131632316333163431635316363163731638316393164031641316423164331644316453164631647316483164931650316513165231653316543165531656316573165831659316603166131662316633166431665316663166731668316693167031671316723167331674316753167631677316783167931680316813168231683316843168531686316873168831689316903169131692316933169431695316963169731698316993170031701317023170331704317053170631707317083170931710317113171231713317143171531716317173171831719317203172131722317233172431725317263172731728317293173031731317323173331734317353173631737317383173931740317413174231743317443174531746317473174831749317503175131752317533175431755317563175731758317593176031761317623176331764317653176631767317683176931770317713177231773317743177531776317773177831779317803178131782317833178431785317863178731788317893179031791317923179331794317953179631797317983179931800318013180231803318043180531806318073180831809318103181131812318133181431815318163181731818318193182031821318223182331824318253182631827318283182931830318313183231833318343183531836318373183831839318403184131842318433184431845318463184731848318493185031851318523185331854318553185631857318583185931860318613186231863318643186531866318673186831869318703187131872318733187431875318763187731878318793188031881318823188331884318853188631887318883188931890318913189231893318943189531896318973189831899319003190131902319033190431905319063190731908319093191031911319123191331914319153191631917319183191931920319213192231923319243192531926319273192831929319303193131932319333193431935319363193731938319393194031941319423194331944319453194631947319483194931950319513195231953319543195531956319573195831959319603196131962319633196431965319663196731968319693197031971319723197331974319753197631977319783197931980319813198231983319843198531986319873198831989319903199131992319933199431995319963199731998319993200032001320023200332004320053200632007320083200932010320113201232013320143201532016320173201832019320203202132022320233202432025320263202732028320293203032031320323203332034320353203632037320383203932040320413204232043320443204532046320473204832049320503205132052320533205432055320563205732058320593206032061320623206332064320653206632067320683206932070320713207232073320743207532076320773207832079320803208132082320833208432085320863208732088320893209032091320923209332094320953209632097320983209932100321013210232103321043210532106321073210832109321103211132112321133211432115321163211732118321193212032121321223212332124321253212632127321283212932130321313213232133321343213532136321373213832139321403214132142321433214432145321463214732148321493215032151321523215332154321553215632157321583215932160321613216232163321643216532166321673216832169321703217132172321733217432175321763217732178321793218032181321823218332184321853218632187321883218932190321913219232193321943219532196321973219832199322003220132202322033220432205322063220732208322093221032211322123221332214322153221632217322183221932220322213222232223322243222532226322273222832229322303223132232322333223432235322363223732238322393224032241322423224332244322453224632247322483224932250322513225232253322543225532256322573225832259322603226132262322633226432265322663226732268322693227032271322723227332274322753227632277322783227932280322813228232283322843228532286322873228832289322903229132292322933229432295322963229732298322993230032301323023230332304323053230632307323083230932310323113231232313323143231532316323173231832319323203232132322323233232432325323263232732328323293233032331323323233332334323353233632337323383233932340323413234232343323443234532346323473234832349323503235132352323533235432355323563235732358323593236032361323623236332364323653236632367323683236932370323713237232373323743237532376323773237832379323803238132382323833238432385323863238732388323893239032391323923239332394323953239632397323983239932400324013240232403324043240532406324073240832409324103241132412324133241432415324163241732418324193242032421324223242332424324253242632427324283242932430324313243232433324343243532436324373243832439324403244132442324433244432445324463244732448324493245032451324523245332454324553245632457324583245932460324613246232463324643246532466324673246832469324703247132472324733247432475324763247732478324793248032481324823248332484324853248632487324883248932490324913249232493324943249532496324973249832499325003250132502325033250432505325063250732508325093251032511325123251332514325153251632517325183251932520325213252232523325243252532526325273252832529325303253132532325333253432535325363253732538325393254032541325423254332544325453254632547325483254932550325513255232553325543255532556325573255832559325603256132562325633256432565325663256732568325693257032571325723257332574325753257632577325783257932580325813258232583325843258532586325873258832589325903259132592325933259432595325963259732598325993260032601326023260332604326053260632607326083260932610326113261232613326143261532616326173261832619326203262132622326233262432625326263262732628326293263032631326323263332634326353263632637326383263932640326413264232643326443264532646326473264832649326503265132652326533265432655326563265732658326593266032661326623266332664326653266632667326683266932670326713267232673326743267532676326773267832679326803268132682326833268432685326863268732688326893269032691326923269332694326953269632697326983269932700327013270232703327043270532706327073270832709327103271132712327133271432715327163271732718327193272032721327223272332724327253272632727327283272932730327313273232733327343273532736327373273832739327403274132742327433274432745327463274732748327493275032751327523275332754327553275632757327583275932760327613276232763327643276532766327673276832769327703277132772327733277432775327763277732778327793278032781327823278332784327853278632787327883278932790327913279232793327943279532796327973279832799328003280132802328033280432805328063280732808328093281032811328123281332814328153281632817328183281932820328213282232823328243282532826328273282832829328303283132832328333283432835328363283732838328393284032841328423284332844328453284632847328483284932850328513285232853328543285532856328573285832859328603286132862328633286432865328663286732868328693287032871328723287332874328753287632877328783287932880328813288232883328843288532886328873288832889328903289132892328933289432895328963289732898328993290032901329023290332904329053290632907329083290932910329113291232913329143291532916329173291832919329203292132922329233292432925329263292732928329293293032931329323293332934329353293632937329383293932940329413294232943329443294532946329473294832949329503295132952329533295432955329563295732958329593296032961329623296332964329653296632967329683296932970329713297232973329743297532976329773297832979329803298132982329833298432985329863298732988329893299032991329923299332994329953299632997329983299933000330013300233003330043300533006330073300833009330103301133012330133301433015330163301733018330193302033021330223302333024330253302633027330283302933030330313303233033330343303533036330373303833039330403304133042330433304433045330463304733048330493305033051330523305333054330553305633057330583305933060330613306233063330643306533066330673306833069330703307133072330733307433075330763307733078330793308033081330823308333084330853308633087330883308933090330913309233093330943309533096330973309833099331003310133102331033310433105331063310733108331093311033111331123311333114331153311633117331183311933120331213312233123331243312533126331273312833129331303313133132331333313433135331363313733138331393314033141331423314333144331453314633147331483314933150331513315233153331543315533156331573315833159331603316133162331633316433165331663316733168331693317033171331723317333174331753317633177331783317933180331813318233183331843318533186331873318833189331903319133192331933319433195331963319733198331993320033201332023320333204332053320633207332083320933210332113321233213332143321533216332173321833219332203322133222332233322433225332263322733228332293323033231332323323333234332353323633237332383323933240332413324233243332443324533246332473324833249332503325133252332533325433255332563325733258332593326033261332623326333264332653326633267332683326933270332713327233273332743327533276332773327833279332803328133282332833328433285332863328733288332893329033291332923329333294332953329633297332983329933300333013330233303333043330533306333073330833309333103331133312333133331433315333163331733318333193332033321333223332333324333253332633327333283332933330333313333233333333343333533336333373333833339333403334133342333433334433345333463334733348333493335033351333523335333354333553335633357333583335933360333613336233363333643336533366333673336833369333703337133372333733337433375333763337733378333793338033381333823338333384333853338633387333883338933390333913339233393333943339533396333973339833399334003340133402334033340433405334063340733408334093341033411334123341333414334153341633417334183341933420334213342233423334243342533426334273342833429334303343133432334333343433435334363343733438334393344033441334423344333444334453344633447334483344933450334513345233453334543345533456334573345833459334603346133462334633346433465334663346733468334693347033471334723347333474334753347633477334783347933480334813348233483334843348533486334873348833489334903349133492334933349433495334963349733498334993350033501335023350333504335053350633507335083350933510335113351233513335143351533516335173351833519335203352133522335233352433525335263352733528335293353033531335323353333534335353353633537335383353933540335413354233543335443354533546335473354833549335503355133552335533355433555335563355733558335593356033561335623356333564335653356633567335683356933570335713357233573335743357533576335773357833579335803358133582335833358433585335863358733588335893359033591335923359333594335953359633597335983359933600336013360233603336043360533606336073360833609336103361133612336133361433615336163361733618336193362033621336223362333624336253362633627336283362933630336313363233633336343363533636336373363833639336403364133642336433364433645336463364733648336493365033651336523365333654336553365633657336583365933660336613366233663336643366533666336673366833669336703367133672336733367433675336763367733678336793368033681336823368333684336853368633687336883368933690336913369233693336943369533696336973369833699337003370133702337033370433705337063370733708337093371033711337123371333714337153371633717337183371933720337213372233723337243372533726337273372833729337303373133732337333373433735337363373733738337393374033741337423374333744337453374633747337483374933750337513375233753337543375533756337573375833759337603376133762337633376433765337663376733768337693377033771337723377333774337753377633777337783377933780337813378233783337843378533786337873378833789337903379133792337933379433795337963379733798337993380033801338023380333804338053380633807338083380933810338113381233813338143381533816338173381833819338203382133822338233382433825338263382733828338293383033831338323383333834338353383633837338383383933840338413384233843338443384533846338473384833849338503385133852338533385433855338563385733858338593386033861338623386333864338653386633867338683386933870338713387233873338743387533876338773387833879338803388133882338833388433885338863388733888338893389033891338923389333894338953389633897338983389933900339013390233903339043390533906339073390833909339103391133912339133391433915339163391733918339193392033921339223392333924339253392633927339283392933930339313393233933339343393533936339373393833939339403394133942339433394433945339463394733948339493395033951339523395333954339553395633957339583395933960339613396233963339643396533966339673396833969339703397133972339733397433975339763397733978339793398033981339823398333984339853398633987339883398933990339913399233993339943399533996339973399833999340003400134002340033400434005340063400734008340093401034011340123401334014340153401634017340183401934020340213402234023340243402534026340273402834029340303403134032340333403434035340363403734038340393404034041340423404334044340453404634047340483404934050340513405234053340543405534056340573405834059340603406134062340633406434065340663406734068340693407034071340723407334074340753407634077340783407934080340813408234083340843408534086340873408834089340903409134092340933409434095340963409734098340993410034101341023410334104341053410634107341083410934110341113411234113341143411534116341173411834119341203412134122341233412434125341263412734128341293413034131341323413334134341353413634137341383413934140341413414234143341443414534146341473414834149341503415134152341533415434155341563415734158341593416034161341623416334164341653416634167341683416934170341713417234173341743417534176341773417834179341803418134182341833418434185341863418734188341893419034191341923419334194341953419634197341983419934200342013420234203342043420534206342073420834209342103421134212342133421434215342163421734218342193422034221342223422334224342253422634227342283422934230342313423234233342343423534236342373423834239342403424134242342433424434245342463424734248342493425034251342523425334254342553425634257342583425934260342613426234263342643426534266342673426834269342703427134272342733427434275342763427734278342793428034281342823428334284342853428634287342883428934290342913429234293342943429534296342973429834299343003430134302343033430434305343063430734308343093431034311343123431334314343153431634317343183431934320343213432234323343243432534326343273432834329343303433134332343333433434335343363433734338343393434034341343423434334344343453434634347343483434934350343513435234353343543435534356343573435834359343603436134362343633436434365343663436734368343693437034371343723437334374343753437634377343783437934380343813438234383343843438534386343873438834389343903439134392343933439434395343963439734398343993440034401344023440334404344053440634407344083440934410344113441234413344143441534416344173441834419344203442134422344233442434425344263442734428344293443034431344323443334434344353443634437344383443934440344413444234443344443444534446344473444834449344503445134452344533445434455344563445734458344593446034461344623446334464344653446634467344683446934470344713447234473344743447534476344773447834479344803448134482344833448434485344863448734488344893449034491344923449334494344953449634497344983449934500345013450234503345043450534506345073450834509345103451134512345133451434515345163451734518345193452034521345223452334524345253452634527345283452934530345313453234533345343453534536345373453834539345403454134542345433454434545345463454734548345493455034551345523455334554345553455634557345583455934560345613456234563345643456534566345673456834569345703457134572345733457434575345763457734578345793458034581345823458334584345853458634587345883458934590345913459234593345943459534596345973459834599346003460134602346033460434605346063460734608346093461034611346123461334614346153461634617346183461934620346213462234623346243462534626346273462834629346303463134632346333463434635346363463734638346393464034641346423464334644346453464634647346483464934650346513465234653346543465534656346573465834659346603466134662346633466434665346663466734668346693467034671346723467334674346753467634677346783467934680346813468234683346843468534686346873468834689346903469134692346933469434695346963469734698346993470034701347023470334704347053470634707347083470934710347113471234713347143471534716347173471834719347203472134722347233472434725347263472734728347293473034731347323473334734347353473634737347383473934740347413474234743347443474534746347473474834749347503475134752347533475434755347563475734758347593476034761347623476334764347653476634767347683476934770347713477234773347743477534776347773477834779347803478134782347833478434785347863478734788347893479034791347923479334794347953479634797347983479934800348013480234803348043480534806348073480834809348103481134812348133481434815348163481734818348193482034821348223482334824348253482634827348283482934830348313483234833348343483534836348373483834839348403484134842348433484434845348463484734848348493485034851348523485334854348553485634857348583485934860348613486234863348643486534866348673486834869348703487134872348733487434875348763487734878348793488034881348823488334884348853488634887348883488934890348913489234893348943489534896348973489834899349003490134902349033490434905349063490734908349093491034911349123491334914349153491634917349183491934920349213492234923349243492534926349273492834929349303493134932349333493434935349363493734938349393494034941349423494334944349453494634947349483494934950349513495234953349543495534956349573495834959349603496134962349633496434965349663496734968349693497034971349723497334974349753497634977349783497934980349813498234983349843498534986349873498834989349903499134992349933499434995349963499734998349993500035001350023500335004350053500635007350083500935010350113501235013350143501535016350173501835019350203502135022350233502435025350263502735028350293503035031350323503335034350353503635037350383503935040350413504235043350443504535046350473504835049350503505135052350533505435055350563505735058350593506035061350623506335064350653506635067350683506935070350713507235073350743507535076350773507835079350803508135082350833508435085350863508735088350893509035091350923509335094350953509635097350983509935100351013510235103351043510535106351073510835109351103511135112351133511435115351163511735118351193512035121351223512335124351253512635127351283512935130351313513235133351343513535136351373513835139351403514135142351433514435145351463514735148351493515035151351523515335154351553515635157351583515935160351613516235163351643516535166351673516835169351703517135172351733517435175351763517735178351793518035181351823518335184351853518635187351883518935190351913519235193351943519535196351973519835199352003520135202352033520435205352063520735208352093521035211352123521335214352153521635217352183521935220352213522235223352243522535226352273522835229352303523135232352333523435235352363523735238352393524035241352423524335244352453524635247352483524935250352513525235253352543525535256352573525835259352603526135262352633526435265352663526735268352693527035271352723527335274352753527635277352783527935280352813528235283352843528535286352873528835289352903529135292352933529435295352963529735298352993530035301353023530335304353053530635307353083530935310353113531235313353143531535316353173531835319353203532135322353233532435325353263532735328353293533035331353323533335334353353533635337353383533935340353413534235343353443534535346353473534835349353503535135352353533535435355353563535735358353593536035361353623536335364353653536635367353683536935370353713537235373353743537535376353773537835379353803538135382353833538435385353863538735388353893539035391353923539335394353953539635397353983539935400354013540235403354043540535406354073540835409354103541135412354133541435415354163541735418354193542035421354223542335424354253542635427354283542935430354313543235433354343543535436354373543835439354403544135442354433544435445354463544735448354493545035451354523545335454354553545635457354583545935460354613546235463354643546535466354673546835469354703547135472354733547435475354763547735478354793548035481354823548335484354853548635487354883548935490354913549235493354943549535496354973549835499355003550135502355033550435505355063550735508355093551035511355123551335514355153551635517355183551935520355213552235523355243552535526355273552835529355303553135532355333553435535355363553735538355393554035541355423554335544355453554635547355483554935550355513555235553355543555535556355573555835559355603556135562355633556435565355663556735568355693557035571355723557335574355753557635577355783557935580355813558235583355843558535586355873558835589355903559135592355933559435595355963559735598355993560035601356023560335604356053560635607356083560935610356113561235613356143561535616356173561835619356203562135622356233562435625356263562735628356293563035631356323563335634356353563635637356383563935640356413564235643356443564535646356473564835649356503565135652356533565435655356563565735658356593566035661356623566335664356653566635667356683566935670356713567235673356743567535676356773567835679356803568135682356833568435685356863568735688356893569035691356923569335694356953569635697356983569935700357013570235703357043570535706357073570835709357103571135712357133571435715357163571735718357193572035721357223572335724357253572635727357283572935730357313573235733357343573535736357373573835739357403574135742357433574435745357463574735748357493575035751357523575335754357553575635757357583575935760357613576235763357643576535766357673576835769357703577135772357733577435775357763577735778357793578035781357823578335784357853578635787357883578935790357913579235793357943579535796357973579835799358003580135802358033580435805358063580735808358093581035811358123581335814358153581635817358183581935820358213582235823358243582535826358273582835829358303583135832358333583435835358363583735838358393584035841358423584335844358453584635847358483584935850358513585235853358543585535856358573585835859358603586135862358633586435865358663586735868358693587035871358723587335874358753587635877358783587935880358813588235883358843588535886358873588835889358903589135892358933589435895358963589735898358993590035901359023590335904359053590635907359083590935910359113591235913359143591535916359173591835919359203592135922359233592435925359263592735928359293593035931359323593335934359353593635937359383593935940359413594235943359443594535946359473594835949359503595135952359533595435955359563595735958359593596035961359623596335964359653596635967359683596935970359713597235973359743597535976359773597835979359803598135982359833598435985359863598735988359893599035991359923599335994359953599635997359983599936000360013600236003360043600536006360073600836009360103601136012360133601436015360163601736018360193602036021360223602336024360253602636027360283602936030360313603236033360343603536036360373603836039360403604136042360433604436045360463604736048360493605036051360523605336054360553605636057360583605936060360613606236063360643606536066360673606836069360703607136072360733607436075360763607736078360793608036081360823608336084360853608636087360883608936090360913609236093360943609536096360973609836099361003610136102361033610436105361063610736108361093611036111361123611336114361153611636117361183611936120361213612236123361243612536126361273612836129361303613136132361333613436135361363613736138361393614036141361423614336144361453614636147361483614936150361513615236153361543615536156361573615836159361603616136162361633616436165361663616736168361693617036171361723617336174361753617636177361783617936180361813618236183361843618536186361873618836189361903619136192361933619436195361963619736198361993620036201362023620336204362053620636207362083620936210362113621236213362143621536216362173621836219362203622136222362233622436225362263622736228362293623036231362323623336234362353623636237362383623936240362413624236243362443624536246362473624836249362503625136252362533625436255362563625736258362593626036261362623626336264362653626636267362683626936270362713627236273362743627536276362773627836279362803628136282362833628436285362863628736288362893629036291362923629336294362953629636297362983629936300363013630236303363043630536306363073630836309363103631136312363133631436315363163631736318363193632036321363223632336324363253632636327363283632936330363313633236333363343633536336363373633836339363403634136342363433634436345363463634736348363493635036351363523635336354363553635636357363583635936360363613636236363363643636536366363673636836369363703637136372363733637436375363763637736378363793638036381363823638336384363853638636387363883638936390363913639236393363943639536396363973639836399364003640136402364033640436405364063640736408364093641036411364123641336414364153641636417364183641936420364213642236423364243642536426364273642836429364303643136432364333643436435364363643736438364393644036441364423644336444364453644636447364483644936450364513645236453364543645536456364573645836459364603646136462364633646436465364663646736468364693647036471364723647336474364753647636477364783647936480364813648236483364843648536486364873648836489364903649136492364933649436495364963649736498364993650036501365023650336504365053650636507365083650936510365113651236513365143651536516365173651836519365203652136522365233652436525365263652736528365293653036531365323653336534365353653636537365383653936540365413654236543365443654536546365473654836549365503655136552365533655436555365563655736558365593656036561365623656336564365653656636567365683656936570365713657236573365743657536576365773657836579365803658136582365833658436585365863658736588365893659036591365923659336594365953659636597365983659936600366013660236603366043660536606366073660836609366103661136612366133661436615366163661736618366193662036621366223662336624366253662636627366283662936630366313663236633366343663536636366373663836639366403664136642366433664436645366463664736648366493665036651366523665336654366553665636657366583665936660366613666236663366643666536666366673666836669366703667136672366733667436675366763667736678366793668036681366823668336684366853668636687366883668936690366913669236693366943669536696366973669836699367003670136702367033670436705367063670736708367093671036711367123671336714367153671636717367183671936720367213672236723367243672536726367273672836729367303673136732367333673436735367363673736738367393674036741367423674336744367453674636747367483674936750367513675236753367543675536756367573675836759367603676136762367633676436765367663676736768367693677036771367723677336774367753677636777367783677936780367813678236783367843678536786367873678836789367903679136792367933679436795367963679736798367993680036801368023680336804368053680636807368083680936810368113681236813368143681536816368173681836819368203682136822368233682436825368263682736828368293683036831368323683336834368353683636837368383683936840368413684236843368443684536846368473684836849368503685136852368533685436855368563685736858368593686036861368623686336864368653686636867368683686936870368713687236873368743687536876368773687836879368803688136882368833688436885368863688736888368893689036891368923689336894368953689636897368983689936900369013690236903369043690536906369073690836909369103691136912369133691436915369163691736918369193692036921369223692336924369253692636927369283692936930369313693236933369343693536936369373693836939369403694136942369433694436945369463694736948369493695036951369523695336954369553695636957369583695936960369613696236963369643696536966369673696836969369703697136972369733697436975369763697736978369793698036981369823698336984369853698636987369883698936990369913699236993369943699536996369973699836999370003700137002370033700437005370063700737008370093701037011370123701337014370153701637017370183701937020370213702237023370243702537026370273702837029370303703137032370333703437035370363703737038370393704037041370423704337044370453704637047370483704937050370513705237053370543705537056370573705837059370603706137062370633706437065370663706737068370693707037071370723707337074370753707637077370783707937080370813708237083370843708537086370873708837089370903709137092370933709437095370963709737098370993710037101371023710337104371053710637107371083710937110371113711237113371143711537116371173711837119371203712137122371233712437125371263712737128371293713037131371323713337134371353713637137371383713937140371413714237143371443714537146371473714837149371503715137152371533715437155371563715737158371593716037161371623716337164371653716637167371683716937170371713717237173371743717537176371773717837179371803718137182371833718437185371863718737188371893719037191371923719337194371953719637197371983719937200372013720237203372043720537206372073720837209372103721137212372133721437215372163721737218372193722037221372223722337224372253722637227372283722937230372313723237233372343723537236372373723837239372403724137242372433724437245372463724737248372493725037251372523725337254372553725637257372583725937260372613726237263372643726537266372673726837269372703727137272372733727437275372763727737278372793728037281372823728337284372853728637287372883728937290372913729237293372943729537296372973729837299373003730137302373033730437305373063730737308373093731037311373123731337314373153731637317373183731937320373213732237323373243732537326373273732837329373303733137332373333733437335373363733737338373393734037341373423734337344373453734637347373483734937350373513735237353373543735537356373573735837359373603736137362373633736437365373663736737368373693737037371373723737337374373753737637377373783737937380373813738237383373843738537386373873738837389373903739137392373933739437395373963739737398373993740037401374023740337404374053740637407374083740937410374113741237413374143741537416374173741837419374203742137422374233742437425374263742737428374293743037431374323743337434374353743637437374383743937440374413744237443374443744537446374473744837449374503745137452374533745437455374563745737458374593746037461374623746337464374653746637467374683746937470374713747237473374743747537476374773747837479374803748137482374833748437485374863748737488374893749037491374923749337494374953749637497374983749937500375013750237503375043750537506375073750837509375103751137512375133751437515375163751737518375193752037521375223752337524375253752637527375283752937530375313753237533375343753537536375373753837539375403754137542375433754437545375463754737548375493755037551375523755337554375553755637557375583755937560375613756237563375643756537566375673756837569375703757137572375733757437575375763757737578375793758037581375823758337584375853758637587375883758937590375913759237593375943759537596375973759837599376003760137602376033760437605376063760737608376093761037611376123761337614376153761637617376183761937620376213762237623376243762537626376273762837629376303763137632376333763437635376363763737638376393764037641376423764337644376453764637647376483764937650376513765237653376543765537656376573765837659376603766137662376633766437665376663766737668376693767037671376723767337674376753767637677376783767937680376813768237683376843768537686376873768837689376903769137692376933769437695376963769737698376993770037701377023770337704377053770637707377083770937710377113771237713377143771537716377173771837719377203772137722377233772437725377263772737728377293773037731377323773337734377353773637737377383773937740377413774237743377443774537746377473774837749377503775137752377533775437755377563775737758377593776037761377623776337764377653776637767377683776937770377713777237773377743777537776377773777837779377803778137782377833778437785377863778737788377893779037791377923779337794377953779637797377983779937800378013780237803378043780537806378073780837809378103781137812378133781437815378163781737818378193782037821378223782337824378253782637827378283782937830378313783237833378343783537836378373783837839378403784137842378433784437845378463784737848378493785037851378523785337854378553785637857378583785937860378613786237863378643786537866378673786837869378703787137872378733787437875378763787737878378793788037881378823788337884378853788637887378883788937890378913789237893378943789537896378973789837899379003790137902379033790437905379063790737908379093791037911379123791337914379153791637917379183791937920379213792237923379243792537926379273792837929379303793137932379333793437935379363793737938379393794037941379423794337944379453794637947379483794937950379513795237953379543795537956379573795837959379603796137962379633796437965379663796737968379693797037971379723797337974379753797637977379783797937980379813798237983379843798537986379873798837989379903799137992379933799437995379963799737998379993800038001380023800338004380053800638007380083800938010380113801238013380143801538016380173801838019380203802138022380233802438025380263802738028380293803038031380323803338034380353803638037380383803938040380413804238043380443804538046380473804838049380503805138052380533805438055380563805738058380593806038061380623806338064380653806638067380683806938070380713807238073380743807538076380773807838079380803808138082380833808438085380863808738088380893809038091380923809338094380953809638097380983809938100381013810238103381043810538106381073810838109381103811138112381133811438115381163811738118381193812038121381223812338124381253812638127381283812938130381313813238133381343813538136381373813838139381403814138142381433814438145381463814738148381493815038151381523815338154381553815638157381583815938160381613816238163381643816538166381673816838169381703817138172381733817438175381763817738178381793818038181381823818338184381853818638187381883818938190381913819238193381943819538196381973819838199382003820138202382033820438205382063820738208382093821038211382123821338214382153821638217382183821938220382213822238223382243822538226382273822838229382303823138232382333823438235382363823738238382393824038241382423824338244382453824638247382483824938250382513825238253382543825538256382573825838259382603826138262382633826438265382663826738268382693827038271382723827338274382753827638277382783827938280382813828238283382843828538286382873828838289382903829138292382933829438295382963829738298382993830038301383023830338304383053830638307383083830938310383113831238313383143831538316383173831838319383203832138322383233832438325383263832738328383293833038331383323833338334383353833638337383383833938340383413834238343383443834538346383473834838349383503835138352383533835438355383563835738358383593836038361383623836338364383653836638367383683836938370383713837238373383743837538376383773837838379383803838138382383833838438385383863838738388383893839038391383923839338394383953839638397383983839938400384013840238403384043840538406384073840838409384103841138412384133841438415384163841738418384193842038421384223842338424384253842638427384283842938430384313843238433384343843538436384373843838439384403844138442384433844438445384463844738448384493845038451384523845338454384553845638457384583845938460384613846238463384643846538466384673846838469384703847138472384733847438475384763847738478384793848038481384823848338484384853848638487384883848938490384913849238493384943849538496384973849838499385003850138502385033850438505385063850738508385093851038511385123851338514385153851638517385183851938520385213852238523385243852538526385273852838529385303853138532385333853438535385363853738538385393854038541385423854338544385453854638547385483854938550385513855238553385543855538556385573855838559385603856138562385633856438565385663856738568385693857038571385723857338574385753857638577385783857938580385813858238583385843858538586385873858838589385903859138592385933859438595385963859738598385993860038601386023860338604386053860638607386083860938610386113861238613386143861538616386173861838619386203862138622386233862438625386263862738628386293863038631386323863338634386353863638637386383863938640386413864238643386443864538646386473864838649386503865138652386533865438655386563865738658386593866038661386623866338664386653866638667386683866938670386713867238673386743867538676386773867838679386803868138682386833868438685386863868738688386893869038691386923869338694386953869638697386983869938700387013870238703387043870538706387073870838709387103871138712387133871438715387163871738718387193872038721387223872338724387253872638727387283872938730387313873238733387343873538736387373873838739387403874138742387433874438745387463874738748387493875038751387523875338754387553875638757387583875938760387613876238763387643876538766387673876838769387703877138772387733877438775387763877738778387793878038781387823878338784387853878638787387883878938790387913879238793387943879538796387973879838799388003880138802388033880438805388063880738808388093881038811388123881338814388153881638817388183881938820388213882238823388243882538826388273882838829388303883138832388333883438835388363883738838388393884038841388423884338844388453884638847388483884938850388513885238853388543885538856388573885838859388603886138862388633886438865388663886738868388693887038871388723887338874388753887638877388783887938880388813888238883388843888538886388873888838889388903889138892388933889438895388963889738898388993890038901389023890338904389053890638907389083890938910389113891238913389143891538916389173891838919389203892138922389233892438925389263892738928389293893038931389323893338934389353893638937389383893938940389413894238943389443894538946389473894838949389503895138952389533895438955389563895738958389593896038961389623896338964389653896638967389683896938970389713897238973389743897538976389773897838979389803898138982389833898438985389863898738988389893899038991389923899338994389953899638997389983899939000390013900239003390043900539006390073900839009390103901139012390133901439015390163901739018390193902039021390223902339024390253902639027390283902939030390313903239033390343903539036390373903839039390403904139042390433904439045390463904739048390493905039051390523905339054390553905639057390583905939060390613906239063390643906539066390673906839069390703907139072390733907439075390763907739078390793908039081390823908339084390853908639087390883908939090390913909239093390943909539096390973909839099391003910139102391033910439105391063910739108391093911039111391123911339114391153911639117391183911939120391213912239123391243912539126391273912839129391303913139132391333913439135391363913739138391393914039141391423914339144391453914639147391483914939150391513915239153391543915539156391573915839159391603916139162391633916439165391663916739168391693917039171391723917339174391753917639177391783917939180391813918239183391843918539186391873918839189391903919139192391933919439195391963919739198391993920039201392023920339204392053920639207392083920939210392113921239213392143921539216392173921839219392203922139222392233922439225392263922739228392293923039231392323923339234392353923639237392383923939240392413924239243392443924539246392473924839249392503925139252392533925439255392563925739258392593926039261392623926339264392653926639267392683926939270392713927239273392743927539276392773927839279392803928139282392833928439285392863928739288392893929039291392923929339294392953929639297392983929939300393013930239303393043930539306393073930839309393103931139312393133931439315393163931739318393193932039321393223932339324393253932639327393283932939330393313933239333393343933539336393373933839339393403934139342393433934439345393463934739348393493935039351393523935339354393553935639357393583935939360393613936239363393643936539366393673936839369393703937139372393733937439375393763937739378393793938039381393823938339384393853938639387393883938939390393913939239393393943939539396393973939839399394003940139402394033940439405394063940739408394093941039411394123941339414394153941639417394183941939420394213942239423394243942539426394273942839429394303943139432394333943439435394363943739438394393944039441394423944339444394453944639447394483944939450394513945239453394543945539456394573945839459394603946139462394633946439465394663946739468394693947039471394723947339474394753947639477394783947939480394813948239483394843948539486394873948839489394903949139492394933949439495394963949739498394993950039501395023950339504395053950639507395083950939510395113951239513395143951539516395173951839519395203952139522395233952439525395263952739528395293953039531395323953339534395353953639537395383953939540395413954239543395443954539546395473954839549395503955139552395533955439555395563955739558395593956039561395623956339564395653956639567395683956939570395713957239573395743957539576395773957839579395803958139582395833958439585395863958739588395893959039591395923959339594395953959639597395983959939600396013960239603396043960539606396073960839609396103961139612396133961439615396163961739618396193962039621396223962339624396253962639627396283962939630396313963239633396343963539636396373963839639396403964139642396433964439645396463964739648396493965039651396523965339654396553965639657396583965939660396613966239663396643966539666396673966839669396703967139672396733967439675396763967739678396793968039681396823968339684396853968639687396883968939690396913969239693396943969539696396973969839699397003970139702397033970439705397063970739708397093971039711397123971339714397153971639717397183971939720397213972239723397243972539726397273972839729397303973139732397333973439735397363973739738397393974039741397423974339744397453974639747397483974939750397513975239753397543975539756397573975839759397603976139762397633976439765397663976739768397693977039771397723977339774397753977639777397783977939780397813978239783397843978539786397873978839789397903979139792397933979439795397963979739798397993980039801398023980339804398053980639807398083980939810398113981239813398143981539816398173981839819398203982139822398233982439825398263982739828398293983039831398323983339834398353983639837398383983939840398413984239843398443984539846398473984839849398503985139852398533985439855398563985739858398593986039861398623986339864398653986639867398683986939870398713987239873398743987539876398773987839879398803988139882398833988439885398863988739888398893989039891398923989339894398953989639897398983989939900399013990239903399043990539906399073990839909399103991139912399133991439915399163991739918399193992039921399223992339924399253992639927399283992939930399313993239933399343993539936399373993839939399403994139942399433994439945399463994739948399493995039951399523995339954399553995639957399583995939960399613996239963399643996539966399673996839969399703997139972399733997439975399763997739978399793998039981399823998339984399853998639987399883998939990399913999239993399943999539996399973999839999400004000140002400034000440005400064000740008400094001040011400124001340014400154001640017400184001940020400214002240023400244002540026400274002840029400304003140032400334003440035400364003740038400394004040041400424004340044400454004640047400484004940050400514005240053400544005540056400574005840059400604006140062400634006440065400664006740068400694007040071400724007340074400754007640077400784007940080400814008240083400844008540086400874008840089400904009140092400934009440095400964009740098400994010040101401024010340104401054010640107401084010940110401114011240113401144011540116401174011840119401204012140122401234012440125401264012740128401294013040131401324013340134401354013640137401384013940140401414014240143401444014540146401474014840149401504015140152401534015440155401564015740158401594016040161401624016340164401654016640167401684016940170401714017240173401744017540176401774017840179401804018140182401834018440185401864018740188401894019040191401924019340194401954019640197401984019940200402014020240203402044020540206402074020840209402104021140212402134021440215402164021740218402194022040221402224022340224402254022640227402284022940230402314023240233402344023540236402374023840239402404024140242402434024440245402464024740248402494025040251402524025340254402554025640257402584025940260402614026240263402644026540266402674026840269402704027140272402734027440275402764027740278402794028040281402824028340284402854028640287402884028940290402914029240293402944029540296402974029840299403004030140302403034030440305403064030740308403094031040311403124031340314403154031640317403184031940320403214032240323403244032540326403274032840329403304033140332403334033440335403364033740338403394034040341403424034340344403454034640347403484034940350403514035240353403544035540356403574035840359403604036140362403634036440365403664036740368403694037040371403724037340374403754037640377403784037940380403814038240383403844038540386403874038840389403904039140392403934039440395403964039740398403994040040401404024040340404404054040640407404084040940410404114041240413404144041540416404174041840419404204042140422404234042440425404264042740428404294043040431404324043340434404354043640437404384043940440404414044240443404444044540446404474044840449404504045140452404534045440455404564045740458404594046040461404624046340464404654046640467404684046940470404714047240473404744047540476404774047840479404804048140482404834048440485404864048740488404894049040491404924049340494404954049640497404984049940500405014050240503405044050540506405074050840509405104051140512405134051440515405164051740518405194052040521405224052340524405254052640527405284052940530405314053240533405344053540536405374053840539405404054140542405434054440545405464054740548405494055040551405524055340554405554055640557405584055940560405614056240563405644056540566405674056840569405704057140572405734057440575405764057740578405794058040581405824058340584405854058640587405884058940590405914059240593405944059540596405974059840599406004060140602406034060440605406064060740608406094061040611406124061340614406154061640617406184061940620406214062240623406244062540626406274062840629406304063140632406334063440635406364063740638406394064040641406424064340644406454064640647406484064940650406514065240653406544065540656406574065840659406604066140662406634066440665406664066740668406694067040671406724067340674406754067640677406784067940680406814068240683406844068540686406874068840689406904069140692406934069440695406964069740698406994070040701407024070340704407054070640707407084070940710407114071240713407144071540716407174071840719407204072140722407234072440725407264072740728407294073040731407324073340734407354073640737407384073940740407414074240743407444074540746407474074840749407504075140752407534075440755407564075740758407594076040761407624076340764407654076640767407684076940770407714077240773407744077540776407774077840779407804078140782407834078440785407864078740788407894079040791407924079340794407954079640797407984079940800408014080240803408044080540806408074080840809408104081140812408134081440815408164081740818408194082040821408224082340824408254082640827408284082940830408314083240833408344083540836408374083840839408404084140842408434084440845408464084740848408494085040851408524085340854408554085640857408584085940860408614086240863408644086540866408674086840869408704087140872408734087440875408764087740878408794088040881408824088340884408854088640887408884088940890408914089240893408944089540896408974089840899409004090140902409034090440905409064090740908409094091040911409124091340914409154091640917409184091940920409214092240923409244092540926409274092840929409304093140932409334093440935409364093740938409394094040941409424094340944409454094640947409484094940950409514095240953409544095540956409574095840959409604096140962409634096440965409664096740968409694097040971409724097340974409754097640977409784097940980409814098240983409844098540986409874098840989409904099140992409934099440995409964099740998409994100041001410024100341004410054100641007410084100941010410114101241013410144101541016410174101841019410204102141022410234102441025410264102741028410294103041031410324103341034410354103641037410384103941040410414104241043410444104541046410474104841049410504105141052410534105441055410564105741058410594106041061410624106341064410654106641067410684106941070410714107241073410744107541076410774107841079410804108141082410834108441085410864108741088410894109041091410924109341094410954109641097410984109941100411014110241103411044110541106411074110841109411104111141112411134111441115411164111741118411194112041121411224112341124411254112641127411284112941130411314113241133411344113541136411374113841139411404114141142411434114441145411464114741148411494115041151411524115341154411554115641157411584115941160411614116241163411644116541166411674116841169411704117141172411734117441175411764117741178411794118041181411824118341184411854118641187411884118941190411914119241193411944119541196411974119841199412004120141202412034120441205412064120741208412094121041211412124121341214412154121641217412184121941220412214122241223412244122541226412274122841229412304123141232412334123441235412364123741238412394124041241412424124341244412454124641247412484124941250412514125241253412544125541256412574125841259412604126141262412634126441265412664126741268412694127041271412724127341274412754127641277412784127941280412814128241283412844128541286412874128841289412904129141292412934129441295412964129741298412994130041301413024130341304413054130641307413084130941310413114131241313413144131541316413174131841319413204132141322413234132441325413264132741328413294133041331413324133341334413354133641337413384133941340413414134241343413444134541346413474134841349413504135141352413534135441355413564135741358413594136041361413624136341364413654136641367413684136941370413714137241373413744137541376413774137841379413804138141382413834138441385413864138741388413894139041391413924139341394413954139641397413984139941400414014140241403414044140541406414074140841409414104141141412414134141441415414164141741418414194142041421414224142341424414254142641427414284142941430414314143241433414344143541436414374143841439414404144141442414434144441445414464144741448414494145041451414524145341454414554145641457414584145941460414614146241463414644146541466414674146841469414704147141472414734147441475414764147741478414794148041481414824148341484414854148641487414884148941490414914149241493414944149541496414974149841499415004150141502415034150441505415064150741508415094151041511415124151341514415154151641517415184151941520415214152241523415244152541526415274152841529415304153141532415334153441535415364153741538415394154041541415424154341544415454154641547415484154941550415514155241553415544155541556415574155841559415604156141562415634156441565415664156741568415694157041571415724157341574415754157641577415784157941580415814158241583415844158541586415874158841589415904159141592415934159441595415964159741598415994160041601416024160341604416054160641607416084160941610416114161241613416144161541616416174161841619416204162141622416234162441625416264162741628416294163041631416324163341634416354163641637416384163941640416414164241643416444164541646416474164841649416504165141652416534165441655416564165741658416594166041661416624166341664416654166641667416684166941670416714167241673416744167541676416774167841679416804168141682416834168441685416864168741688416894169041691416924169341694416954169641697416984169941700417014170241703417044170541706417074170841709417104171141712417134171441715417164171741718417194172041721417224172341724417254172641727417284172941730417314173241733417344173541736417374173841739417404174141742417434174441745417464174741748417494175041751417524175341754417554175641757417584175941760417614176241763417644176541766417674176841769417704177141772417734177441775417764177741778417794178041781417824178341784417854178641787417884178941790417914179241793417944179541796417974179841799418004180141802418034180441805418064180741808418094181041811418124181341814418154181641817418184181941820418214182241823418244182541826418274182841829418304183141832418334183441835418364183741838418394184041841418424184341844418454184641847418484184941850418514185241853418544185541856418574185841859418604186141862418634186441865418664186741868418694187041871418724187341874418754187641877418784187941880418814188241883418844188541886418874188841889418904189141892418934189441895418964189741898418994190041901419024190341904419054190641907419084190941910419114191241913419144191541916419174191841919419204192141922419234192441925419264192741928419294193041931419324193341934419354193641937419384193941940419414194241943419444194541946419474194841949419504195141952419534195441955419564195741958419594196041961419624196341964419654196641967419684196941970419714197241973419744197541976419774197841979419804198141982419834198441985419864198741988419894199041991419924199341994419954199641997419984199942000420014200242003420044200542006420074200842009420104201142012420134201442015420164201742018420194202042021420224202342024420254202642027420284202942030420314203242033420344203542036420374203842039420404204142042420434204442045420464204742048420494205042051420524205342054420554205642057420584205942060420614206242063420644206542066420674206842069420704207142072420734207442075420764207742078420794208042081420824208342084420854208642087420884208942090420914209242093420944209542096420974209842099421004210142102421034210442105421064210742108421094211042111421124211342114421154211642117421184211942120421214212242123421244212542126421274212842129421304213142132421334213442135421364213742138421394214042141421424214342144421454214642147421484214942150421514215242153421544215542156421574215842159421604216142162421634216442165421664216742168421694217042171421724217342174421754217642177421784217942180421814218242183421844218542186421874218842189421904219142192421934219442195421964219742198421994220042201422024220342204422054220642207422084220942210422114221242213422144221542216422174221842219422204222142222422234222442225422264222742228422294223042231422324223342234422354223642237422384223942240422414224242243422444224542246422474224842249422504225142252422534225442255422564225742258422594226042261422624226342264422654226642267422684226942270422714227242273422744227542276422774227842279422804228142282422834228442285422864228742288422894229042291422924229342294422954229642297422984229942300423014230242303423044230542306423074230842309423104231142312423134231442315423164231742318423194232042321423224232342324423254232642327423284232942330423314233242333423344233542336423374233842339423404234142342423434234442345423464234742348423494235042351423524235342354423554235642357423584235942360423614236242363423644236542366423674236842369423704237142372423734237442375423764237742378423794238042381423824238342384423854238642387423884238942390423914239242393423944239542396423974239842399424004240142402424034240442405424064240742408424094241042411424124241342414424154241642417424184241942420424214242242423424244242542426424274242842429424304243142432424334243442435424364243742438424394244042441424424244342444424454244642447424484244942450424514245242453424544245542456424574245842459424604246142462424634246442465424664246742468424694247042471424724247342474424754247642477424784247942480424814248242483424844248542486424874248842489424904249142492424934249442495424964249742498424994250042501425024250342504425054250642507425084250942510425114251242513425144251542516425174251842519425204252142522425234252442525425264252742528425294253042531425324253342534425354253642537425384253942540425414254242543425444254542546425474254842549425504255142552425534255442555425564255742558425594256042561425624256342564425654256642567425684256942570425714257242573425744257542576425774257842579425804258142582425834258442585425864258742588425894259042591425924259342594425954259642597425984259942600426014260242603426044260542606426074260842609426104261142612426134261442615426164261742618426194262042621426224262342624426254262642627426284262942630426314263242633426344263542636426374263842639426404264142642426434264442645426464264742648426494265042651426524265342654426554265642657426584265942660426614266242663426644266542666426674266842669426704267142672426734267442675426764267742678426794268042681426824268342684426854268642687426884268942690426914269242693426944269542696426974269842699427004270142702427034270442705427064270742708427094271042711427124271342714427154271642717427184271942720427214272242723427244272542726427274272842729427304273142732427334273442735427364273742738427394274042741427424274342744427454274642747427484274942750427514275242753427544275542756427574275842759427604276142762427634276442765427664276742768427694277042771427724277342774427754277642777427784277942780427814278242783427844278542786427874278842789427904279142792427934279442795427964279742798427994280042801428024280342804428054280642807428084280942810428114281242813428144281542816428174281842819428204282142822428234282442825428264282742828428294283042831428324283342834428354283642837428384283942840428414284242843428444284542846428474284842849428504285142852428534285442855428564285742858428594286042861428624286342864428654286642867428684286942870428714287242873428744287542876428774287842879428804288142882428834288442885428864288742888428894289042891428924289342894428954289642897428984289942900429014290242903429044290542906429074290842909429104291142912429134291442915429164291742918429194292042921429224292342924429254292642927429284292942930429314293242933429344293542936429374293842939429404294142942429434294442945429464294742948429494295042951429524295342954429554295642957429584295942960429614296242963429644296542966429674296842969429704297142972429734297442975429764297742978429794298042981429824298342984429854298642987429884298942990429914299242993429944299542996429974299842999430004300143002430034300443005430064300743008430094301043011430124301343014430154301643017430184301943020430214302243023430244302543026430274302843029430304303143032430334303443035430364303743038430394304043041430424304343044430454304643047430484304943050430514305243053430544305543056430574305843059430604306143062430634306443065430664306743068430694307043071430724307343074430754307643077430784307943080430814308243083430844308543086430874308843089430904309143092430934309443095430964309743098430994310043101431024310343104431054310643107431084310943110431114311243113431144311543116431174311843119431204312143122431234312443125431264312743128431294313043131431324313343134431354313643137431384313943140431414314243143431444314543146431474314843149431504315143152431534315443155431564315743158431594316043161431624316343164431654316643167431684316943170431714317243173431744317543176431774317843179431804318143182431834318443185431864318743188431894319043191431924319343194431954319643197431984319943200432014320243203432044320543206432074320843209432104321143212432134321443215432164321743218432194322043221432224322343224432254322643227432284322943230432314323243233432344323543236432374323843239432404324143242432434324443245432464324743248432494325043251432524325343254432554325643257432584325943260432614326243263432644326543266432674326843269432704327143272432734327443275432764327743278432794328043281432824328343284432854328643287432884328943290432914329243293432944329543296432974329843299433004330143302433034330443305433064330743308433094331043311433124331343314433154331643317433184331943320433214332243323433244332543326433274332843329433304333143332433334333443335433364333743338433394334043341433424334343344433454334643347433484334943350433514335243353433544335543356433574335843359433604336143362433634336443365433664336743368433694337043371433724337343374433754337643377433784337943380433814338243383433844338543386433874338843389433904339143392433934339443395433964339743398433994340043401434024340343404434054340643407434084340943410434114341243413434144341543416434174341843419434204342143422434234342443425434264342743428434294343043431434324343343434434354343643437434384343943440434414344243443434444344543446434474344843449434504345143452434534345443455434564345743458434594346043461434624346343464434654346643467434684346943470434714347243473434744347543476434774347843479434804348143482434834348443485434864348743488434894349043491434924349343494434954349643497434984349943500435014350243503435044350543506435074350843509435104351143512435134351443515435164351743518435194352043521435224352343524435254352643527435284352943530435314353243533435344353543536435374353843539435404354143542435434354443545435464354743548435494355043551435524355343554435554355643557435584355943560435614356243563435644356543566435674356843569435704357143572435734357443575435764357743578435794358043581435824358343584435854358643587435884358943590435914359243593435944359543596435974359843599436004360143602436034360443605436064360743608436094361043611436124361343614436154361643617436184361943620436214362243623436244362543626436274362843629436304363143632436334363443635436364363743638436394364043641436424364343644436454364643647436484364943650436514365243653436544365543656436574365843659436604366143662436634366443665436664366743668436694367043671436724367343674436754367643677436784367943680436814368243683436844368543686436874368843689436904369143692436934369443695436964369743698436994370043701437024370343704437054370643707437084370943710437114371243713437144371543716437174371843719437204372143722437234372443725437264372743728437294373043731437324373343734437354373643737437384373943740437414374243743437444374543746437474374843749437504375143752437534375443755437564375743758437594376043761437624376343764437654376643767437684376943770437714377243773437744377543776437774377843779437804378143782437834378443785437864378743788437894379043791437924379343794437954379643797437984379943800438014380243803438044380543806438074380843809438104381143812438134381443815438164381743818438194382043821438224382343824438254382643827438284382943830438314383243833438344383543836438374383843839438404384143842438434384443845438464384743848438494385043851438524385343854438554385643857438584385943860438614386243863438644386543866438674386843869438704387143872438734387443875438764387743878438794388043881438824388343884438854388643887438884388943890438914389243893438944389543896438974389843899439004390143902439034390443905439064390743908439094391043911439124391343914439154391643917439184391943920439214392243923439244392543926439274392843929439304393143932439334393443935439364393743938439394394043941439424394343944439454394643947439484394943950439514395243953439544395543956439574395843959439604396143962439634396443965439664396743968439694397043971439724397343974439754397643977439784397943980439814398243983439844398543986439874398843989439904399143992439934399443995439964399743998439994400044001440024400344004440054400644007440084400944010440114401244013440144401544016440174401844019440204402144022440234402444025440264402744028440294403044031440324403344034440354403644037440384403944040440414404244043440444404544046440474404844049440504405144052440534405444055440564405744058440594406044061440624406344064440654406644067440684406944070440714407244073440744407544076440774407844079440804408144082440834408444085440864408744088440894409044091440924409344094440954409644097440984409944100441014410244103441044410544106441074410844109441104411144112441134411444115441164411744118441194412044121441224412344124441254412644127441284412944130441314413244133441344413544136441374413844139441404414144142441434414444145441464414744148441494415044151441524415344154441554415644157441584415944160441614416244163441644416544166441674416844169441704417144172441734417444175441764417744178441794418044181441824418344184441854418644187441884418944190441914419244193441944419544196441974419844199442004420144202442034420444205442064420744208442094421044211442124421344214442154421644217442184421944220442214422244223442244422544226442274422844229442304423144232442334423444235442364423744238442394424044241442424424344244442454424644247442484424944250442514425244253442544425544256442574425844259442604426144262442634426444265442664426744268442694427044271442724427344274442754427644277442784427944280442814428244283442844428544286442874428844289442904429144292442934429444295442964429744298442994430044301443024430344304443054430644307443084430944310443114431244313443144431544316443174431844319443204432144322443234432444325443264432744328443294433044331443324433344334443354433644337443384433944340443414434244343443444434544346443474434844349443504435144352443534435444355443564435744358443594436044361443624436344364443654436644367443684436944370443714437244373443744437544376443774437844379443804438144382443834438444385443864438744388443894439044391443924439344394443954439644397443984439944400444014440244403444044440544406444074440844409444104441144412444134441444415444164441744418444194442044421444224442344424444254442644427444284442944430444314443244433444344443544436444374443844439444404444144442444434444444445444464444744448444494445044451444524445344454444554445644457444584445944460444614446244463444644446544466444674446844469444704447144472444734447444475444764447744478444794448044481444824448344484444854448644487444884448944490444914449244493444944449544496444974449844499445004450144502445034450444505445064450744508445094451044511445124451344514445154451644517445184451944520445214452244523445244452544526445274452844529445304453144532445334453444535445364453744538445394454044541445424454344544445454454644547445484454944550445514455244553445544455544556445574455844559445604456144562445634456444565445664456744568445694457044571445724457344574445754457644577445784457944580445814458244583445844458544586445874458844589445904459144592445934459444595445964459744598445994460044601446024460344604446054460644607446084460944610446114461244613446144461544616446174461844619446204462144622446234462444625446264462744628446294463044631446324463344634446354463644637446384463944640446414464244643446444464544646446474464844649446504465144652446534465444655446564465744658446594466044661446624466344664446654466644667446684466944670446714467244673446744467544676446774467844679446804468144682446834468444685446864468744688446894469044691446924469344694446954469644697446984469944700447014470244703447044470544706447074470844709447104471144712447134471444715447164471744718447194472044721447224472344724447254472644727447284472944730447314473244733447344473544736447374473844739447404474144742447434474444745447464474744748447494475044751447524475344754447554475644757447584475944760447614476244763447644476544766447674476844769447704477144772447734477444775447764477744778447794478044781447824478344784447854478644787447884478944790447914479244793447944479544796447974479844799448004480144802448034480444805448064480744808448094481044811448124481344814448154481644817448184481944820448214482244823448244482544826448274482844829448304483144832448334483444835448364483744838448394484044841448424484344844448454484644847448484484944850448514485244853448544485544856448574485844859448604486144862448634486444865448664486744868448694487044871448724487344874448754487644877448784487944880448814488244883448844488544886448874488844889448904489144892448934489444895448964489744898448994490044901449024490344904449054490644907449084490944910449114491244913449144491544916449174491844919449204492144922449234492444925449264492744928449294493044931449324493344934449354493644937449384493944940449414494244943449444494544946449474494844949449504495144952449534495444955449564495744958449594496044961449624496344964449654496644967449684496944970449714497244973449744497544976449774497844979449804498144982449834498444985449864498744988449894499044991449924499344994449954499644997449984499945000450014500245003450044500545006450074500845009450104501145012450134501445015450164501745018450194502045021450224502345024450254502645027450284502945030450314503245033450344503545036450374503845039450404504145042450434504445045450464504745048450494505045051450524505345054450554505645057450584505945060450614506245063450644506545066450674506845069450704507145072450734507445075450764507745078450794508045081450824508345084450854508645087450884508945090450914509245093450944509545096450974509845099451004510145102451034510445105451064510745108451094511045111451124511345114451154511645117451184511945120451214512245123451244512545126451274512845129451304513145132451334513445135451364513745138451394514045141451424514345144451454514645147451484514945150451514515245153451544515545156451574515845159451604516145162451634516445165451664516745168451694517045171451724517345174451754517645177451784517945180451814518245183451844518545186451874518845189451904519145192451934519445195451964519745198451994520045201452024520345204452054520645207452084520945210452114521245213452144521545216452174521845219452204522145222452234522445225452264522745228452294523045231452324523345234452354523645237452384523945240452414524245243452444524545246452474524845249452504525145252452534525445255452564525745258452594526045261452624526345264452654526645267452684526945270452714527245273452744527545276452774527845279452804528145282452834528445285452864528745288452894529045291452924529345294452954529645297452984529945300453014530245303453044530545306453074530845309453104531145312453134531445315453164531745318453194532045321453224532345324453254532645327453284532945330453314533245333453344533545336453374533845339453404534145342453434534445345453464534745348453494535045351453524535345354453554535645357453584535945360453614536245363453644536545366453674536845369453704537145372453734537445375453764537745378453794538045381453824538345384453854538645387453884538945390453914539245393453944539545396453974539845399454004540145402454034540445405454064540745408454094541045411454124541345414454154541645417454184541945420454214542245423454244542545426454274542845429454304543145432454334543445435454364543745438454394544045441454424544345444454454544645447454484544945450454514545245453454544545545456454574545845459454604546145462454634546445465454664546745468454694547045471454724547345474454754547645477454784547945480454814548245483454844548545486454874548845489454904549145492454934549445495454964549745498454994550045501455024550345504455054550645507455084550945510455114551245513455144551545516455174551845519455204552145522455234552445525455264552745528455294553045531455324553345534455354553645537455384553945540455414554245543455444554545546455474554845549455504555145552455534555445555455564555745558455594556045561455624556345564455654556645567455684556945570455714557245573455744557545576455774557845579455804558145582455834558445585455864558745588455894559045591455924559345594455954559645597455984559945600456014560245603456044560545606456074560845609456104561145612456134561445615456164561745618456194562045621456224562345624456254562645627456284562945630456314563245633456344563545636456374563845639456404564145642456434564445645456464564745648456494565045651456524565345654456554565645657456584565945660456614566245663456644566545666456674566845669456704567145672456734567445675456764567745678456794568045681456824568345684456854568645687456884568945690456914569245693456944569545696456974569845699457004570145702457034570445705457064570745708457094571045711457124571345714457154571645717457184571945720457214572245723457244572545726457274572845729457304573145732457334573445735457364573745738457394574045741457424574345744457454574645747457484574945750457514575245753457544575545756457574575845759457604576145762457634576445765457664576745768457694577045771457724577345774457754577645777457784577945780457814578245783457844578545786457874578845789457904579145792457934579445795457964579745798457994580045801458024580345804458054580645807458084580945810458114581245813458144581545816458174581845819458204582145822458234582445825458264582745828458294583045831458324583345834458354583645837458384583945840458414584245843458444584545846458474584845849458504585145852458534585445855458564585745858458594586045861458624586345864458654586645867458684586945870458714587245873458744587545876458774587845879458804588145882458834588445885458864588745888458894589045891458924589345894458954589645897458984589945900459014590245903459044590545906459074590845909459104591145912459134591445915459164591745918459194592045921459224592345924459254592645927459284592945930459314593245933459344593545936459374593845939459404594145942459434594445945459464594745948459494595045951459524595345954459554595645957459584595945960459614596245963459644596545966459674596845969459704597145972459734597445975459764597745978459794598045981459824598345984459854598645987459884598945990459914599245993459944599545996459974599845999460004600146002460034600446005460064600746008460094601046011460124601346014460154601646017460184601946020460214602246023460244602546026460274602846029460304603146032460334603446035460364603746038460394604046041460424604346044460454604646047460484604946050460514605246053460544605546056460574605846059460604606146062460634606446065460664606746068460694607046071460724607346074460754607646077460784607946080460814608246083460844608546086460874608846089460904609146092460934609446095460964609746098460994610046101461024610346104461054610646107461084610946110461114611246113461144611546116461174611846119461204612146122461234612446125461264612746128461294613046131461324613346134461354613646137461384613946140461414614246143461444614546146461474614846149461504615146152461534615446155461564615746158461594616046161461624616346164461654616646167461684616946170461714617246173461744617546176461774617846179461804618146182461834618446185461864618746188461894619046191461924619346194461954619646197461984619946200462014620246203462044620546206462074620846209462104621146212462134621446215462164621746218462194622046221462224622346224462254622646227462284622946230462314623246233462344623546236462374623846239462404624146242462434624446245462464624746248462494625046251462524625346254462554625646257462584625946260462614626246263462644626546266462674626846269462704627146272462734627446275462764627746278462794628046281462824628346284462854628646287462884628946290462914629246293462944629546296462974629846299463004630146302463034630446305463064630746308463094631046311463124631346314463154631646317463184631946320463214632246323463244632546326463274632846329463304633146332463334633446335463364633746338463394634046341463424634346344463454634646347463484634946350463514635246353463544635546356463574635846359463604636146362463634636446365463664636746368463694637046371463724637346374463754637646377463784637946380463814638246383463844638546386463874638846389463904639146392463934639446395463964639746398463994640046401464024640346404464054640646407464084640946410464114641246413464144641546416464174641846419464204642146422464234642446425464264642746428464294643046431464324643346434464354643646437464384643946440464414644246443464444644546446464474644846449464504645146452464534645446455464564645746458464594646046461464624646346464464654646646467464684646946470464714647246473464744647546476464774647846479464804648146482464834648446485464864648746488464894649046491464924649346494464954649646497464984649946500465014650246503465044650546506465074650846509465104651146512465134651446515465164651746518465194652046521465224652346524465254652646527465284652946530465314653246533465344653546536465374653846539465404654146542465434654446545465464654746548465494655046551465524655346554465554655646557465584655946560465614656246563465644656546566465674656846569465704657146572465734657446575465764657746578465794658046581465824658346584465854658646587465884658946590465914659246593465944659546596465974659846599466004660146602466034660446605466064660746608466094661046611466124661346614466154661646617466184661946620466214662246623466244662546626466274662846629466304663146632466334663446635466364663746638466394664046641466424664346644466454664646647466484664946650466514665246653466544665546656466574665846659466604666146662466634666446665466664666746668466694667046671466724667346674466754667646677466784667946680466814668246683466844668546686466874668846689466904669146692466934669446695466964669746698466994670046701467024670346704467054670646707467084670946710467114671246713467144671546716467174671846719467204672146722467234672446725467264672746728467294673046731467324673346734467354673646737467384673946740467414674246743467444674546746467474674846749467504675146752467534675446755467564675746758467594676046761467624676346764467654676646767467684676946770467714677246773467744677546776467774677846779467804678146782467834678446785467864678746788467894679046791467924679346794467954679646797467984679946800468014680246803468044680546806468074680846809468104681146812468134681446815468164681746818468194682046821468224682346824468254682646827468284682946830468314683246833468344683546836468374683846839468404684146842468434684446845468464684746848468494685046851468524685346854468554685646857468584685946860468614686246863468644686546866468674686846869468704687146872468734687446875468764687746878468794688046881468824688346884468854688646887468884688946890468914689246893468944689546896468974689846899469004690146902469034690446905469064690746908469094691046911469124691346914469154691646917469184691946920469214692246923469244692546926469274692846929469304693146932469334693446935469364693746938469394694046941469424694346944469454694646947469484694946950469514695246953469544695546956469574695846959469604696146962469634696446965469664696746968469694697046971469724697346974469754697646977469784697946980469814698246983469844698546986469874698846989469904699146992469934699446995469964699746998469994700047001470024700347004470054700647007470084700947010470114701247013470144701547016470174701847019470204702147022470234702447025470264702747028470294703047031470324703347034470354703647037470384703947040470414704247043470444704547046470474704847049470504705147052470534705447055470564705747058470594706047061470624706347064470654706647067470684706947070470714707247073470744707547076470774707847079470804708147082470834708447085470864708747088470894709047091470924709347094470954709647097470984709947100471014710247103471044710547106471074710847109471104711147112471134711447115471164711747118471194712047121471224712347124471254712647127471284712947130471314713247133471344713547136471374713847139471404714147142471434714447145471464714747148471494715047151471524715347154471554715647157471584715947160471614716247163471644716547166471674716847169471704717147172471734717447175471764717747178471794718047181471824718347184471854718647187471884718947190471914719247193471944719547196471974719847199472004720147202472034720447205472064720747208472094721047211472124721347214472154721647217472184721947220472214722247223472244722547226472274722847229472304723147232472334723447235472364723747238472394724047241472424724347244472454724647247472484724947250472514725247253472544725547256472574725847259472604726147262472634726447265472664726747268472694727047271472724727347274472754727647277472784727947280472814728247283472844728547286472874728847289472904729147292472934729447295472964729747298472994730047301473024730347304473054730647307473084730947310473114731247313473144731547316473174731847319473204732147322473234732447325473264732747328473294733047331473324733347334473354733647337473384733947340473414734247343473444734547346473474734847349473504735147352473534735447355473564735747358473594736047361473624736347364473654736647367473684736947370473714737247373473744737547376473774737847379473804738147382473834738447385473864738747388473894739047391473924739347394473954739647397473984739947400474014740247403474044740547406474074740847409474104741147412474134741447415474164741747418474194742047421474224742347424474254742647427474284742947430474314743247433474344743547436474374743847439474404744147442474434744447445474464744747448474494745047451474524745347454474554745647457474584745947460474614746247463474644746547466474674746847469474704747147472474734747447475474764747747478474794748047481474824748347484474854748647487474884748947490474914749247493474944749547496474974749847499475004750147502475034750447505475064750747508475094751047511475124751347514475154751647517475184751947520475214752247523475244752547526475274752847529475304753147532475334753447535475364753747538475394754047541475424754347544475454754647547475484754947550475514755247553475544755547556475574755847559475604756147562475634756447565475664756747568475694757047571475724757347574475754757647577475784757947580475814758247583475844758547586475874758847589475904759147592475934759447595475964759747598475994760047601476024760347604476054760647607476084760947610476114761247613476144761547616476174761847619476204762147622476234762447625476264762747628476294763047631476324763347634476354763647637476384763947640476414764247643476444764547646476474764847649476504765147652476534765447655476564765747658476594766047661476624766347664476654766647667476684766947670476714767247673476744767547676476774767847679476804768147682476834768447685476864768747688476894769047691476924769347694476954769647697476984769947700477014770247703477044770547706477074770847709477104771147712477134771447715477164771747718477194772047721477224772347724477254772647727477284772947730477314773247733477344773547736477374773847739477404774147742477434774447745477464774747748477494775047751477524775347754477554775647757477584775947760477614776247763477644776547766477674776847769477704777147772477734777447775477764777747778477794778047781477824778347784477854778647787477884778947790477914779247793477944779547796477974779847799478004780147802478034780447805478064780747808478094781047811478124781347814478154781647817478184781947820478214782247823478244782547826478274782847829478304783147832478334783447835478364783747838478394784047841478424784347844478454784647847478484784947850478514785247853478544785547856478574785847859478604786147862478634786447865478664786747868478694787047871478724787347874478754787647877478784787947880478814788247883478844788547886478874788847889478904789147892478934789447895478964789747898478994790047901479024790347904479054790647907479084790947910479114791247913479144791547916479174791847919479204792147922479234792447925479264792747928479294793047931479324793347934479354793647937479384793947940479414794247943479444794547946479474794847949479504795147952479534795447955479564795747958479594796047961479624796347964479654796647967479684796947970479714797247973479744797547976479774797847979479804798147982479834798447985479864798747988479894799047991479924799347994479954799647997479984799948000480014800248003480044800548006480074800848009480104801148012480134801448015480164801748018480194802048021480224802348024480254802648027480284802948030480314803248033480344803548036480374803848039480404804148042480434804448045480464804748048480494805048051480524805348054480554805648057480584805948060480614806248063480644806548066480674806848069480704807148072480734807448075480764807748078480794808048081480824808348084480854808648087480884808948090480914809248093480944809548096480974809848099481004810148102481034810448105481064810748108481094811048111481124811348114481154811648117481184811948120481214812248123481244812548126481274812848129481304813148132481334813448135481364813748138481394814048141481424814348144481454814648147481484814948150481514815248153481544815548156481574815848159481604816148162481634816448165481664816748168481694817048171481724817348174481754817648177481784817948180481814818248183481844818548186481874818848189481904819148192481934819448195481964819748198481994820048201482024820348204482054820648207482084820948210482114821248213482144821548216482174821848219482204822148222482234822448225482264822748228482294823048231482324823348234482354823648237482384823948240482414824248243482444824548246482474824848249482504825148252482534825448255482564825748258482594826048261482624826348264482654826648267482684826948270482714827248273482744827548276482774827848279482804828148282482834828448285482864828748288482894829048291482924829348294482954829648297482984829948300483014830248303483044830548306483074830848309483104831148312483134831448315483164831748318483194832048321483224832348324483254832648327483284832948330483314833248333483344833548336483374833848339483404834148342483434834448345483464834748348483494835048351483524835348354483554835648357483584835948360483614836248363483644836548366483674836848369483704837148372483734837448375483764837748378483794838048381483824838348384483854838648387483884838948390483914839248393483944839548396483974839848399484004840148402484034840448405484064840748408484094841048411484124841348414484154841648417484184841948420484214842248423484244842548426484274842848429484304843148432484334843448435484364843748438484394844048441484424844348444484454844648447484484844948450484514845248453484544845548456484574845848459484604846148462484634846448465484664846748468484694847048471484724847348474484754847648477484784847948480484814848248483484844848548486484874848848489484904849148492484934849448495484964849748498484994850048501485024850348504485054850648507485084850948510485114851248513485144851548516485174851848519485204852148522485234852448525485264852748528485294853048531485324853348534485354853648537485384853948540485414854248543485444854548546485474854848549485504855148552485534855448555485564855748558485594856048561485624856348564485654856648567485684856948570485714857248573485744857548576485774857848579485804858148582485834858448585485864858748588485894859048591485924859348594485954859648597485984859948600486014860248603486044860548606486074860848609486104861148612486134861448615486164861748618486194862048621486224862348624486254862648627486284862948630486314863248633486344863548636486374863848639486404864148642486434864448645486464864748648486494865048651486524865348654486554865648657486584865948660486614866248663486644866548666486674866848669486704867148672486734867448675486764867748678486794868048681486824868348684486854868648687486884868948690486914869248693486944869548696486974869848699487004870148702487034870448705487064870748708487094871048711487124871348714487154871648717487184871948720487214872248723487244872548726487274872848729487304873148732487334873448735487364873748738487394874048741487424874348744487454874648747487484874948750487514875248753487544875548756487574875848759487604876148762487634876448765487664876748768487694877048771487724877348774487754877648777487784877948780487814878248783487844878548786487874878848789487904879148792487934879448795487964879748798487994880048801488024880348804488054880648807488084880948810488114881248813488144881548816488174881848819488204882148822488234882448825488264882748828488294883048831488324883348834488354883648837488384883948840488414884248843488444884548846488474884848849488504885148852488534885448855488564885748858488594886048861488624886348864488654886648867488684886948870488714887248873488744887548876488774887848879488804888148882488834888448885488864888748888488894889048891488924889348894488954889648897488984889948900489014890248903489044890548906489074890848909489104891148912489134891448915489164891748918489194892048921489224892348924489254892648927489284892948930489314893248933489344893548936489374893848939489404894148942489434894448945489464894748948489494895048951489524895348954489554895648957489584895948960489614896248963489644896548966489674896848969489704897148972489734897448975489764897748978489794898048981489824898348984489854898648987489884898948990489914899248993489944899548996489974899848999490004900149002490034900449005490064900749008490094901049011490124901349014490154901649017490184901949020490214902249023490244902549026490274902849029490304903149032490334903449035490364903749038490394904049041490424904349044490454904649047490484904949050490514905249053490544905549056490574905849059490604906149062490634906449065490664906749068490694907049071490724907349074490754907649077490784907949080490814908249083490844908549086490874908849089490904909149092490934909449095490964909749098490994910049101491024910349104491054910649107491084910949110491114911249113491144911549116491174911849119491204912149122491234912449125491264912749128491294913049131491324913349134491354913649137491384913949140491414914249143491444914549146491474914849149491504915149152491534915449155491564915749158491594916049161491624916349164491654916649167491684916949170491714917249173491744917549176491774917849179491804918149182491834918449185491864918749188491894919049191491924919349194491954919649197491984919949200492014920249203492044920549206492074920849209492104921149212492134921449215492164921749218492194922049221492224922349224492254922649227492284922949230492314923249233492344923549236492374923849239492404924149242492434924449245492464924749248492494925049251492524925349254492554925649257492584925949260492614926249263492644926549266492674926849269492704927149272492734927449275492764927749278492794928049281492824928349284492854928649287492884928949290492914929249293492944929549296492974929849299493004930149302493034930449305493064930749308493094931049311493124931349314493154931649317493184931949320493214932249323493244932549326493274932849329493304933149332493334933449335493364933749338493394934049341493424934349344493454934649347493484934949350493514935249353493544935549356493574935849359493604936149362493634936449365493664936749368493694937049371493724937349374493754937649377493784937949380493814938249383493844938549386493874938849389493904939149392493934939449395493964939749398493994940049401494024940349404494054940649407494084940949410494114941249413494144941549416494174941849419494204942149422494234942449425494264942749428494294943049431494324943349434494354943649437494384943949440494414944249443494444944549446494474944849449494504945149452494534945449455494564945749458494594946049461494624946349464494654946649467494684946949470494714947249473494744947549476494774947849479494804948149482494834948449485494864948749488494894949049491494924949349494494954949649497494984949949500495014950249503495044950549506495074950849509495104951149512495134951449515495164951749518495194952049521495224952349524495254952649527495284952949530495314953249533495344953549536495374953849539495404954149542495434954449545495464954749548495494955049551495524955349554495554955649557495584955949560495614956249563495644956549566495674956849569495704957149572495734957449575495764957749578495794958049581495824958349584495854958649587495884958949590495914959249593495944959549596495974959849599496004960149602496034960449605496064960749608496094961049611496124961349614496154961649617496184961949620496214962249623496244962549626496274962849629496304963149632496334963449635496364963749638496394964049641496424964349644496454964649647496484964949650496514965249653496544965549656496574965849659496604966149662496634966449665496664966749668496694967049671496724967349674496754967649677496784967949680496814968249683496844968549686496874968849689496904969149692496934969449695496964969749698496994970049701497024970349704497054970649707497084970949710497114971249713497144971549716497174971849719497204972149722497234972449725497264972749728497294973049731497324973349734497354973649737497384973949740497414974249743497444974549746497474974849749497504975149752497534975449755497564975749758497594976049761497624976349764497654976649767497684976949770497714977249773497744977549776497774977849779497804978149782497834978449785497864978749788497894979049791497924979349794497954979649797497984979949800498014980249803498044980549806498074980849809498104981149812498134981449815498164981749818498194982049821498224982349824498254982649827498284982949830498314983249833498344983549836498374983849839498404984149842498434984449845498464984749848498494985049851498524985349854498554985649857498584985949860498614986249863498644986549866498674986849869498704987149872498734987449875498764987749878498794988049881498824988349884498854988649887498884988949890498914989249893498944989549896498974989849899499004990149902499034990449905499064990749908499094991049911499124991349914499154991649917499184991949920499214992249923499244992549926499274992849929499304993149932499334993449935499364993749938499394994049941499424994349944499454994649947499484994949950499514995249953499544995549956499574995849959499604996149962499634996449965499664996749968499694997049971499724997349974499754997649977499784997949980499814998249983499844998549986499874998849989499904999149992499934999449995499964999749998499995000050001500025000350004500055000650007500085000950010500115001250013500145001550016500175001850019500205002150022500235002450025500265002750028500295003050031500325003350034500355003650037500385003950040500415004250043500445004550046500475004850049500505005150052500535005450055500565005750058500595006050061500625006350064500655006650067500685006950070500715007250073500745007550076500775007850079500805008150082500835008450085500865008750088500895009050091500925009350094500955009650097500985009950100501015010250103501045010550106501075010850109501105011150112501135011450115501165011750118501195012050121501225012350124501255012650127501285012950130501315013250133501345013550136501375013850139501405014150142501435014450145501465014750148501495015050151501525015350154501555015650157501585015950160501615016250163501645016550166501675016850169501705017150172 |
- From 536821c33d55b5d714910c015008d2cebd7dfef5 Mon Sep 17 00:00:00 2001
- From: Robert Ogden <robertogden@chromium.org>
- Date: Wed, 25 May 2022 11:03:46 -0700
- Subject: [PATCH 8/9] run clang format
- ---
- .../configuration/edgetpu_coral_plugin.cc | 20 +-
- .../edgetpu_coral_plugin_test.cc | 3 +-
- .../src/tensorflow_lite_support/c/common.cc | 2 +-
- .../src/tensorflow_lite_support/c/common.h | 4 +-
- .../tensorflow_lite_support/c/common_utils.cc | 11 +-
- .../tensorflow_lite_support/c/common_utils.h | 3 +-
- .../c/task/audio/audio_classifier.cc | 12 +-
- .../c/task/audio/audio_classifier.h | 12 +-
- .../c/task/audio/core/audio_buffer.h | 4 +-
- .../c/task/processor/classification_result.cc | 2 +-
- .../c/task/text/bert_nl_classifier.cc | 6 +-
- .../c/task/text/bert_nl_classifier.h | 6 +-
- .../c/task/text/bert_question_answerer.cc | 3 +-
- .../c/task/text/bert_question_answerer.h | 3 +-
- .../c/task/text/nl_classifier.cc | 3 +-
- .../c/task/text/nl_classifier.h | 3 +-
- .../c/task/vision/image_classifier.cc | 9 +-
- .../c/task/vision/image_classifier.h | 9 +-
- .../c/task/vision/image_segmenter.cc | 6 +-
- .../c/task/vision/image_segmenter.h | 6 +-
- .../c/task/vision/object_detector.cc | 6 +-
- .../c/task/vision/object_detector.h | 6 +-
- .../test/task/audio/audio_classifier_test.cc | 32 +-
- .../test/task/vision/image_classifier_test.cc | 84 +-
- .../test/task/vision/image_segmenter_test.cc | 62 +-
- .../test/task/vision/object_detector_test.cc | 90 +-
- .../src/tensorflow_lite_support/cc/common.cc | 2 +-
- .../src/tensorflow_lite_support/cc/common.h | 5 +-
- .../cc/port/default/status_macros.h | 2 +-
- .../cc/port/default/statusor_internals.h | 38 +-
- .../cc/port/default/tflite_wrapper.cc | 9 +-
- .../cc/port/default/tflite_wrapper.h | 2 +-
- .../cc/port/integral_types.h | 2 +-
- .../cc/task/audio/audio_classifier.cc | 2 +-
- .../cc/task/audio/audio_embedder.cc | 3 +-
- .../cc/task/audio/audio_embedder.h | 9 +-
- .../cc/task/audio/core/audio_buffer.h | 10 +-
- .../cc/task/audio/utils/audio_utils.cc | 3 +-
- .../cc/task/audio/utils/audio_utils.h | 3 +-
- .../cc/task/audio/utils/wav_io.cc | 19 +-
- .../cc/task/audio/utils/wav_io.h | 6 +-
- .../cc/task/core/base_task_api.h | 2 +-
- .../cc/task/core/classification_head.h | 2 +-
- .../cc/task/core/error_reporter.cc | 8 +-
- .../cc/task/core/external_file_handler.cc | 7 +-
- .../cc/task/core/external_file_handler.h | 3 +-
- .../cc/task/core/label_map_item.cc | 5 +-
- .../cc/task/core/label_map_item.h | 7 +-
- .../cc/task/core/score_calibration.cc | 8 +-
- .../cc/task/core/score_calibration.h | 11 +-
- .../cc/task/core/task_api_factory.h | 8 +-
- .../cc/task/core/task_utils.h | 30 +-
- .../cc/task/core/tflite_engine.cc | 14 +-
- .../cc/task/core/tflite_engine.h | 13 +-
- .../cc/task/processor/audio_preprocessor.cc | 5 +-
- .../processor/classification_postprocessor.cc | 5 +-
- .../task/processor/embedding_postprocessor.h | 10 +-
- .../cc/task/processor/image_preprocessor.cc | 6 +-
- .../cc/task/processor/processor.h | 5 +-
- .../cc/task/processor/regex_preprocessor.cc | 3 +-
- .../cc/task/processor/regex_preprocessor.h | 3 +-
- .../cc/task/processor/search_postprocessor.cc | 40 +-
- .../cc/task/processor/search_postprocessor.h | 37 +-
- .../cc/task/text/bert_clu_annotator.cc | 4 +-
- .../cc/task/text/bert_nl_classifier.cc | 3 +-
- .../cc/task/text/bert_nl_classifier.h | 2 +-
- .../cc/task/text/bert_question_answerer.cc | 32 +-
- .../cc/task/text/bert_question_answerer.h | 7 +-
- .../cc/task/text/clu_lib/bert_utils.cc | 14 +-
- .../cc/task/text/clu_lib/bert_utils.h | 7 +-
- .../cc/task/text/clu_lib/intent_repr.cc | 18 +-
- .../cc/task/text/clu_lib/intent_repr.h | 5 +-
- .../cc/task/text/clu_lib/slot_repr.cc | 32 +-
- .../cc/task/text/clu_lib/slot_repr.h | 9 +-
- .../task/text/clu_lib/slot_tagging_output.cc | 24 +-
- .../task/text/clu_lib/slot_tagging_output.h | 6 +-
- .../cc/task/text/clu_lib/tflite_modules.cc | 41 +-
- .../cc/task/text/clu_lib/tflite_modules.h | 17 +-
- .../cc/task/text/clu_lib/tflite_test_utils.cc | 14 +-
- .../cc/task/text/clu_lib/tflite_test_utils.h | 6 +-
- .../task/text/nlclassifier/nl_classifier.cc | 18 +-
- .../cc/task/text/nlclassifier/nl_classifier.h | 19 +-
- .../text/proto/text_searcher_options.proto | 1 -
- .../cc/task/text/question_answerer.h | 6 +-
- .../cc/task/text/text_embedder.cc | 6 +-
- .../cc/task/text/text_embedder.h | 3 +-
- .../cc/task/text/text_searcher.h | 4 +-
- .../text/universal_sentence_encoder_qa.cc | 14 +-
- .../task/text/universal_sentence_encoder_qa.h | 7 +-
- .../cc/task/text/utils/bert_utils.cc | 2 +-
- .../task/vision/core/base_vision_task_api.h | 9 +-
- .../cc/task/vision/core/classification_head.h | 2 +-
- .../cc/task/vision/core/frame_buffer.h | 47 +-
- .../cc/task/vision/core/label_map_item.cc | 5 +-
- .../cc/task/vision/core/label_map_item.h | 7 +-
- .../cc/task/vision/image_classifier.cc | 14 +-
- .../cc/task/vision/image_classifier.h | 8 +-
- .../cc/task/vision/image_embedder.cc | 17 +-
- .../cc/task/vision/image_embedder.h | 9 +-
- .../cc/task/vision/image_searcher.cc | 7 +-
- .../cc/task/vision/image_searcher.h | 8 +-
- .../cc/task/vision/image_segmenter.cc | 17 +-
- .../cc/task/vision/image_segmenter.h | 8 +-
- .../cc/task/vision/object_detector.cc | 14 +-
- .../cc/task/vision/object_detector.h | 5 +-
- .../vision/proto/image_searcher_options.proto | 2 -
- .../vision/utils/frame_buffer_common_utils.cc | 59 +-
- .../vision/utils/frame_buffer_common_utils.h | 37 +-
- .../task/vision/utils/frame_buffer_utils.cc | 50 +-
- .../cc/task/vision/utils/frame_buffer_utils.h | 40 +-
- .../utils/frame_buffer_utils_interface.h | 11 +-
- .../cc/task/vision/utils/image_utils.cc | 12 +-
- .../cc/task/vision/utils/image_utils.h | 2 +-
- .../vision/utils/libyuv_frame_buffer_utils.cc | 81 +-
- .../vision/utils/libyuv_frame_buffer_utils.h | 9 +-
- .../cc/task/vision/utils/score_calibration.cc | 8 +-
- .../cc/task/vision/utils/score_calibration.h | 11 +-
- .../cc/test/common_test.cc | 2 +-
- .../task/processor/image_preprocessor_test.cc | 13 +-
- .../test/task/text/bert_nl_classifier_test.cc | 36 +-
- .../task/text/bert_question_answerer_test.cc | 7 +-
- .../test/task/text/clu_lib/bert_utils_test.cc | 32 +-
- .../task/text/clu_lib/intent_repr_test.cc | 2 +-
- .../text/nlclassifier/nl_classifier_test.cc | 83 +-
- .../cc/test/task/text/text_embedder_test.cc | 26 +-
- .../cc/test/task/text/text_searcher_test.cc | 18 +-
- .../universal_sentence_encoder_qa_test.cc | 16 +-
- .../test/task/vision/image_classifier_test.cc | 158 +-
- .../test/task/vision/image_embedder_test.cc | 95 +-
- .../test/task/vision/image_searcher_test.cc | 62 +-
- .../test/task/vision/image_segmenter_test.cc | 117 +-
- .../test/task/vision/object_detector_test.cc | 157 +-
- .../cc/test/test_utils.cc | 18 +-
- .../cc/test/test_utils.h | 6 +-
- .../cc/text/tokenizers/bert_tokenizer.cc | 3 +-
- .../cc/text/tokenizers/bert_tokenizer.h | 3 +-
- .../cc/text/tokenizers/bert_tokenizer_jni.cc | 25 +-
- .../cc/text/tokenizers/regex_tokenizer.cc | 4 +-
- .../cc/text/tokenizers/sentencepiece_jni.cc | 20 +-
- .../cc/text/tokenizers/tokenizer_jni_lib.cc | 3 +-
- .../cc/text/tokenizers/tokenizer_jni_lib.h | 3 +-
- .../cc/text/tokenizers/tokenizer_utils.cc | 6 +-
- .../cc/text/tokenizers/tokenizer_utils.h | 1 -
- .../cc/utils/common_utils.cc | 3 +-
- .../cc/utils/common_utils.h | 3 +-
- .../cc/utils/jni_utils.cc | 7 +-
- .../cc/utils/jni_utils.h | 9 +-
- .../codegen/android_java_generator.cc | 37 +-
- .../codegen/android_java_generator.h | 5 +-
- .../codegen/code_generator.cc | 3 +-
- .../codegen/code_generator.h | 3 +-
- .../codegen/code_generator_test.cc | 3 +-
- .../codegen/metadata_helper.h | 2 +-
- .../codegen/python/codegen_lib.cc | 9 +-
- .../tensorflow_lite_support/codegen/utils.cc | 36 +-
- .../custom_ops/kernel/ngrams.cc | 7 +-
- .../custom_ops/kernel/ngrams_op_resolver.cc | 2 +-
- .../custom_ops/kernel/ngrams_test.cc | 9 +-
- .../kernel/ragged/py_tflite_registerer.h | 2 +-
- .../kernel/ragged/ragged_range_tflite.cc | 9 +-
- .../kernel/ragged/ragged_range_tflite_test.cc | 3 +-
- .../ragged/ragged_tensor_to_tensor_tflite.cc | 47 +-
- .../ragged_tensor_to_tensor_tflite_test.cc | 6 +-
- .../kernel/sentencepiece/model_converter.cc | 10 +-
- .../kernel/sentencepiece/model_converter.h | 6 +-
- .../sentencepiece/optimized_decoder_test.cc | 6 +-
- .../kernel/sentencepiece/optimized_encoder.cc | 23 +-
- .../kernel/sentencepiece/optimized_encoder.h | 10 +-
- .../sentencepiece/optimized_encoder_test.cc | 8 +-
- .../sentencepiece/py_tflite_registerer.h | 2 +-
- .../sentencepiece_detokenizer_tflite.cc | 3 +-
- .../sentencepiece_tokenizer_op.cc | 6 +-
- .../sentencepiece_tokenizer_tflite.cc | 7 +-
- .../custom_ops/kernel/whitespace_tokenizer.cc | 13 +-
- .../whitespace_tokenizer_op_resolver.cc | 2 +-
- .../audio/desktop/audio_classifier_demo.cc | 16 +-
- .../audio/desktop/audio_classifier_lib.cc | 11 +-
- .../task/audio/desktop/audio_classifier_lib.h | 3 +-
- .../text/desktop/bert_nl_classifier_demo.cc | 14 +-
- .../desktop/bert_question_answerer_demo.cc | 18 +-
- .../task/text/desktop/nl_classifier_demo.cc | 14 +-
- .../task/text/desktop/text_embedder_demo.cc | 26 +-
- .../task/text/desktop/text_searcher_demo.cc | 30 +-
- .../universal_sentence_encoder_qa_demo.cc | 17 +-
- .../vision/desktop/image_classifier_demo.cc | 34 +-
- .../vision/desktop/image_embedder_demo.cc | 30 +-
- .../vision/desktop/image_searcher_demo.cc | 30 +-
- .../vision/desktop/image_segmenter_demo.cc | 24 +-
- .../vision/desktop/object_detector_demo.cc | 40 +-
- .../ios/sources/TFLCommon.h | 11 +-
- .../ios/sources/TFLCommonUtils.h | 32 +-
- .../ios/sources/TFLCommonUtils.m | 19 +-
- .../task/audio/core/sources/TFLFloatBuffer.h | 18 +-
- .../task/audio/core/sources/TFLFloatBuffer.m | 4 +-
- .../task/audio/core/sources/TFLRingBuffer.h | 32 +-
- .../task/audio/core/sources/TFLRingBuffer.m | 49 +-
- .../core/sources/TFLBaseOptions+Helpers.h | 2 +-
- .../ios/task/core/sources/TFLBaseOptions.h | 32 +-
- .../processor/sources/TFLCategory+Helpers.h | 2 +-
- .../processor/sources/TFLCategory+Helpers.m | 7 +-
- .../ios/task/processor/sources/TFLCategory.h | 22 +-
- .../ios/task/processor/sources/TFLCategory.m | 4 +-
- .../TFLClassificationOptions+Helpers.h | 6 +-
- .../TFLClassificationOptions+Helpers.m | 33 +-
- .../sources/TFLClassificationOptions.h | 9 +-
- .../sources/TFLClassificationResult+Helpers.h | 17 +-
- .../sources/TFLClassificationResult+Helpers.m | 22 +-
- .../sources/TFLClassificationResult.h | 79 +-
- .../sources/TFLClassificationResult.m | 12 +-
- .../sources/TFLDetectionResult+Helpers.h | 11 +-
- .../sources/TFLDetectionResult+Helpers.m | 15 +-
- .../processor/sources/TFLDetectionResult.h | 35 +-
- .../processor/sources/TFLDetectionResult.m | 4 +-
- .../sources/TFLSegmentationResult+Helpers.h | 4 +-
- .../sources/TFLSegmentationResult+Helpers.m | 44 +-
- .../processor/sources/TFLSegmentationResult.h | 65 +-
- .../processor/sources/TFLSegmentationResult.m | 45 +-
- .../Sources/TFLBertNLClassifier.h | 21 +-
- .../nlclassifier/Sources/TFLNLClassifier.h | 47 +-
- .../text/qa/Sources/TFLBertQuestionAnswerer.h | 4 +-
- .../task/vision/sources/TFLImageClassifier.h | 90 +-
- .../task/vision/sources/TFLImageClassifier.m | 58 +-
- .../task/vision/sources/TFLImageSegmenter.h | 62 +-
- .../task/vision/sources/TFLImageSegmenter.m | 49 +-
- .../task/vision/sources/TFLObjectDetector.h | 64 +-
- .../task/vision/sources/TFLObjectDetector.m | 54 +-
- .../vision/utils/sources/GMLImage+Utils.h | 8 +-
- .../vision/utils/sources/GMLImage+Utils.m | 225 +-
- .../test/task/audio/core/TFLRingBufferTests.m | 171 +-
- .../TFLImageClassifierTests.m | 28 +-
- .../image_segmenter/TFLImageSegmenterTests.m | 64 +-
- .../object_detector/TFLObjectDetectorTests.m | 36 +-
- .../tokenizers/Sources/TFLBertTokenizer.h | 6 +-
- .../Sources/TFLSentencepieceTokenizer.h | 2 +-
- .../text/tokenizers/Sources/TFLTokenizer.h | 4 +-
- .../tokenizers/Sources/TFLTokenizerUtil.h | 11 +-
- .../ios/utils/Sources/TFLStringUtil.mm | 11 +-
- .../lite/support/audio/TensorAudio.java | 524 ++---
- .../lite/support/common/FileUtil.java | 301 +--
- .../lite/support/common/Operator.java | 15 +-
- .../lite/support/common/Processor.java | 2 +-
- .../support/common/SequentialProcessor.java | 83 +-
- .../lite/support/common/TensorOperator.java | 6 +-
- .../lite/support/common/TensorProcessor.java | 57 +-
- .../common/internal/SupportPreconditions.java | 302 +--
- .../lite/support/common/ops/CastOp.java | 55 +-
- .../lite/support/common/ops/DequantizeOp.java | 9 +-
- .../lite/support/common/ops/NormalizeOp.java | 245 ++-
- .../lite/support/common/ops/QuantizeOp.java | 9 +-
- .../lite/support/image/BitmapContainer.java | 116 +-
- .../lite/support/image/BoundingBoxUtil.java | 369 ++--
- .../lite/support/image/ColorSpaceType.java | 623 +++---
- .../lite/support/image/ImageContainer.java | 36 +-
- .../lite/support/image/ImageConversions.java | 217 +-
- .../lite/support/image/ImageOperator.java | 41 +-
- .../lite/support/image/ImageProcessor.java | 285 +--
- .../lite/support/image/ImageProperties.java | 91 +-
- .../support/image/MediaImageContainer.java | 112 +-
- .../lite/support/image/MlImageAdapter.java | 160 +-
- .../support/image/TensorBufferContainer.java | 202 +-
- .../lite/support/image/TensorImage.java | 677 +++---
- .../lite/support/image/ops/ResizeOp.java | 105 +-
- .../image/ops/ResizeWithCropOrPadOp.java | 170 +-
- .../lite/support/image/ops/Rot90Op.java | 141 +-
- .../image/ops/TensorOperatorWrapper.java | 78 +-
- .../image/ops/TransformToGrayscaleOp.java | 127 +-
- .../lite/support/label/Category.java | 192 +-
- .../lite/support/label/LabelUtil.java | 77 +-
- .../lite/support/label/TensorLabel.java | 331 +--
- .../lite/support/label/ops/LabelAxisOp.java | 70 +-
- .../lite/support/model/GpuDelegateProxy.java | 71 +-
- .../tensorflow/lite/support/model/Model.java | 467 +++--
- .../support/tensorbuffer/TensorBuffer.java | 899 ++++----
- .../tensorbuffer/TensorBufferFloat.java | 181 +-
- .../tensorbuffer/TensorBufferUint8.java | 188 +-
- .../audio/classifier/AudioClassifier.java | 857 ++++----
- .../audio/classifier/Classifications.java | 28 +-
- .../lite/task/core/BaseOptions.java | 105 +-
- .../lite/task/core/BaseTaskApi.java | 122 +-
- .../lite/task/core/ComputeSettings.java | 48 +-
- .../lite/task/core/TaskJniUtils.java | 275 ++-
- .../core/annotations/UsedByReflection.java | 2 +-
- .../core/vision/ImageProcessingOptions.java | 125 +-
- .../lite/task/processor/NearestNeighbor.java | 53 +-
- .../lite/task/processor/SearcherOptions.java | 114 +-
- .../text/nlclassifier/BertNLClassifier.java | 391 ++--
- .../task/text/nlclassifier/NLClassifier.java | 568 ++---
- .../task/text/qa/BertQuestionAnswerer.java | 394 ++--
- .../lite/task/text/qa/QaAnswer.java | 60 +-
- .../lite/task/text/qa/QuestionAnswerer.java | 19 +-
- .../lite/task/text/searcher/TextSearcher.java | 375 ++--
- .../vision/classifier/Classifications.java | 25 +-
- .../vision/classifier/ImageClassifier.java | 882 ++++----
- .../task/vision/core/BaseVisionTaskApi.java | 349 ++--
- .../lite/task/vision/detector/Detection.java | 26 +-
- .../task/vision/detector/ObjectDetector.java | 873 ++++----
- .../task/vision/searcher/ImageSearcher.java | 605 +++---
- .../task/vision/segmenter/ColoredLabel.java | 112 +-
- .../task/vision/segmenter/ImageSegmenter.java | 752 ++++---
- .../task/vision/segmenter/OutputType.java | 202 +-
- .../task/vision/segmenter/Segmentation.java | 106 +-
- .../lite/support/audio/TensorAudioTest.java | 505 ++---
- .../lite/support/common/FileUtilTest.java | 129 +-
- .../support/common/TensorProcessorTest.java | 91 +-
- .../lite/support/common/ops/CastOpTest.java | 91 +-
- .../support/common/ops/DequantizeOpTest.java | 23 +-
- .../support/common/ops/NormalizeOpTest.java | 217 +-
- .../support/common/ops/QuantizeOpTest.java | 21 +-
- .../support/image/BoundingBoxUtilTest.java | 343 ++--
- .../image/ColorSpaceTypeInstrumentedTest.java | 37 +-
- .../support/image/ColorSpaceTypeTest.java | 703 +++----
- .../ImageConversionsInstrumentedTest.java | 338 +--
- .../support/image/ImageConversionsTest.java | 164 +-
- .../image/ImageProcessorInstrumentedTest.java | 221 +-
- .../support/image/ImageProcessorTest.java | 209 +-
- .../support/image/MlImageAdapterTest.java | 259 +--
- .../image/TensorImageInstrumentedTest.java | 208 +-
- .../lite/support/image/TensorImageTest.java | 1391 ++++++-------
- .../lite/support/image/TestImageCreator.java | 183 +-
- .../image/ops/ResizeOpInstrumentedTest.java | 103 +-
- ...ResizeWithCropOrPadOpInstrumentedTest.java | 239 ++-
- .../image/ops/Rot90OpInstrumentedTest.java | 122 +-
- ...ransformToGrayScaleOpInstrumentedTest.java | 104 +-
- .../lite/support/label/CategoryTest.java | 204 +-
- .../lite/support/label/LabelUtilTest.java | 47 +-
- .../lite/support/label/TensorLabelTest.java | 327 +--
- .../support/label/ops/LabelAxisOpTest.java | 160 +-
- .../GpuDelegateProxyInstrumentedTest.java | 18 +-
- .../support/model/GpuDelegateProxyTest.java | 11 +-
- .../lite/support/model/ModelTest.java | 244 +--
- .../tensorbuffer/TensorBufferFloatTest.java | 82 +-
- .../tensorbuffer/TensorBufferTest.java | 1707 +++++++--------
- .../tensorbuffer/TensorBufferUint8Test.java | 82 +-
- .../audio/classifier/audio_classifier_jni.cc | 42 +-
- .../src/native/task/core/task_jni_utils.cc | 5 +-
- .../bert/bert_nl_classifier_jni.cc | 23 +-
- .../text/nlclassifier/nl_classifier_jni.cc | 21 +-
- .../text/qa/bert_question_answerer_jni.cc | 24 +-
- .../task/text/searcher/text_searcher_jni.cc | 36 +-
- .../vision/classifier/image_classifier_jni.cc | 27 +-
- .../vision/core/base_vision_task_api_jni.cc | 40 +-
- .../vision/detector/object_detector_jni.cc | 27 +-
- .../java/src/native/task/vision/jni_utils.cc | 30 +-
- .../java/src/native/task/vision/jni_utils.h | 28 +-
- .../vision/searcher/image_searcher_jni.cc | 36 +-
- .../vision/segmenter/image_segmenter_jni.cc | 32 +-
- .../metadata/cc/metadata_extractor.cc | 20 +-
- .../metadata/cc/metadata_extractor.h | 4 +-
- .../metadata/cc/metadata_populator.cc | 2 +-
- .../metadata/cc/metadata_populator.h | 7 +-
- .../metadata/cc/metadata_version.cc | 33 +-
- .../cc/utils/zip_readonly_mem_file.cc | 13 +-
- .../metadata/cc/utils/zip_readonly_mem_file.h | 4 +-
- .../cc/utils/zip_writable_mem_file.cc | 17 +-
- .../metadata/cc/utils/zip_writable_mem_file.h | 4 +-
- .../flatbuffers_lib/flatbuffers_lib.cc | 2 +-
- .../support/metadata/BoundedInputStream.java | 138 +-
- .../support/metadata/ByteBufferChannel.java | 188 +-
- .../support/metadata/MetadataExtractor.java | 622 +++---
- .../lite/support/metadata/MetadataParser.java | 12 +-
- .../lite/support/metadata/ModelInfo.java | 448 ++--
- .../support/metadata/ModelMetadataInfo.java | 243 ++-
- .../lite/support/metadata/Preconditions.java | 306 +--
- .../metadata/SeekableByteChannelCompat.java | 140 +-
- .../lite/support/metadata/ZipFile.java | 686 +++----
- .../metadata/BoundedInputStreamTest.java | 429 ++--
- .../metadata/ByteBufferChannelTest.java | 480 +++--
- .../metadata/MetadataExtractorTest.java | 1828 ++++++++---------
- .../support/metadata/MetadataParserTest.java | 18 +-
- .../lite/support/metadata/ZipFileTest.java | 206 +-
- .../odml/ios/image/apis/GMLImage.h | 47 +-
- .../android/odml/image/BitmapExtractor.java | 43 +-
- .../odml/image/BitmapImageContainer.java | 70 +-
- .../odml/image/BitmapMlImageBuilder.java | 137 +-
- .../odml/image/ByteBufferExtractor.java | 421 ++--
- .../odml/image/ByteBufferImageContainer.java | 68 +-
- .../odml/image/ByteBufferMlImageBuilder.java | 135 +-
- .../android/odml/image/ImageContainer.java | 12 +-
- .../android/odml/image/ImageProperties.java | 92 +-
- .../odml/image/MediaImageContainer.java | 81 +-
- .../odml/image/MediaImageExtractor.java | 42 +-
- .../odml/image/MediaMlImageBuilder.java | 105 +-
- .../google/android/odml/image/MlImage.java | 423 ++--
- .../odml/image/BitmapExtractorTest.java | 46 +-
- .../odml/image/BitmapMlImageBuilderTest.java | 116 +-
- .../odml/image/ByteBufferExtractorTest.java | 264 ++-
- .../image/ByteBufferMlImageBuilderTest.java | 93 +-
- .../odml/image/MediaImageExtractorTest.java | 48 +-
- .../odml/image/MediaMlImageBuilderTest.java | 109 +-
- .../android/odml/image/TestImageCreator.java | 211 +-
- .../core/pybinds/_pywrap_audio_buffer.cc | 17 +-
- .../audio/pybinds/_pywrap_audio_classifier.cc | 1 -
- .../audio/pybinds/_pywrap_audio_embedder.cc | 22 +-
- .../task/vision/core/pybinds/image_utils.cc | 4 +-
- .../pybinds/_pywrap_image_classifier.cc | 16 +-
- .../vision/pybinds/_pywrap_image_segmenter.cc | 12 +-
- .../vision/pybinds/_pywrap_object_detector.cc | 13 +-
- .../scann_ondevice/cc/core/index_table_sum.h | 41 +-
- .../scann_ondevice/cc/core/indexer.cc | 24 +-
- .../scann_ondevice/cc/core/indexer.h | 2 +-
- .../scann_ondevice/cc/core/indexer_test.cc | 6 +-
- .../scann_ondevice/cc/core/partitioner.cc | 8 +-
- .../scann_ondevice/cc/core/partitioner.h | 5 +-
- .../scann_ondevice/cc/core/searcher.h | 29 +-
- .../scann_ondevice/cc/core/searcher_test.cc | 9 +-
- .../cc/core/top_n_amortized_constant.h | 12 +-
- .../scann_ondevice/cc/index.cc | 23 +-
- .../scann_ondevice/cc/index.h | 13 +-
- .../scann_ondevice/cc/index_builder.cc | 24 +-
- .../scann_ondevice/cc/index_builder.h | 14 +-
- .../cc/mem_random_access_file.cc | 7 +-
- .../cc/mem_random_access_file.h | 8 +-
- .../scann_ondevice/cc/mem_writable_file.h | 8 +-
- .../cc/python/index_builder_py_wrapper.cc | 6 +-
- .../cc/test/index_builder_test.cc | 143 +-
- .../scann_ondevice/cc/test/index_test.cc | 33 +-
- .../cc/test/mem_writable_file_test.cc | 2 +-
- .../leveldb_testing_utils_py_wrapper.cc | 14 +-
- .../src/third_party/fft2d/fft.h | 12 +-
- .../src/third_party/fft2d/fft2d.h | 12 +-
- 420 files changed, 19248 insertions(+), 18509 deletions(-)
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/acceleration/configuration/edgetpu_coral_plugin.cc b/third_party/tflite_support/src/tensorflow_lite_support/acceleration/configuration/edgetpu_coral_plugin.cc
- index 9f27f3baae82f..6a16d12856258 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/acceleration/configuration/edgetpu_coral_plugin.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/acceleration/configuration/edgetpu_coral_plugin.cc
- @@ -17,12 +17,12 @@ limitations under the License.
-
- #include <glog/logging.h>
- #include "absl/container/node_hash_map.h" // from @com_google_absl
- -#include "absl/memory/memory.h" // from @com_google_absl
- -#include "absl/strings/match.h" // from @com_google_absl
- -#include "absl/strings/numbers.h" // from @com_google_absl
- -#include "tflite/public/edgetpu_c.h"
- +#include "absl/memory/memory.h" // from @com_google_absl
- +#include "absl/strings/match.h" // from @com_google_absl
- +#include "absl/strings/numbers.h" // from @com_google_absl
- #include "tensorflow/lite/experimental/acceleration/configuration/configuration_generated.h"
- #include "tensorflow/lite/experimental/acceleration/configuration/delegate_registry.h"
- +#include "tflite/public/edgetpu_c.h"
-
- namespace tflite {
- namespace delegates {
- @@ -50,12 +50,16 @@ inline std::string ConvertBool(bool from_bool) {
- return from_bool ? "True" : "False";
- }
-
- -bool MatchDevice(const std::string& device, const std::string& type,
- +bool MatchDevice(const std::string& device,
- + const std::string& type,
- int* index) {
- const auto prefix(type + ":");
- - if (!absl::StartsWith(device, prefix)) return false;
- - if (!absl::SimpleAtoi(device.substr(prefix.size()), index)) return false;
- - if (*index < 0) return false;
- + if (!absl::StartsWith(device, prefix))
- + return false;
- + if (!absl::SimpleAtoi(device.substr(prefix.size()), index))
- + return false;
- + if (*index < 0)
- + return false;
- return true;
- }
-
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/acceleration/configuration/edgetpu_coral_plugin_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/acceleration/configuration/edgetpu_coral_plugin_test.cc
- index a02635b9f3578..6ac4e5c734567 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/acceleration/configuration/edgetpu_coral_plugin_test.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/acceleration/configuration/edgetpu_coral_plugin_test.cc
- @@ -43,7 +43,8 @@ using ::tflite::task::vision::ImageDataFree;
-
- using EdgeTpuCoralPluginTest = testing::TestWithParam<std::string>;
-
- -INSTANTIATE_TEST_SUITE_P(CoralPluginTests, EdgeTpuCoralPluginTest,
- +INSTANTIATE_TEST_SUITE_P(CoralPluginTests,
- + EdgeTpuCoralPluginTest,
- testing::Values(kRegularModelFilePath,
- kEdgeTpuModelFilePath));
-
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/common.cc b/third_party/tflite_support/src/tensorflow_lite_support/c/common.cc
- index 2a182bbd6535a..f0974ed26b826 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/c/common.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/common.cc
- @@ -17,7 +17,7 @@ limitations under the License.
-
- #include <cstdlib>
-
- -void TfLiteSupportErrorDelete(TfLiteSupportError *error) {
- +void TfLiteSupportErrorDelete(TfLiteSupportError* error) {
- // `strdup` obtains memory using `malloc` and the memory needs to be
- // released using `free`.
- free(error->message);
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/common.h b/third_party/tflite_support/src/tensorflow_lite_support/c/common.h
- index 1e21f1dcb31dc..3ced64226987f 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/c/common.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/common.h
- @@ -190,10 +190,10 @@ typedef struct TfLiteSupportError {
- // Holds the error code.
- enum TfLiteSupportErrorCode code;
- // Detailed description of the error.
- - char *message;
- + char* message;
- } TfLiteSupportError;
-
- -void TfLiteSupportErrorDelete(TfLiteSupportError *error);
- +void TfLiteSupportErrorDelete(TfLiteSupportError* error);
-
- #ifdef __cplusplus
- } // extern "C"
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/common_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/c/common_utils.cc
- index 39287377c4b36..39afb9c8cbdf3 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/c/common_utils.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/common_utils.cc
- @@ -18,15 +18,17 @@ limitations under the License.
- #include <string>
-
- #include "absl/status/status.h" // from @com_google_absl
- -#include "absl/strings/cord.h" // from @com_google_absl
- +#include "absl/strings/cord.h" // from @com_google_absl
- #include "tensorflow_lite_support/cc/common.h"
-
- namespace tflite {
- namespace support {
-
- void CreateTfLiteSupportError(enum TfLiteSupportErrorCode code,
- - const char* message, TfLiteSupportError** error) {
- - if (error == nullptr) return;
- + const char* message,
- + TfLiteSupportError** error) {
- + if (error == nullptr)
- + return;
-
- *error = new TfLiteSupportError;
- (*error)->code = code;
- @@ -35,7 +37,8 @@ void CreateTfLiteSupportError(enum TfLiteSupportErrorCode code,
-
- void CreateTfLiteSupportErrorWithStatus(const absl::Status& status,
- TfLiteSupportError** error) {
- - if (status.ok() || error == nullptr) return;
- + if (status.ok() || error == nullptr)
- + return;
-
- // Payload of absl::Status created by the tflite task library stores an
- // appropriate value of the enum TfLiteSupportStatus. The integer value
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/common_utils.h b/third_party/tflite_support/src/tensorflow_lite_support/c/common_utils.h
- index 6959029575663..551f64a598970 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/c/common_utils.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/common_utils.h
- @@ -27,7 +27,8 @@ namespace support {
-
- // Creates a TfLiteSupportError with a TfLiteSupportErrorCode and message.
- void CreateTfLiteSupportError(enum TfLiteSupportErrorCode code,
- - const char* message, TfLiteSupportError** error);
- + const char* message,
- + TfLiteSupportError** error);
-
- // Creates a TfLiteSupportError from absl::Status and passes it back as a
- // parameter which is a pointer to the error pointer.
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/audio/audio_classifier.cc b/third_party/tflite_support/src/tensorflow_lite_support/c/task/audio/audio_classifier.cc
- index 89fba26b9b72f..3f1781a0a7db8 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/c/task/audio/audio_classifier.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/audio/audio_classifier.cc
- @@ -109,7 +109,8 @@ TfLiteAudioClassifierOptions TfLiteAudioClassifierOptionsCreate(void) {
- }
-
- TfLiteAudioClassifier* TfLiteAudioClassifierFromOptions(
- - const TfLiteAudioClassifierOptions* options, TfLiteSupportError** error) {
- + const TfLiteAudioClassifierOptions* options,
- + TfLiteSupportError** error) {
- StatusOr<AudioClassifierOptionsCpp> cpp_option_status =
- CreateAudioClassifierCppOptionsFromCOptions(options);
-
- @@ -181,7 +182,8 @@ TfLiteClassificationResult* GetClassificationResultCStruct(
-
- TfLiteClassificationResult* TfLiteAudioClassifierClassify(
- const TfLiteAudioClassifier* classifier,
- - const TfLiteAudioBuffer* audio_buffer, TfLiteSupportError** error) {
- + const TfLiteAudioBuffer* audio_buffer,
- + TfLiteSupportError** error) {
- if (classifier == nullptr) {
- tflite::support::CreateTfLiteSupportError(
- kInvalidArgumentError, "Expected non null audio classifier.", error);
- @@ -211,7 +213,8 @@ TfLiteClassificationResult* TfLiteAudioClassifierClassify(
- }
-
- int TfLiteAudioClassifierGetRequiredInputBufferSize(
- - TfLiteAudioClassifier* classifier, TfLiteSupportError** error) {
- + TfLiteAudioClassifier* classifier,
- + TfLiteSupportError** error) {
- if (classifier == nullptr) {
- tflite::support::CreateTfLiteSupportError(
- kInvalidArgumentError, "Expected non null audio classifier.", error);
- @@ -226,7 +229,8 @@ void TfLiteAudioClassifierDelete(TfLiteAudioClassifier* classifier) {
- }
-
- TfLiteAudioFormat* TfLiteAudioClassifierGetRequiredAudioFormat(
- - TfLiteAudioClassifier* classifier, TfLiteSupportError** error) {
- + TfLiteAudioClassifier* classifier,
- + TfLiteSupportError** error) {
- if (classifier == nullptr) {
- tflite::support::CreateTfLiteSupportError(
- kInvalidArgumentError, "Expected non null audio classifier.", error);
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/audio/audio_classifier.h b/third_party/tflite_support/src/tensorflow_lite_support/c/task/audio/audio_classifier.h
- index e83295963378c..6af9b27944744 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/c/task/audio/audio_classifier.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/audio/audio_classifier.h
- @@ -157,7 +157,8 @@ TfLiteAudioClassifierOptions TfLiteAudioClassifierOptionsCreate(void);
- // TfLiteSupportErrorDelete(error)
- //
- TfLiteAudioClassifier* TfLiteAudioClassifierFromOptions(
- - const TfLiteAudioClassifierOptions* options, TfLiteSupportError** error);
- + const TfLiteAudioClassifierOptions* options,
- + TfLiteSupportError** error);
-
- // Invokes the encapsulated TFLite model and classifies the frame_buffer.
- // Returns a pointer to the created classification result in case of success or
- @@ -185,15 +186,18 @@ TfLiteAudioClassifier* TfLiteAudioClassifierFromOptions(
- //
- TfLiteClassificationResult* TfLiteAudioClassifierClassify(
- const TfLiteAudioClassifier* classifier,
- - const TfLiteAudioBuffer* audio_buffer, TfLiteSupportError** error);
- + const TfLiteAudioBuffer* audio_buffer,
- + TfLiteSupportError** error);
-
- // Returns the input buffer size required by the audio classifier.
- int TfLiteAudioClassifierGetRequiredInputBufferSize(
- - TfLiteAudioClassifier* classifier, TfLiteSupportError** error);
- + TfLiteAudioClassifier* classifier,
- + TfLiteSupportError** error);
-
- // Returns the audio format required by the audio classifier.
- TfLiteAudioFormat* TfLiteAudioClassifierGetRequiredAudioFormat(
- - TfLiteAudioClassifier* classifier, TfLiteSupportError** error);
- + TfLiteAudioClassifier* classifier,
- + TfLiteSupportError** error);
-
- // Disposes off the audio classifier.
- void TfLiteAudioClassifierDelete(TfLiteAudioClassifier* classifier);
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/audio/core/audio_buffer.h b/third_party/tflite_support/src/tensorflow_lite_support/c/task/audio/core/audio_buffer.h
- index 2ec7571036d29..471f02fdf2132 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/c/task/audio/core/audio_buffer.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/audio/core/audio_buffer.h
- @@ -45,11 +45,11 @@ typedef struct TfLiteAudioBuffer {
- int size;
- } TfLiteAudioBuffer;
-
- -void TfLiteAudioBufferDelete(TfLiteAudioBuffer *buffer);
- +void TfLiteAudioBufferDelete(TfLiteAudioBuffer* buffer);
-
- void TfLiteAudioBufferDeleteData(const TfLiteAudioBuffer audio_buffer);
-
- -void TfLiteAudioFormatDelete(TfLiteAudioFormat *format);
- +void TfLiteAudioFormatDelete(TfLiteAudioFormat* format);
-
- #ifdef __cplusplus
- } // extern "C"
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/processor/classification_result.cc b/third_party/tflite_support/src/tensorflow_lite_support/c/task/processor/classification_result.cc
- index 646e2c237c2f8..b7d7fab827694 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/c/task/processor/classification_result.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/processor/classification_result.cc
- @@ -27,7 +27,7 @@ void TfLiteClassificationResultDelete(
- for (int head = 0; head < classification_result->size; ++head) {
- TfLiteClassifications classifications =
- classification_result->classifications[head];
- - free(classifications.head_name);
- + free(classifications.head_name);
- for (int rank = 0; rank < classifications.size; ++rank) {
- TfLiteCategoryDelete(&(classifications.categories[rank]));
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_nl_classifier.cc b/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_nl_classifier.cc
- index 26888a832fc34..52907f4fe7d35 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_nl_classifier.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_nl_classifier.cc
- @@ -40,7 +40,8 @@ struct TfLiteBertNLClassifier {
- };
-
- TfLiteBertNLClassifier* TfLiteBertNLClassifierCreateFromOptions(
- - const char* model_path, const TfLiteBertNLClassifierOptions* options) {
- + const char* model_path,
- + const TfLiteBertNLClassifierOptions* options) {
- BertNLClassifierOptionsCpp cc_options;
-
- cc_options.mutable_base_options()->mutable_model_file()->set_file_name(
- @@ -64,7 +65,8 @@ TfLiteBertNLClassifier* TfLiteBertNLClassifierCreate(const char* model_path) {
- }
-
- Categories* TfLiteBertNLClassifierClassify(
- - const TfLiteBertNLClassifier* classifier, const char* text) {
- + const TfLiteBertNLClassifier* classifier,
- + const char* text) {
- std::vector<CategoryCpp> results =
-
- classifier->impl->Classify(absl::string_view(text).data());
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_nl_classifier.h b/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_nl_classifier.h
- index 430f5735c6bd2..94138a291233b 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_nl_classifier.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_nl_classifier.h
- @@ -48,7 +48,8 @@ typedef struct TfLiteBertNLClassifierOptions {
- // Creates TfLiteBertNLClassifier from model path and options, returns nullptr
- // if the file doesn't exist or is not a well formatted TFLite model path.
- TfLiteBertNLClassifier* TfLiteBertNLClassifierCreateFromOptions(
- - const char* model_path, const TfLiteBertNLClassifierOptions* options);
- + const char* model_path,
- + const TfLiteBertNLClassifierOptions* options);
-
- // Creates TfLiteBertNLClassifier from model path and default options, returns
- // nullptr if the file doesn't exist or is not a well formatted TFLite model
- @@ -57,7 +58,8 @@ TfLiteBertNLClassifier* TfLiteBertNLClassifierCreate(const char* model_path);
-
- // Invokes the encapsulated TFLite model and classifies the input text.
- Categories* TfLiteBertNLClassifierClassify(
- - const TfLiteBertNLClassifier* classifier, const char* text);
- + const TfLiteBertNLClassifier* classifier,
- + const char* text);
-
- void TfLiteBertNLClassifierDelete(TfLiteBertNLClassifier* classifier);
-
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_question_answerer.cc b/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_question_answerer.cc
- index d0d1639357348..1887d5234d180 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_question_answerer.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_question_answerer.cc
- @@ -48,7 +48,8 @@ TfLiteBertQuestionAnswerer* TfLiteBertQuestionAnswererCreate(
- }
-
- TfLiteQaAnswers* TfLiteBertQuestionAnswererAnswer(
- - const TfLiteBertQuestionAnswerer* question_answerer, const char* context,
- + const TfLiteBertQuestionAnswerer* question_answerer,
- + const char* context,
- const char* question) {
- std::vector<QaAnswerCpp> answers = question_answerer->impl->Answer(
- absl::string_view(context).data(), absl::string_view(question).data());
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_question_answerer.h b/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_question_answerer.h
- index 7bc6e6ed385db..e9a1190356914 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_question_answerer.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_question_answerer.h
- @@ -58,7 +58,8 @@ TfLiteBertQuestionAnswerer* TfLiteBertQuestionAnswererCreate(
- // Invokes the encapsulated TFLite model and answers a question based on
- // context.
- TfLiteQaAnswers* TfLiteBertQuestionAnswererAnswer(
- - const TfLiteBertQuestionAnswerer* question_answerer, const char* context,
- + const TfLiteBertQuestionAnswerer* question_answerer,
- + const char* context,
- const char* question);
-
- void TfLiteBertQuestionAnswererDelete(
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/nl_classifier.cc b/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/nl_classifier.cc
- index d6d86f67a620a..1e6805c1d1cd6 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/nl_classifier.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/nl_classifier.cc
- @@ -37,7 +37,8 @@ struct TfLiteNLClassifier {
- };
-
- TfLiteNLClassifier* TfLiteNLClassifierCreateFromOptions(
- - const char* model_path, const TfLiteNLClassifierOptions* options) {
- + const char* model_path,
- + const TfLiteNLClassifierOptions* options) {
- auto classifier_status = NLClassifierCpp::CreateFromFileAndOptions(
- std::string(model_path),
- {
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/nl_classifier.h b/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/nl_classifier.h
- index c47dd59b13eb4..389ca5d686df0 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/nl_classifier.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/nl_classifier.h
- @@ -48,7 +48,8 @@ typedef struct TfLiteNLClassifierOptions {
- // Creates TfLiteNLClassifier from model path and options, returns nullptr if
- // the file doesn't exist or is not a well formatted TFLite model path.
- TfLiteNLClassifier* TfLiteNLClassifierCreateFromOptions(
- - const char* model_path, const TfLiteNLClassifierOptions* options);
- + const char* model_path,
- + const TfLiteNLClassifierOptions* options);
-
- // Invokes the encapsulated TFLite model and classifies the input text.
- Categories* TfLiteNLClassifierClassify(const TfLiteNLClassifier* classifier,
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/image_classifier.cc b/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/image_classifier.cc
- index 52e215116b51e..183468a6855aa 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/image_classifier.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/image_classifier.cc
- @@ -110,7 +110,8 @@ TfLiteImageClassifierOptions TfLiteImageClassifierOptionsCreate(void) {
- }
-
- TfLiteImageClassifier* TfLiteImageClassifierFromOptions(
- - const TfLiteImageClassifierOptions* options, TfLiteSupportError** error) {
- + const TfLiteImageClassifierOptions* options,
- + TfLiteSupportError** error) {
- StatusOr<ImageClassifierOptionsCpp> cpp_option_status =
- CreateImageClassifierCppOptionsFromCOptions(options);
-
- @@ -178,7 +179,8 @@ TfLiteClassificationResult* GetClassificationResultCStruct(
-
- TfLiteClassificationResult* TfLiteImageClassifierClassifyWithRoi(
- const TfLiteImageClassifier* classifier,
- - const TfLiteFrameBuffer* frame_buffer, const TfLiteBoundingBox* roi,
- + const TfLiteFrameBuffer* frame_buffer,
- + const TfLiteBoundingBox* roi,
- TfLiteSupportError** error) {
- if (classifier == nullptr) {
- tflite::support::CreateTfLiteSupportError(
- @@ -221,7 +223,8 @@ TfLiteClassificationResult* TfLiteImageClassifierClassifyWithRoi(
-
- TfLiteClassificationResult* TfLiteImageClassifierClassify(
- const TfLiteImageClassifier* classifier,
- - const TfLiteFrameBuffer* frame_buffer, TfLiteSupportError** error) {
- + const TfLiteFrameBuffer* frame_buffer,
- + TfLiteSupportError** error) {
- return TfLiteImageClassifierClassifyWithRoi(classifier, frame_buffer, nullptr,
- error);
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/image_classifier.h b/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/image_classifier.h
- index dca83e00f9455..837c9894a2302 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/image_classifier.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/image_classifier.h
- @@ -158,7 +158,8 @@ TfLiteImageClassifierOptions TfLiteImageClassifierOptionsCreate(void);
- // TfLiteSupportErrorDelete(error)
- //
- TfLiteImageClassifier* TfLiteImageClassifierFromOptions(
- - const TfLiteImageClassifierOptions* options, TfLiteSupportError** error);
- + const TfLiteImageClassifierOptions* options,
- + TfLiteSupportError** error);
-
- // Invokes the encapsulated TFLite model and classifies the frame_buffer.
- // Returns a pointer to the created classification result in case of success or
- @@ -186,7 +187,8 @@ TfLiteImageClassifier* TfLiteImageClassifierFromOptions(
- //
- TfLiteClassificationResult* TfLiteImageClassifierClassify(
- const TfLiteImageClassifier* classifier,
- - const TfLiteFrameBuffer* frame_buffer, TfLiteSupportError** error);
- + const TfLiteFrameBuffer* frame_buffer,
- + TfLiteSupportError** error);
-
- // Invokes the encapsulated TFLite model and classifies the region of the
- // frame_buffer specified by the bounding box. Same as TfLiteImageClassifier*
- @@ -198,7 +200,8 @@ TfLiteClassificationResult* TfLiteImageClassifierClassify(
- // operations.
- TfLiteClassificationResult* TfLiteImageClassifierClassifyWithRoi(
- const TfLiteImageClassifier* classifier,
- - const TfLiteFrameBuffer* frame_buffer, const TfLiteBoundingBox* roi,
- + const TfLiteFrameBuffer* frame_buffer,
- + const TfLiteBoundingBox* roi,
- TfLiteSupportError** error);
-
- // Disposes off the image classifier.
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/image_segmenter.cc b/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/image_segmenter.cc
- index e7395ddbde80e..d2cf362e82ed7 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/image_segmenter.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/image_segmenter.cc
- @@ -92,7 +92,8 @@ TfLiteImageSegmenterOptions TfLiteImageSegmenterOptionsCreate(void) {
- }
-
- TfLiteImageSegmenter* TfLiteImageSegmenterFromOptions(
- - const TfLiteImageSegmenterOptions* options, TfLiteSupportError** error) {
- + const TfLiteImageSegmenterOptions* options,
- + TfLiteSupportError** error) {
- StatusOr<ImageSegmenterOptionsCpp> cpp_option_status =
- CreateImageSegmenterCppOptionsFromCOptions(options);
-
- @@ -182,7 +183,8 @@ TfLiteSegmentationResult* GetSegmentationResultCStruct(
-
- TfLiteSegmentationResult* TfLiteImageSegmenterSegment(
- const TfLiteImageSegmenter* segmenter,
- - const TfLiteFrameBuffer* frame_buffer, TfLiteSupportError** error) {
- + const TfLiteFrameBuffer* frame_buffer,
- + TfLiteSupportError** error) {
- if (segmenter == nullptr) {
- tflite::support::CreateTfLiteSupportError(
- kInvalidArgumentError, "Expected non null image segmenter.", error);
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/image_segmenter.h b/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/image_segmenter.h
- index c2964fad2c144..e0dc62e224b99 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/image_segmenter.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/image_segmenter.h
- @@ -172,7 +172,8 @@ TfLiteImageSegmenterOptions TfLiteImageSegmenterOptionsCreate(void);
- // TfLiteSupportErrorDelete(error)
- //
- TfLiteImageSegmenter* TfLiteImageSegmenterFromOptions(
- - const TfLiteImageSegmenterOptions* options, TfLiteSupportError** error);
- + const TfLiteImageSegmenterOptions* options,
- + TfLiteSupportError** error);
-
- // Invokes the encapsulated TFLite model and performs image segmentation on
- // the frame_buffer.
- @@ -201,7 +202,8 @@ TfLiteImageSegmenter* TfLiteImageSegmenterFromOptions(
- //
- TfLiteSegmentationResult* TfLiteImageSegmenterSegment(
- const TfLiteImageSegmenter* segmenter,
- - const TfLiteFrameBuffer* frame_buffer, TfLiteSupportError** error);
- + const TfLiteFrameBuffer* frame_buffer,
- + TfLiteSupportError** error);
-
- // Disposes of the image segmenter.
- void TfLiteImageSegmenterDelete(TfLiteImageSegmenter* segmenter);
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/object_detector.cc b/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/object_detector.cc
- index 1389a2de0ee75..92535e863b9a3 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/object_detector.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/object_detector.cc
- @@ -109,7 +109,8 @@ TfLiteObjectDetectorOptions TfLiteObjectDetectorOptionsCreate(void) {
- }
-
- TfLiteObjectDetector* TfLiteObjectDetectorFromOptions(
- - const TfLiteObjectDetectorOptions* options, TfLiteSupportError** error) {
- + const TfLiteObjectDetectorOptions* options,
- + TfLiteSupportError** error) {
- StatusOr<ObjectDetectorOptionsCpp> cpp_option_status =
- CreateObjectDetectorCppOptionsFromCOptions(options);
-
- @@ -174,7 +175,8 @@ TfLiteDetectionResult* GetDetectionResultCStruct(
- }
-
- TfLiteDetectionResult* TfLiteObjectDetectorDetect(
- - const TfLiteObjectDetector* detector, const TfLiteFrameBuffer* frame_buffer,
- + const TfLiteObjectDetector* detector,
- + const TfLiteFrameBuffer* frame_buffer,
- TfLiteSupportError** error) {
- if (detector == nullptr) {
- tflite::support::CreateTfLiteSupportError(
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/object_detector.h b/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/object_detector.h
- index e2e08ec161559..b4d4564fefeb0 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/object_detector.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/object_detector.h
- @@ -157,7 +157,8 @@ TfLiteObjectDetectorOptions TfLiteObjectDetectorOptionsCreate(void);
- // TfLiteSupportErrorDelete(error)
- //
- TfLiteObjectDetector* TfLiteObjectDetectorFromOptions(
- - const TfLiteObjectDetectorOptions* options, TfLiteSupportError** error);
- + const TfLiteObjectDetectorOptions* options,
- + TfLiteSupportError** error);
-
- // Invokes the encapsulated TFLite model and performs object detection on the
- // frame_buffer. Returns a pointer to the created object detection result result
- @@ -185,7 +186,8 @@ TfLiteObjectDetector* TfLiteObjectDetectorFromOptions(
- // TfLiteSupportErrorDelete(error)
- //
- TfLiteDetectionResult* TfLiteObjectDetectorDetect(
- - const TfLiteObjectDetector* detector, const TfLiteFrameBuffer* frame_buffer,
- + const TfLiteObjectDetector* detector,
- + const TfLiteFrameBuffer* frame_buffer,
- TfLiteSupportError** error);
-
- // Disposes off the object detector.
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/audio/audio_classifier_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/audio/audio_classifier_test.cc
- index 17b2a4ccede29..126784cf6c755 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/audio/audio_classifier_test.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/audio/audio_classifier_test.cc
- @@ -45,9 +45,10 @@ constexpr char kYamNetAudioClassifierWithMetadata[] =
- "yamnet_audio_classifier_with_metadata.tflite";
-
- StatusOr<TfLiteAudioBuffer> LoadAudioBufferFromFileNamed(
- - const std::string wav_file, int buffer_size) {
- - std::string contents = ReadFile(
- - JoinPath("./" /*test src dir*/, kTestDataDirectory, wav_file));
- + const std::string wav_file,
- + int buffer_size) {
- + std::string contents =
- + ReadFile(JoinPath("./" /*test src dir*/, kTestDataDirectory, wav_file));
-
- uint32_t decoded_sample_count;
- uint16_t decoded_channel_count;
- @@ -90,7 +91,8 @@ void Verify(TfLiteClassificationResult* classification_result,
- }
-
- void Verify(TfLiteClassifications& classifications,
- - int expected_categories_size, int expected_head_index,
- + int expected_categories_size,
- + int expected_head_index,
- char const* expected_head_name) {
- EXPECT_EQ(classifications.size, expected_categories_size);
- EXPECT_EQ(classifications.head_index, expected_head_index);
- @@ -101,8 +103,10 @@ void Verify(TfLiteClassifications& classifications,
- EXPECT_NE(classifications.categories, nullptr);
- }
-
- -void Verify(TfLiteCategory& category, int expected_index,
- - char const* expected_label, float expected_score) {
- +void Verify(TfLiteCategory& category,
- + int expected_index,
- + char const* expected_label,
- + float expected_score) {
- const float kPrecision = 1e-6;
- EXPECT_EQ(category.index, expected_index);
- EXPECT_NE(category.label, nullptr);
- @@ -115,7 +119,8 @@ void Verify(TfLiteCategory& category, int expected_index,
- EXPECT_NEAR(category.score, expected_score, kPrecision);
- }
-
- -void Verify(TfLiteSupportError* error, TfLiteSupportErrorCode error_code,
- +void Verify(TfLiteSupportError* error,
- + TfLiteSupportErrorCode error_code,
- char const* message) {
- ASSERT_NE(error, nullptr);
- EXPECT_EQ(error->code, kInvalidArgumentError);
- @@ -133,7 +138,8 @@ TEST_F(AudioClassifierFromOptionsTest, FailsWithMissingModelPathAndError) {
- TfLiteAudioClassifierFromOptions(&options, &error);
-
- EXPECT_EQ(audio_classifier, nullptr);
- - if (audio_classifier) TfLiteAudioClassifierDelete(audio_classifier);
- + if (audio_classifier)
- + TfLiteAudioClassifierDelete(audio_classifier);
-
- Verify(error, kInvalidArgumentError,
- "INVALID_ARGUMENT: Missing mandatory `model_file` field in "
- @@ -143,9 +149,8 @@ TEST_F(AudioClassifierFromOptionsTest, FailsWithMissingModelPathAndError) {
- }
-
- TEST_F(AudioClassifierFromOptionsTest, SucceedsWithModelPath) {
- - std::string model_path =
- - JoinPath("./" /*test src dir*/, kTestDataDirectory,
- - kYamNetAudioClassifierWithMetadata);
- + std::string model_path = JoinPath("./" /*test src dir*/, kTestDataDirectory,
- + kYamNetAudioClassifierWithMetadata);
- TfLiteAudioClassifierOptions options = TfLiteAudioClassifierOptionsCreate();
- options.base_options.model_file.file_path = model_path.data();
- TfLiteAudioClassifier* audio_classifier =
- @@ -158,9 +163,8 @@ TEST_F(AudioClassifierFromOptionsTest, SucceedsWithModelPath) {
- class AudioClassifierClassifyTest : public tflite_shims::testing::Test {
- protected:
- void SetUp() override {
- - std::string model_path =
- - JoinPath("./" /*test src dir*/, kTestDataDirectory,
- - kYamNetAudioClassifierWithMetadata);
- + std::string model_path = JoinPath("./" /*test src dir*/, kTestDataDirectory,
- + kYamNetAudioClassifierWithMetadata);
-
- TfLiteAudioClassifierOptions options = TfLiteAudioClassifierOptionsCreate();
- options.base_options.model_file.file_path = model_path.data();
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/vision/image_classifier_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/vision/image_classifier_test.cc
- index 0a59344f4394c..cce2fa63fad17 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/vision/image_classifier_test.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/vision/image_classifier_test.cc
- @@ -44,8 +44,8 @@ constexpr char kMobileNetQuantizedWithMetadata[] =
- "mobilenet_v1_0.25_224_quant.tflite";
-
- StatusOr<ImageData> LoadImage(const char* image_name) {
- - return DecodeImageFromFile(JoinPath("./" /*test src dir*/,
- - kTestDataDirectory, image_name));
- + return DecodeImageFromFile(
- + JoinPath("./" /*test src dir*/, kTestDataDirectory, image_name));
- }
-
- class ImageClassifierFromOptionsTest : public tflite_shims::testing::Test {};
- @@ -56,7 +56,8 @@ TEST_F(ImageClassifierFromOptionsTest, FailsWithNullOptionsAndError) {
- TfLiteImageClassifierFromOptions(nullptr, &error);
-
- EXPECT_EQ(image_classifier, nullptr);
- - if (image_classifier) TfLiteImageClassifierDelete(image_classifier);
- + if (image_classifier)
- + TfLiteImageClassifierDelete(image_classifier);
-
- ASSERT_NE(error, nullptr);
- EXPECT_EQ(error->code, kInvalidArgumentError);
- @@ -71,7 +72,8 @@ TEST_F(ImageClassifierFromOptionsTest, FailsWithMissingModelPath) {
- TfLiteImageClassifier* image_classifier =
- TfLiteImageClassifierFromOptions(&options, nullptr);
- EXPECT_EQ(image_classifier, nullptr);
- - if (image_classifier) TfLiteImageClassifierDelete(image_classifier);
- + if (image_classifier)
- + TfLiteImageClassifierDelete(image_classifier);
- }
-
- TEST_F(ImageClassifierFromOptionsTest, FailsWithMissingModelPathAndError) {
- @@ -82,7 +84,8 @@ TEST_F(ImageClassifierFromOptionsTest, FailsWithMissingModelPathAndError) {
- TfLiteImageClassifierFromOptions(&options, &error);
-
- EXPECT_EQ(image_classifier, nullptr);
- - if (image_classifier) TfLiteImageClassifierDelete(image_classifier);
- + if (image_classifier)
- + TfLiteImageClassifierDelete(image_classifier);
-
- ASSERT_NE(error, nullptr);
- EXPECT_EQ(error->code, kInvalidArgumentError);
- @@ -93,9 +96,8 @@ TEST_F(ImageClassifierFromOptionsTest, FailsWithMissingModelPathAndError) {
- }
-
- TEST_F(ImageClassifierFromOptionsTest, SucceedsWithModelPath) {
- - std::string model_path =
- - JoinPath("./" /*test src dir*/, kTestDataDirectory,
- - kMobileNetQuantizedWithMetadata);
- + std::string model_path = JoinPath("./" /*test src dir*/, kTestDataDirectory,
- + kMobileNetQuantizedWithMetadata);
- TfLiteImageClassifierOptions options = TfLiteImageClassifierOptionsCreate();
- options.base_options.model_file.file_path = model_path.data();
- TfLiteImageClassifier* image_classifier =
- @@ -106,9 +108,8 @@ TEST_F(ImageClassifierFromOptionsTest, SucceedsWithModelPath) {
- }
-
- TEST_F(ImageClassifierFromOptionsTest, SucceedsWithNumberOfThreadsAndError) {
- - std::string model_path =
- - JoinPath("./" /*test src dir*/, kTestDataDirectory,
- - kMobileNetQuantizedWithMetadata);
- + std::string model_path = JoinPath("./" /*test src dir*/, kTestDataDirectory,
- + kMobileNetQuantizedWithMetadata);
- TfLiteImageClassifierOptions options = TfLiteImageClassifierOptionsCreate();
- options.base_options.model_file.file_path = model_path.data();
- options.base_options.compute_settings.cpu_settings.num_threads = 3;
- @@ -120,15 +121,16 @@ TEST_F(ImageClassifierFromOptionsTest, SucceedsWithNumberOfThreadsAndError) {
- EXPECT_NE(image_classifier, nullptr);
- EXPECT_EQ(error, nullptr);
-
- - if (image_classifier) TfLiteImageClassifierDelete(image_classifier);
- - if (error) TfLiteSupportErrorDelete(error);
- + if (image_classifier)
- + TfLiteImageClassifierDelete(image_classifier);
- + if (error)
- + TfLiteSupportErrorDelete(error);
- }
-
- TEST_F(ImageClassifierFromOptionsTest,
- FailsWithClassNameDenyListAndClassNameAllowListAndError) {
- - std::string model_path =
- - JoinPath("./" /*test src dir*/, kTestDataDirectory,
- - kMobileNetQuantizedWithMetadata);
- + std::string model_path = JoinPath("./" /*test src dir*/, kTestDataDirectory,
- + kMobileNetQuantizedWithMetadata);
-
- TfLiteImageClassifierOptions options = TfLiteImageClassifierOptionsCreate();
- options.base_options.model_file.file_path = model_path.data();
- @@ -146,7 +148,8 @@ TEST_F(ImageClassifierFromOptionsTest,
- TfLiteImageClassifierFromOptions(&options, &error);
-
- EXPECT_EQ(image_classifier, nullptr);
- - if (image_classifier) TfLiteImageClassifierDelete(image_classifier);
- + if (image_classifier)
- + TfLiteImageClassifierDelete(image_classifier);
-
- ASSERT_NE(error, nullptr);
- EXPECT_EQ(error->code, kInvalidArgumentError);
- @@ -158,7 +161,8 @@ TEST_F(ImageClassifierFromOptionsTest,
-
- TEST(ImageClassifierNullClassifierClassifyTest,
- FailsWithNullImageClassifierAndError) {
- - SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data, LoadImage("burger-224.png"));
- + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data,
- + LoadImage("burger-224.png"));
-
- TfLiteSupportError* error = nullptr;
- TfLiteClassificationResult* classification_result =
- @@ -181,9 +185,8 @@ TEST(ImageClassifierNullClassifierClassifyTest,
- class ImageClassifierClassifyTest : public tflite_shims::testing::Test {
- protected:
- void SetUp() override {
- - std::string model_path =
- - JoinPath("./" /*test src dir*/, kTestDataDirectory,
- - kMobileNetQuantizedWithMetadata);
- + std::string model_path = JoinPath("./" /*test src dir*/, kTestDataDirectory,
- + kMobileNetQuantizedWithMetadata);
-
- TfLiteImageClassifierOptions options = TfLiteImageClassifierOptionsCreate();
- options.base_options.model_file.file_path = model_path.data();
- @@ -196,7 +199,8 @@ class ImageClassifierClassifyTest : public tflite_shims::testing::Test {
- };
-
- TEST_F(ImageClassifierClassifyTest, SucceedsWithImageData) {
- - SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data, LoadImage("burger-224.png"));
- + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data,
- + LoadImage("burger-224.png"));
-
- TfLiteFrameBuffer frame_buffer = {
- .format = kRGB,
- @@ -223,7 +227,8 @@ TEST_F(ImageClassifierClassifyTest, SucceedsWithImageData) {
- }
-
- TEST_F(ImageClassifierClassifyTest, FailsWithNullFrameBufferAndError) {
- - SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data, LoadImage("burger-224.png"));
- + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data,
- + LoadImage("burger-224.png"));
-
- TfLiteSupportError* error = nullptr;
- TfLiteClassificationResult* classification_result =
- @@ -244,7 +249,8 @@ TEST_F(ImageClassifierClassifyTest, FailsWithNullFrameBufferAndError) {
- }
-
- TEST_F(ImageClassifierClassifyTest, FailsWithNullImageDataAndError) {
- - SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data, LoadImage("burger-224.png"));
- + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data,
- + LoadImage("burger-224.png"));
-
- TfLiteFrameBuffer frame_buffer = {.format = kRGB, .orientation = kTopLeft};
-
- @@ -267,7 +273,8 @@ TEST_F(ImageClassifierClassifyTest, FailsWithNullImageDataAndError) {
- }
-
- TEST_F(ImageClassifierClassifyTest, SucceedsWithRoiWithinImageBounds) {
- - SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data, LoadImage("burger-224.png"));
- + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data,
- + LoadImage("burger-224.png"));
-
- TfLiteFrameBuffer frame_buffer = {
- .format = kRGB,
- @@ -298,7 +305,8 @@ TEST_F(ImageClassifierClassifyTest, SucceedsWithRoiWithinImageBounds) {
- }
-
- TEST_F(ImageClassifierClassifyTest, FailsWithRoiOutsideImageBoundsAndError) {
- - SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data, LoadImage("burger-224.png"));
- + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data,
- + LoadImage("burger-224.png"));
-
- TfLiteFrameBuffer frame_buffer = {
- .format = kRGB,
- @@ -330,9 +338,8 @@ TEST_F(ImageClassifierClassifyTest, FailsWithRoiOutsideImageBoundsAndError) {
- TEST(ImageClassifierWithUserDefinedOptionsClassifyTest,
- SucceedsWithClassNameDenyList) {
- char* denylisted_label_name = (char*)"cheeseburger";
- - std::string model_path =
- - JoinPath("./" /*test src dir*/, kTestDataDirectory,
- - kMobileNetQuantizedWithMetadata);
- + std::string model_path = JoinPath("./" /*test src dir*/, kTestDataDirectory,
- + kMobileNetQuantizedWithMetadata);
-
- TfLiteImageClassifierOptions options = TfLiteImageClassifierOptionsCreate();
- options.base_options.model_file.file_path = model_path.data();
- @@ -345,7 +352,8 @@ TEST(ImageClassifierWithUserDefinedOptionsClassifyTest,
- TfLiteImageClassifierFromOptions(&options, nullptr);
- ASSERT_NE(image_classifier, nullptr);
-
- - SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data, LoadImage("burger-224.png"));
- + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data,
- + LoadImage("burger-224.png"));
-
- TfLiteFrameBuffer frame_buffer = {
- .format = kRGB,
- @@ -357,7 +365,8 @@ TEST(ImageClassifierWithUserDefinedOptionsClassifyTest,
- TfLiteImageClassifierClassify(image_classifier, &frame_buffer, nullptr);
-
- ImageDataFree(&image_data);
- - if (image_classifier) TfLiteImageClassifierDelete(image_classifier);
- + if (image_classifier)
- + TfLiteImageClassifierDelete(image_classifier);
-
- ASSERT_NE(classification_result, nullptr);
- EXPECT_GE(classification_result->size, 1);
- @@ -374,10 +383,9 @@ TEST(ImageClassifierWithUserDefinedOptionsClassifyTest,
- TEST(ImageClassifierWithUserDefinedOptionsClassifyTest,
- SucceedsWithClassNameAllowList) {
- char* allowlisted_label_name = (char*)"cheeseburger";
- - std::string model_path =
- - JoinPath("./" /*test src dir*/, kTestDataDirectory,
- - kMobileNetQuantizedWithMetadata)
- - .data();
- + std::string model_path = JoinPath("./" /*test src dir*/, kTestDataDirectory,
- + kMobileNetQuantizedWithMetadata)
- + .data();
-
- TfLiteImageClassifierOptions options = TfLiteImageClassifierOptionsCreate();
- options.base_options.model_file.file_path = model_path.data();
- @@ -390,7 +398,8 @@ TEST(ImageClassifierWithUserDefinedOptionsClassifyTest,
- TfLiteImageClassifierFromOptions(&options, nullptr);
- ASSERT_NE(image_classifier, nullptr);
-
- - SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data, LoadImage("burger-224.png"));
- + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data,
- + LoadImage("burger-224.png"));
-
- TfLiteFrameBuffer frame_buffer = {
- .format = kRGB,
- @@ -402,7 +411,8 @@ TEST(ImageClassifierWithUserDefinedOptionsClassifyTest,
- TfLiteImageClassifierClassify(image_classifier, &frame_buffer, nullptr);
-
- ImageDataFree(&image_data);
- - if (image_classifier) TfLiteImageClassifierDelete(image_classifier);
- + if (image_classifier)
- + TfLiteImageClassifierDelete(image_classifier);
-
- ASSERT_NE(classification_result, nullptr);
- EXPECT_GE(classification_result->size, 1);
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/vision/image_segmenter_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/vision/image_segmenter_test.cc
- index d4c8106b2729d..c03c15d6fe6b7 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/vision/image_segmenter_test.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/vision/image_segmenter_test.cc
- @@ -46,8 +46,8 @@ constexpr char kTestDataDirectory[] =
- constexpr char kDeepLabV3[] = "deeplabv3.tflite";
-
- StatusOr<ImageData> LoadImage(const char* image_name) {
- - return DecodeImageFromFile(JoinPath("./" /*test src dir*/,
- - kTestDataDirectory, image_name));
- + return DecodeImageFromFile(
- + JoinPath("./" /*test src dir*/, kTestDataDirectory, image_name));
- }
-
- // The maximum fraction of pixels in the candidate mask that can have a
- @@ -59,8 +59,11 @@ constexpr float kGoldenMaskTolerance = 1e-2;
- // 20 means class index 2, etc.
- constexpr int kGoldenMaskMagnificationFactor = 10;
-
- -void InitializeColoredLabel(TfLiteColoredLabel& colored_label, uint32_t r,
- - uint32_t g, uint32_t b, const char* label) {
- +void InitializeColoredLabel(TfLiteColoredLabel& colored_label,
- + uint32_t r,
- + uint32_t g,
- + uint32_t b,
- + const char* label) {
- colored_label.r = r;
- colored_label.g = g;
- colored_label.b = b;
- @@ -129,7 +132,8 @@ TEST_F(ImageSegmenterFromOptionsTest, FailsWithNullOptionsAndError) {
-
- EXPECT_EQ(image_segmenter, nullptr);
-
- - if (image_segmenter) TfLiteImageSegmenterDelete(image_segmenter);
- + if (image_segmenter)
- + TfLiteImageSegmenterDelete(image_segmenter);
-
- ASSERT_NE(error, nullptr);
- EXPECT_EQ(error->code, kInvalidArgumentError);
- @@ -147,7 +151,8 @@ TEST_F(ImageSegmenterFromOptionsTest, FailsWithMissingModelPath) {
-
- EXPECT_EQ(image_segmenter, nullptr);
-
- - if (image_segmenter) TfLiteImageSegmenterDelete(image_segmenter);
- + if (image_segmenter)
- + TfLiteImageSegmenterDelete(image_segmenter);
- }
-
- TEST_F(ImageSegmenterFromOptionsTest, FailsWithMissingModelPathAndError) {
- @@ -160,7 +165,8 @@ TEST_F(ImageSegmenterFromOptionsTest, FailsWithMissingModelPathAndError) {
-
- EXPECT_EQ(image_segmenter, nullptr);
-
- - if (image_segmenter) TfLiteImageSegmenterDelete(image_segmenter);
- + if (image_segmenter)
- + TfLiteImageSegmenterDelete(image_segmenter);
-
- ASSERT_NE(error, nullptr);
- EXPECT_EQ(error->code, kInvalidArgumentError);
- @@ -171,8 +177,8 @@ TEST_F(ImageSegmenterFromOptionsTest, FailsWithMissingModelPathAndError) {
- }
-
- TEST_F(ImageSegmenterFromOptionsTest, SucceedsWithModelPath) {
- - std::string model_path = JoinPath("./" /*test src dir*/,
- - kTestDataDirectory, kDeepLabV3);
- + std::string model_path =
- + JoinPath("./" /*test src dir*/, kTestDataDirectory, kDeepLabV3);
-
- TfLiteImageSegmenterOptions options = TfLiteImageSegmenterOptionsCreate();
- options.base_options.model_file.file_path = model_path.data();
- @@ -186,8 +192,8 @@ TEST_F(ImageSegmenterFromOptionsTest, SucceedsWithModelPath) {
- }
-
- TEST_F(ImageSegmenterFromOptionsTest, SucceedsWithNumberOfThreadsAndError) {
- - std::string model_path = JoinPath("./" /*test src dir*/,
- - kTestDataDirectory, kDeepLabV3);
- + std::string model_path =
- + JoinPath("./" /*test src dir*/, kTestDataDirectory, kDeepLabV3);
-
- TfLiteImageSegmenterOptions options = TfLiteImageSegmenterOptionsCreate();
- options.base_options.model_file.file_path = model_path.data();
- @@ -200,13 +206,15 @@ TEST_F(ImageSegmenterFromOptionsTest, SucceedsWithNumberOfThreadsAndError) {
- EXPECT_NE(image_segmenter, nullptr);
- EXPECT_EQ(error, nullptr);
-
- - if (image_segmenter) TfLiteImageSegmenterDelete(image_segmenter);
- - if (error) TfLiteSupportErrorDelete(error);
- + if (image_segmenter)
- + TfLiteImageSegmenterDelete(image_segmenter);
- + if (error)
- + TfLiteSupportErrorDelete(error);
- }
-
- TEST_F(ImageSegmenterFromOptionsTest, FailsWithUnspecifiedOutputTypeAndError) {
- - std::string model_path = JoinPath("./" /*test src dir*/,
- - kTestDataDirectory, kDeepLabV3);
- + std::string model_path =
- + JoinPath("./" /*test src dir*/, kTestDataDirectory, kDeepLabV3);
-
- TfLiteImageSegmenterOptions options = TfLiteImageSegmenterOptionsCreate();
- options.base_options.model_file.file_path = model_path.data();
- @@ -219,15 +227,17 @@ TEST_F(ImageSegmenterFromOptionsTest, FailsWithUnspecifiedOutputTypeAndError) {
- EXPECT_EQ(image_segmenter, nullptr);
- EXPECT_NE(error, nullptr);
-
- - if (image_segmenter) TfLiteImageSegmenterDelete(image_segmenter);
- - if (error) TfLiteSupportErrorDelete(error);
- + if (image_segmenter)
- + TfLiteImageSegmenterDelete(image_segmenter);
- + if (error)
- + TfLiteSupportErrorDelete(error);
- }
-
- class ImageSegmenterSegmentTest : public tflite_shims::testing::Test {
- protected:
- void SetUp() override {
- - std::string model_path = JoinPath("./" /*test src dir*/,
- - kTestDataDirectory, kDeepLabV3);
- + std::string model_path =
- + JoinPath("./" /*test src dir*/, kTestDataDirectory, kDeepLabV3);
-
- TfLiteImageSegmenterOptions options = TfLiteImageSegmenterOptionsCreate();
- options.base_options.model_file.file_path = model_path.data();
- @@ -241,7 +251,7 @@ class ImageSegmenterSegmentTest : public tflite_shims::testing::Test {
-
- TEST_F(ImageSegmenterSegmentTest, SucceedsWithCategoryMask) {
- SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data,
- - LoadImage("segmentation_input_rotation0.jpg"));
- + LoadImage("segmentation_input_rotation0.jpg"));
-
- TfLiteFrameBuffer frame_buffer = {
- .format = kRGB,
- @@ -264,7 +274,7 @@ TEST_F(ImageSegmenterSegmentTest, SucceedsWithCategoryMask) {
-
- // Load golden mask output.
- SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData golden_mask,
- - LoadImage("segmentation_golden_rotation0.png"));
- + LoadImage("segmentation_golden_rotation0.png"));
-
- int inconsistent_pixels = 0;
- int num_pixels = golden_mask.height * golden_mask.width;
- @@ -285,8 +295,9 @@ TEST_F(ImageSegmenterSegmentTest, SucceedsWithCategoryMask) {
- }
-
- TEST_F(ImageSegmenterSegmentTest, SucceedsWithCategoryMaskAndOrientation) {
- - SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data,
- - LoadImage("segmentation_input_rotation90_flop.jpg"));
- + SUPPORT_ASSERT_OK_AND_ASSIGN(
- + ImageData image_data,
- + LoadImage("segmentation_input_rotation90_flop.jpg"));
-
- TfLiteFrameBuffer frame_buffer = {
- .format = kRGB,
- @@ -308,8 +319,9 @@ TEST_F(ImageSegmenterSegmentTest, SucceedsWithCategoryMaskAndOrientation) {
- segmentation_result->segmentations[0]);
-
- // Load golden mask output.
- - SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData golden_mask,
- - LoadImage("segmentation_golden_rotation90_flop.png"));
- + SUPPORT_ASSERT_OK_AND_ASSIGN(
- + ImageData golden_mask,
- + LoadImage("segmentation_golden_rotation90_flop.png"));
-
- int inconsistent_pixels = 0;
- int num_pixels = golden_mask.height * golden_mask.width;
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/vision/object_detector_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/vision/object_detector_test.cc
- index 0171e584fdd3d..78d78f5ddb6d1 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/vision/object_detector_test.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/vision/object_detector_test.cc
- @@ -46,8 +46,8 @@ constexpr char kMobileSsdWithMetadata[] =
- "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite";
-
- StatusOr<ImageData> LoadImage(const char* image_name) {
- - return DecodeImageFromFile(JoinPath("./" /*test src dir*/,
- - kTestDataDirectory, image_name));
- + return DecodeImageFromFile(
- + JoinPath("./" /*test src dir*/, kTestDataDirectory, image_name));
- }
-
- void VerifyDetection(const TfLiteDetection& detection,
- @@ -96,7 +96,8 @@ TEST_F(ObjectDetectorFromOptionsTest, FailsWithNullOptionsAndError) {
- TfLiteObjectDetectorFromOptions(nullptr, &error);
-
- EXPECT_EQ(object_detector, nullptr);
- - if (object_detector) TfLiteObjectDetectorDelete(object_detector);
- + if (object_detector)
- + TfLiteObjectDetectorDelete(object_detector);
-
- ASSERT_NE(error, nullptr);
- EXPECT_EQ(error->code, kInvalidArgumentError);
- @@ -111,7 +112,8 @@ TEST_F(ObjectDetectorFromOptionsTest, FailsWithMissingModelPath) {
- TfLiteObjectDetector* object_detector =
- TfLiteObjectDetectorFromOptions(&options, nullptr);
- EXPECT_EQ(object_detector, nullptr);
- - if (object_detector) TfLiteObjectDetectorDelete(object_detector);
- + if (object_detector)
- + TfLiteObjectDetectorDelete(object_detector);
- }
-
- TEST_F(ObjectDetectorFromOptionsTest, FailsWithMissingModelPathAndError) {
- @@ -122,7 +124,8 @@ TEST_F(ObjectDetectorFromOptionsTest, FailsWithMissingModelPathAndError) {
- TfLiteObjectDetectorFromOptions(&options, &error);
-
- EXPECT_EQ(object_detector, nullptr);
- - if (object_detector) TfLiteObjectDetectorDelete(object_detector);
- + if (object_detector)
- + TfLiteObjectDetectorDelete(object_detector);
-
- ASSERT_NE(error, nullptr);
- EXPECT_EQ(error->code, kInvalidArgumentError);
- @@ -133,8 +136,8 @@ TEST_F(ObjectDetectorFromOptionsTest, FailsWithMissingModelPathAndError) {
- }
-
- TEST_F(ObjectDetectorFromOptionsTest, SucceedsWithModelPath) {
- - std::string model_path = JoinPath("./" /*test src dir*/,
- - kTestDataDirectory, kMobileSsdWithMetadata);
- + std::string model_path = JoinPath("./" /*test src dir*/, kTestDataDirectory,
- + kMobileSsdWithMetadata);
- TfLiteObjectDetectorOptions options = TfLiteObjectDetectorOptionsCreate();
- options.base_options.model_file.file_path = model_path.data();
- TfLiteObjectDetector* object_detector =
- @@ -145,8 +148,8 @@ TEST_F(ObjectDetectorFromOptionsTest, SucceedsWithModelPath) {
- }
-
- TEST_F(ObjectDetectorFromOptionsTest, SucceedsWithNumberOfThreadsAndError) {
- - std::string model_path = JoinPath("./" /*test src dir*/,
- - kTestDataDirectory, kMobileSsdWithMetadata);
- + std::string model_path = JoinPath("./" /*test src dir*/, kTestDataDirectory,
- + kMobileSsdWithMetadata);
- TfLiteObjectDetectorOptions options = TfLiteObjectDetectorOptionsCreate();
- options.base_options.model_file.file_path = model_path.data();
- options.base_options.compute_settings.cpu_settings.num_threads = 3;
- @@ -158,14 +161,16 @@ TEST_F(ObjectDetectorFromOptionsTest, SucceedsWithNumberOfThreadsAndError) {
- EXPECT_NE(object_detector, nullptr);
- EXPECT_EQ(error, nullptr);
-
- - if (object_detector) TfLiteObjectDetectorDelete(object_detector);
- - if (error) TfLiteSupportErrorDelete(error);
- + if (object_detector)
- + TfLiteObjectDetectorDelete(object_detector);
- + if (error)
- + TfLiteSupportErrorDelete(error);
- }
-
- TEST_F(ObjectDetectorFromOptionsTest,
- FailsWithClassNameDenyListAndClassNameAllowListAndError) {
- - std::string model_path = JoinPath("./" /*test src dir*/,
- - kTestDataDirectory, kMobileSsdWithMetadata);
- + std::string model_path = JoinPath("./" /*test src dir*/, kTestDataDirectory,
- + kMobileSsdWithMetadata);
-
- TfLiteObjectDetectorOptions options = TfLiteObjectDetectorOptionsCreate();
- options.base_options.model_file.file_path = model_path.data();
- @@ -183,7 +188,8 @@ TEST_F(ObjectDetectorFromOptionsTest,
- TfLiteObjectDetectorFromOptions(&options, &error);
-
- EXPECT_EQ(object_detector, nullptr);
- - if (object_detector) TfLiteObjectDetectorDelete(object_detector);
- + if (object_detector)
- + TfLiteObjectDetectorDelete(object_detector);
-
- ASSERT_NE(error, nullptr);
- EXPECT_EQ(error->code, kInvalidArgumentError);
- @@ -195,7 +201,8 @@ TEST_F(ObjectDetectorFromOptionsTest,
-
- TEST(ObjectDetectorNullDetectorDetectTest,
- FailsWithNullObjectDetectorAndError) {
- - SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data, LoadImage("cats_and_dogs.jpg"));
- + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data,
- + LoadImage("cats_and_dogs.jpg"));
-
- TfLiteSupportError* error = nullptr;
- TfLiteDetectionResult* detection_result =
- @@ -204,7 +211,8 @@ TEST(ObjectDetectorNullDetectorDetectTest,
- ImageDataFree(&image_data);
-
- EXPECT_EQ(detection_result, nullptr);
- - if (detection_result) TfLiteDetectionResultDelete(detection_result);
- + if (detection_result)
- + TfLiteDetectionResultDelete(detection_result);
-
- ASSERT_NE(error, nullptr);
- EXPECT_EQ(error->code, kInvalidArgumentError);
- @@ -217,9 +225,8 @@ TEST(ObjectDetectorNullDetectorDetectTest,
- class ObjectDetectorDetectTest : public tflite_shims::testing::Test {
- protected:
- void SetUp() override {
- - std::string model_path =
- - JoinPath("./" /*test src dir*/, kTestDataDirectory,
- - kMobileSsdWithMetadata);
- + std::string model_path = JoinPath("./" /*test src dir*/, kTestDataDirectory,
- + kMobileSsdWithMetadata);
-
- TfLiteObjectDetectorOptions options = TfLiteObjectDetectorOptionsCreate();
- options.base_options.model_file.file_path = model_path.data();
- @@ -232,7 +239,8 @@ class ObjectDetectorDetectTest : public tflite_shims::testing::Test {
- };
-
- TEST_F(ObjectDetectorDetectTest, SucceedsWithImageData) {
- - SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data, LoadImage("cats_and_dogs.jpg"));
- + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data,
- + LoadImage("cats_and_dogs.jpg"));
-
- TfLiteFrameBuffer frame_buffer = {
- .format = kRGB,
- @@ -251,7 +259,8 @@ TEST_F(ObjectDetectorDetectTest, SucceedsWithImageData) {
- }
-
- TEST_F(ObjectDetectorDetectTest, FailsWithNullFrameBufferAndError) {
- - SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data, LoadImage("cats_and_dogs.jpg"));
- + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data,
- + LoadImage("cats_and_dogs.jpg"));
-
- TfLiteSupportError* error = nullptr;
- TfLiteDetectionResult* detection_result =
- @@ -260,7 +269,8 @@ TEST_F(ObjectDetectorDetectTest, FailsWithNullFrameBufferAndError) {
- ImageDataFree(&image_data);
-
- EXPECT_EQ(detection_result, nullptr);
- - if (detection_result) TfLiteDetectionResultDelete(detection_result);
- + if (detection_result)
- + TfLiteDetectionResultDelete(detection_result);
-
- ASSERT_NE(error, nullptr);
- EXPECT_EQ(error->code, kInvalidArgumentError);
- @@ -271,7 +281,8 @@ TEST_F(ObjectDetectorDetectTest, FailsWithNullFrameBufferAndError) {
- }
-
- TEST_F(ObjectDetectorDetectTest, FailsWithNullImageDataAndError) {
- - SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data, LoadImage("cats_and_dogs.jpg"));
- + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data,
- + LoadImage("cats_and_dogs.jpg"));
- TfLiteSupportError* error = nullptr;
- TfLiteDetectionResult* detection_result =
- TfLiteObjectDetectorDetect(object_detector, nullptr, &error);
- @@ -279,7 +290,8 @@ TEST_F(ObjectDetectorDetectTest, FailsWithNullImageDataAndError) {
- ImageDataFree(&image_data);
-
- EXPECT_EQ(detection_result, nullptr);
- - if (detection_result) TfLiteDetectionResultDelete(detection_result);
- + if (detection_result)
- + TfLiteDetectionResultDelete(detection_result);
-
- ASSERT_NE(error, nullptr);
- EXPECT_EQ(error->code, kInvalidArgumentError);
- @@ -292,8 +304,8 @@ TEST_F(ObjectDetectorDetectTest, FailsWithNullImageDataAndError) {
- TEST(ObjectDetectorWithUserDefinedOptionsDetectorTest,
- SucceedsWithClassNameDenyList) {
- char* denylisted_label_name = (char*)"cat";
- - std::string model_path = JoinPath("./" /*test src dir*/,
- - kTestDataDirectory, kMobileSsdWithMetadata);
- + std::string model_path = JoinPath("./" /*test src dir*/, kTestDataDirectory,
- + kMobileSsdWithMetadata);
-
- TfLiteObjectDetectorOptions options = TfLiteObjectDetectorOptionsCreate();
- options.base_options.model_file.file_path = model_path.data();
- @@ -306,7 +318,8 @@ TEST(ObjectDetectorWithUserDefinedOptionsDetectorTest,
- TfLiteObjectDetectorFromOptions(&options, nullptr);
- ASSERT_NE(object_detector, nullptr);
-
- - SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data, LoadImage("cats_and_dogs.jpg"));
- + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data,
- + LoadImage("cats_and_dogs.jpg"));
-
- TfLiteFrameBuffer frame_buffer = {
- .format = kRGB,
- @@ -318,7 +331,8 @@ TEST(ObjectDetectorWithUserDefinedOptionsDetectorTest,
- TfLiteObjectDetectorDetect(object_detector, &frame_buffer, nullptr);
-
- ImageDataFree(&image_data);
- - if (object_detector) TfLiteObjectDetectorDelete(object_detector);
- + if (object_detector)
- + TfLiteObjectDetectorDelete(object_detector);
-
- ASSERT_NE(detection_result, nullptr);
- EXPECT_GE(detection_result->size, 1);
- @@ -334,8 +348,8 @@ TEST(ObjectDetectorWithUserDefinedOptionsDetectorTest,
- TEST(ObjectDetectorWithUserDefinedOptionsDetectorTest,
- SucceedsWithClassNameAllowList) {
- char* allowlisted_label_name = (char*)"cat";
- - std::string model_path = JoinPath("./" /*test src dir*/,
- - kTestDataDirectory, kMobileSsdWithMetadata)
- + std::string model_path = JoinPath("./" /*test src dir*/, kTestDataDirectory,
- + kMobileSsdWithMetadata)
- .data();
-
- TfLiteObjectDetectorOptions options = TfLiteObjectDetectorOptionsCreate();
- @@ -349,7 +363,8 @@ TEST(ObjectDetectorWithUserDefinedOptionsDetectorTest,
- TfLiteObjectDetectorFromOptions(&options, nullptr);
- ASSERT_NE(object_detector, nullptr);
-
- - SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data, LoadImage("cats_and_dogs.jpg"));
- + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data,
- + LoadImage("cats_and_dogs.jpg"));
-
- TfLiteFrameBuffer frame_buffer = {
- .format = kRGB,
- @@ -361,7 +376,8 @@ TEST(ObjectDetectorWithUserDefinedOptionsDetectorTest,
- TfLiteObjectDetectorDetect(object_detector, &frame_buffer, nullptr);
-
- ImageDataFree(&image_data);
- - if (object_detector) TfLiteObjectDetectorDelete(object_detector);
- + if (object_detector)
- + TfLiteObjectDetectorDelete(object_detector);
-
- ASSERT_NE(detection_result, nullptr);
- EXPECT_GE(detection_result->size, 1);
- @@ -376,8 +392,8 @@ TEST(ObjectDetectorWithUserDefinedOptionsDetectorTest,
-
- TEST(ObjectDetectorWithUserDefinedOptionsDetectorTest,
- SucceedsWithScoreThreshold) {
- - std::string model_path = JoinPath("./" /*test src dir*/,
- - kTestDataDirectory, kMobileSsdWithMetadata)
- + std::string model_path = JoinPath("./" /*test src dir*/, kTestDataDirectory,
- + kMobileSsdWithMetadata)
- .data();
-
- TfLiteObjectDetectorOptions options = TfLiteObjectDetectorOptionsCreate();
- @@ -389,7 +405,8 @@ TEST(ObjectDetectorWithUserDefinedOptionsDetectorTest,
- TfLiteObjectDetectorFromOptions(&options, nullptr);
- ASSERT_NE(object_detector, nullptr);
-
- - SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data, LoadImage("cats_and_dogs.jpg"));
- + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data,
- + LoadImage("cats_and_dogs.jpg"));
-
- TfLiteFrameBuffer frame_buffer = {
- .format = kRGB,
- @@ -401,7 +418,8 @@ TEST(ObjectDetectorWithUserDefinedOptionsDetectorTest,
- TfLiteObjectDetectorDetect(object_detector, &frame_buffer, nullptr);
-
- ImageDataFree(&image_data);
- - if (object_detector) TfLiteObjectDetectorDelete(object_detector);
- + if (object_detector)
- + TfLiteObjectDetectorDelete(object_detector);
-
- ASSERT_NE(detection_result, nullptr);
- EXPECT_EQ(detection_result->size, 1);
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/common.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/common.cc
- index abfef722d6659..09e9a83e07bef 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/common.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/common.cc
- @@ -15,7 +15,7 @@ limitations under the License.
-
- #include "tensorflow_lite_support/cc/common.h"
-
- -#include "absl/strings/cord.h" // from @com_google_absl
- +#include "absl/strings/cord.h" // from @com_google_absl
- #include "absl/strings/str_cat.h" // from @com_google_absl
-
- namespace tflite {
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/common.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/common.h
- index b06e9f58459af..71dd920b86bed 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/common.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/common.h
- @@ -16,7 +16,7 @@ limitations under the License.
- #ifndef TENSORFLOW_LITE_SUPPORT_CC_COMMON_H_
- #define TENSORFLOW_LITE_SUPPORT_CC_COMMON_H_
-
- -#include "absl/status/status.h" // from @com_google_absl
- +#include "absl/status/status.h" // from @com_google_absl
- #include "absl/strings/string_view.h" // from @com_google_absl
-
- namespace tflite {
- @@ -164,7 +164,8 @@ enum class TfLiteSupportStatus {
- // more than returning an object identical to an OK status. See `absl::Status`
- // for more details.
- absl::Status CreateStatusWithPayload(
- - absl::StatusCode canonical_code, absl::string_view message,
- + absl::StatusCode canonical_code,
- + absl::string_view message,
- tflite::support::TfLiteSupportStatus tfls_code =
- tflite::support::TfLiteSupportStatus::kError);
-
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/status_macros.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/status_macros.h
- index 14999ca37b7ac..cb145dbd232c8 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/status_macros.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/status_macros.h
- @@ -18,7 +18,7 @@ limitations under the License.
- #define TENSORFLOW_LITE_SUPPORT_CC_PORT_DEFAULT_STATUS_MACROS_H_
-
- #include "absl/base/optimization.h" // from @com_google_absl
- -#include "absl/status/status.h" // from @com_google_absl
- +#include "absl/status/status.h" // from @com_google_absl
-
- // Evaluates an expression that produces a `absl::Status`. If the status is not
- // ok, returns it from the current function.
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/statusor_internals.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/statusor_internals.h
- index dc04c293c6ffd..81ec3c1ab5f86 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/statusor_internals.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/statusor_internals.h
- @@ -21,8 +21,8 @@ limitations under the License.
- #include <utility>
-
- #include "absl/meta/type_traits.h" // from @com_google_absl
- -#include "absl/status/status.h" // from @com_google_absl
- -#include "absl/utility/utility.h" // from @com_google_absl
- +#include "absl/status/status.h" // from @com_google_absl
- +#include "absl/utility/utility.h" // from @com_google_absl
-
- namespace tflite {
- namespace support {
- @@ -63,7 +63,8 @@ struct IsDirectInitializationAmbiguous
- U>::value,
- std::false_type,
- IsDirectInitializationAmbiguous<
- - T, absl::remove_cv_t<absl::remove_reference_t<U>>>> {};
- + T,
- + absl::remove_cv_t<absl::remove_reference_t<U>>>> {};
-
- template <typename T, typename V>
- struct IsDirectInitializationAmbiguous<T, tflite::support::StatusOr<V>>
- @@ -101,7 +102,8 @@ struct IsForwardingAssignmentAmbiguous
- U>::value,
- std::false_type,
- IsForwardingAssignmentAmbiguous<
- - T, absl::remove_cv_t<absl::remove_reference_t<U>>>> {};
- + T,
- + absl::remove_cv_t<absl::remove_reference_t<U>>>> {};
-
- template <typename T, typename U>
- struct IsForwardingAssignmentAmbiguous<T, tflite::support::StatusOr<U>>
- @@ -136,7 +138,8 @@ template <typename T, typename... Args>
- void PlacementNew(void* p, Args&&... args) {
- #if defined(__GNUC__) && !defined(__clang__)
- // Teach gcc that 'p' cannot be null, fixing code size issues.
- - if (p == nullptr) __builtin_unreachable();
- + if (p == nullptr)
- + __builtin_unreachable();
- #endif
- new (p) T(std::forward<Args>(args)...);
- }
- @@ -207,7 +210,8 @@ class StatusOrData {
- }
-
- StatusOrData& operator=(const StatusOrData& other) {
- - if (this == &other) return *this;
- + if (this == &other)
- + return *this;
- if (other.ok())
- Assign(other.data_);
- else
- @@ -216,7 +220,8 @@ class StatusOrData {
- }
-
- StatusOrData& operator=(StatusOrData&& other) {
- - if (this == &other) return *this;
- + if (this == &other)
- + return *this;
- if (other.ok())
- Assign(std::move(other.data_));
- else
- @@ -295,15 +300,18 @@ class StatusOrData {
- };
-
- void Clear() {
- - if (ok()) data_.~T();
- + if (ok())
- + data_.~T();
- }
-
- void EnsureOk() const {
- - if (ABSL_PREDICT_FALSE(!ok())) Helper::Crash(status_);
- + if (ABSL_PREDICT_FALSE(!ok()))
- + Helper::Crash(status_);
- }
-
- void EnsureNotOk() {
- - if (ABSL_PREDICT_FALSE(ok())) Helper::HandleInvalidStatusCtorArg(&status_);
- + if (ABSL_PREDICT_FALSE(ok()))
- + Helper::HandleInvalidStatusCtorArg(&status_);
- }
-
- // Construct the value (ie. data_) through placement new with the passed
- @@ -362,8 +370,9 @@ struct MoveCtorBase<T, false> {
- MoveCtorBase& operator=(MoveCtorBase&&) = default;
- };
-
- -template <typename T, bool = std::is_copy_constructible<T>::value&&
- - std::is_copy_assignable<T>::value>
- +template <typename T,
- + bool = std::is_copy_constructible<T>::value&&
- + std::is_copy_assignable<T>::value>
- struct CopyAssignBase {
- CopyAssignBase() = default;
- CopyAssignBase(const CopyAssignBase&) = default;
- @@ -381,8 +390,9 @@ struct CopyAssignBase<T, false> {
- CopyAssignBase& operator=(CopyAssignBase&&) = default;
- };
-
- -template <typename T, bool = std::is_move_constructible<T>::value&&
- - std::is_move_assignable<T>::value>
- +template <typename T,
- + bool = std::is_move_constructible<T>::value&&
- + std::is_move_assignable<T>::value>
- struct MoveAssignBase {
- MoveAssignBase() = default;
- MoveAssignBase(const MoveAssignBase&) = default;
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/tflite_wrapper.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/tflite_wrapper.cc
- index 11f9d584cfdd0..4d23efe43bc99 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/tflite_wrapper.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/tflite_wrapper.cc
- @@ -15,7 +15,7 @@ limitations under the License.
-
- #include "tensorflow_lite_support/cc/port/default/tflite_wrapper.h"
-
- -#include "absl/status/status.h" // from @com_google_absl
- +#include "absl/status/status.h" // from @com_google_absl
- #include "absl/strings/str_format.h" // from @com_google_absl
- #include "tensorflow/lite/c/common.h"
- #include "tensorflow/lite/delegates/interpreter_utils.h"
- @@ -310,7 +310,9 @@ absl::Status TfLiteInterpreterWrapper::InvokeWithoutFallback() {
- return absl::OkStatus();
- }
-
- -void TfLiteInterpreterWrapper::Cancel() { cancel_flag_.Set(true); }
- +void TfLiteInterpreterWrapper::Cancel() {
- + cancel_flag_.Set(true);
- +}
-
- void TfLiteInterpreterWrapper::SetTfLiteCancellation() {
- // Create a cancellation check function and set to the TFLite interpreter.
- @@ -323,7 +325,8 @@ void TfLiteInterpreterWrapper::SetTfLiteCancellation() {
- }
-
- absl::Status TfLiteInterpreterWrapper::LoadDelegatePlugin(
- - const std::string& name, const tflite::TFLiteSettings& tflite_settings) {
- + const std::string& name,
- + const tflite::TFLiteSettings& tflite_settings) {
- delegate_plugin_ = DelegatePluginRegistry::CreateByName(
- absl::StrFormat("%sPlugin", name), tflite_settings);
-
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/tflite_wrapper.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/tflite_wrapper.h
- index 9a6fdebd99903..a9deed9f93521 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/tflite_wrapper.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/tflite_wrapper.h
- @@ -19,7 +19,7 @@ limitations under the License.
- #include <string>
- #include <utility>
-
- -#include "absl/status/status.h" // from @com_google_absl
- +#include "absl/status/status.h" // from @com_google_absl
- #include "flatbuffers/flatbuffers.h" // from @flatbuffers
- #include "tensorflow/lite/c/common.h"
- #include "tensorflow/lite/experimental/acceleration/configuration/configuration.pb.h"
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/integral_types.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/integral_types.h
- index 0d808ab24d6cc..dc6183bee693c 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/integral_types.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/integral_types.h
- @@ -37,7 +37,7 @@ typedef unsigned long uword_t;
- #define GG_LL_FORMAT "ll" // As in "%lld". Note that "q" is poor form also.
- #define GG_LL_FORMAT_W L"ll"
-
- -const uint8 kuint8max{0xFF};
- +const uint8 kuint8max{0xFF};
- const uint16 kuint16max{0xFFFF};
- const uint32 kuint32max{0xFFFFFFFF};
- const uint64 kuint64max{GG_ULONGLONG(0xFFFFFFFFFFFFFFFF)};
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/audio_classifier.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/audio_classifier.cc
- index 4b1439dcc0719..4be3e53c11972 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/audio_classifier.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/audio_classifier.cc
- @@ -17,7 +17,7 @@ limitations under the License.
-
- #include <initializer_list>
-
- -#include "absl/status/status.h" // from @com_google_absl
- +#include "absl/status/status.h" // from @com_google_absl
- #include "absl/strings/str_format.h" // from @com_google_absl
- #include "tensorflow/lite/c/c_api_types.h"
- #include "tensorflow_lite_support/cc/common.h"
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/audio_embedder.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/audio_embedder.cc
- index 56acada352121..a01effd031e29 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/audio_embedder.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/audio_embedder.cc
- @@ -29,7 +29,8 @@ namespace audio {
-
- /* static */
- tflite::support::StatusOr<double> AudioEmbedder::CosineSimilarity(
- - const processor::FeatureVector& u, const processor::FeatureVector& v) {
- + const processor::FeatureVector& u,
- + const processor::FeatureVector& v) {
- return processor::EmbeddingPostprocessor::CosineSimilarity(u, v);
- }
-
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/audio_embedder.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/audio_embedder.h
- index f6df6d4d58552..4a139ee8bf82d 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/audio_embedder.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/audio_embedder.h
- @@ -27,9 +27,9 @@ limitations under the License.
- namespace tflite {
- namespace task {
- namespace audio {
- -class AudioEmbedder
- - : public tflite::task::core::BaseTaskApi<
- - tflite::task::processor::EmbeddingResult, const AudioBuffer&> {
- +class AudioEmbedder : public tflite::task::core::BaseTaskApi<
- + tflite::task::processor::EmbeddingResult,
- + const AudioBuffer&> {
- public:
- // Use base class constructor.
- using BaseTaskApi::BaseTaskApi;
- @@ -41,7 +41,8 @@ class AudioEmbedder
- //
- // [1]: https://en.wikipedia.org/wiki/Cosine_similarity
- static tflite::support::StatusOr<double> CosineSimilarity(
- - const processor::FeatureVector& u, const processor::FeatureVector& v);
- + const processor::FeatureVector& u,
- + const processor::FeatureVector& v);
-
- // Creates an AudioEmbedder from the provided options. A non-default
- // OpResolver can be specified in order to support custom Ops or specify a
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/core/audio_buffer.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/core/audio_buffer.h
- index 39110ed8d0b15..d922e48af25bc 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/core/audio_buffer.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/core/audio_buffer.h
- @@ -17,8 +17,8 @@ limitations under the License.
-
- #include <memory>
-
- -#include "absl/memory/memory.h" // from @com_google_absl
- -#include "absl/status/status.h" // from @com_google_absl
- +#include "absl/memory/memory.h" // from @com_google_absl
- +#include "absl/status/status.h" // from @com_google_absl
- #include "absl/strings/str_format.h" // from @com_google_absl
- #include "tensorflow_lite_support/cc/common.h"
- #include "tensorflow_lite_support/cc/port/statusor.h"
- @@ -41,7 +41,8 @@ class AudioBuffer {
- // Factory method for creating an AudioBuffer object. The internal buffer does
- // not take the ownership of the input backing buffer.
- static tflite::support::StatusOr<std::unique_ptr<AudioBuffer>> Create(
- - const float* audio_buffer, int buffer_size,
- + const float* audio_buffer,
- + int buffer_size,
- const AudioFormat& audio_format) {
- return absl::make_unique<AudioBuffer>(audio_buffer, buffer_size,
- audio_format);
- @@ -50,7 +51,8 @@ class AudioBuffer {
- // AudioBuffer for internal use only. Uses the factory method to construct
- // AudioBuffer instance. The internal buffer does not take the ownership of
- // the input backing buffer.
- - AudioBuffer(const float* audio_buffer, int buffer_size,
- + AudioBuffer(const float* audio_buffer,
- + int buffer_size,
- const AudioFormat& audio_format)
- : audio_buffer_(audio_buffer),
- buffer_size_(buffer_size),
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/utils/audio_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/utils/audio_utils.cc
- index 1a27c6b44c1bf..c013759b13ebb 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/utils/audio_utils.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/utils/audio_utils.cc
- @@ -20,7 +20,8 @@ namespace task {
- namespace audio {
-
- tflite::support::StatusOr<AudioBuffer> LoadAudioBufferFromFile(
- - const std::string& wav_file_path, int buffer_size,
- + const std::string& wav_file_path,
- + int buffer_size,
- std::vector<float>* wav_data) {
- std::string contents = ReadFile(wav_file_path);
-
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/utils/audio_utils.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/utils/audio_utils.h
- index 68880c0cb4072..123d5a1f6fbf7 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/utils/audio_utils.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/utils/audio_utils.h
- @@ -34,7 +34,8 @@ namespace audio {
- // object, the user of this function has to make sure that wav_data outlives the
- // returned AudioBuffer object.
- tflite::support::StatusOr<AudioBuffer> LoadAudioBufferFromFile(
- - const std::string& wav_file_path, int buffer_size,
- + const std::string& wav_file_path,
- + int buffer_size,
- std::vector<float>* wav_data);
-
- } // namespace audio
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/utils/wav_io.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/utils/wav_io.cc
- index 3c0ad996a9919..9ae3fbec70543 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/utils/wav_io.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/utils/wav_io.cc
- @@ -27,9 +27,9 @@ limitations under the License.
- #include <fstream>
- #include <limits>
-
- -#include "absl/base/casts.h" // from @com_google_absl
- -#include "absl/status/status.h" // from @com_google_absl
- -#include "absl/strings/str_cat.h" // from @com_google_absl
- +#include "absl/base/casts.h" // from @com_google_absl
- +#include "absl/status/status.h" // from @com_google_absl
- +#include "absl/strings/str_cat.h" // from @com_google_absl
- #include "absl/strings/str_format.h" // from @com_google_absl
- #include "tensorflow_lite_support/cc/port/status_macros.h"
-
- @@ -62,7 +62,9 @@ std::string ReadFile(const std::string filepath) {
-
- // Handles moving the data index forward, validating the arguments, and avoiding
- // overflow or underflow.
- -absl::Status IncrementOffset(int old_offset, size_t increment, size_t max_size,
- +absl::Status IncrementOffset(int old_offset,
- + size_t increment,
- + size_t max_size,
- int* new_offset) {
- if (old_offset < 0) {
- return absl::InvalidArgumentError(
- @@ -87,7 +89,8 @@ absl::Status IncrementOffset(int old_offset, size_t increment, size_t max_size,
- }
-
- absl::Status ExpectText(const std::string& data,
- - const std::string& expected_text, int* offset) {
- + const std::string& expected_text,
- + int* offset) {
- int new_offset;
- RETURN_IF_ERROR(
- IncrementOffset(*offset, expected_text.size(), data.size(), &new_offset));
- @@ -101,8 +104,10 @@ absl::Status ExpectText(const std::string& data,
- return absl::OkStatus();
- }
-
- -absl::Status ReadString(const std::string& data, int expected_length,
- - std::string* value, int* offset) {
- +absl::Status ReadString(const std::string& data,
- + int expected_length,
- + std::string* value,
- + int* offset) {
- int new_offset;
- RETURN_IF_ERROR(
- IncrementOffset(*offset, expected_length, data.size(), &new_offset));
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/utils/wav_io.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/utils/wav_io.h
- index 51271fc065c83..9aca5d06f7985 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/utils/wav_io.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/utils/wav_io.h
- @@ -20,9 +20,9 @@ limitations under the License.
-
- #define TENSORFLOW_LITE_SUPPORT_CC_TASK_AUDIO_UTILS_WAV_IO_H_
-
- +#include <cstdint>
- #include <string>
- #include <vector>
- -#include <cstdint>
-
- #include "absl/status/status.h" // from @com_google_absl
- #include "tensorflow_lite_support/cc/port/status_macros.h"
- @@ -64,7 +64,9 @@ absl::Status DecodeLin16WaveAsFloatVector(const std::string& wav_string,
-
- // Handles moving the data index forward, validating the arguments, and avoiding
- // overflow or underflow.
- -absl::Status IncrementOffset(int old_offset, size_t increment, size_t max_size,
- +absl::Status IncrementOffset(int old_offset,
- + size_t increment,
- + size_t max_size,
- int* new_offset);
-
- // This function is only exposed in the header for testing purposes, as a
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/base_task_api.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/base_task_api.h
- index d743383734b42..effd42f0f0336 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/base_task_api.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/base_task_api.h
- @@ -18,7 +18,7 @@ limitations under the License.
-
- #include <utility>
-
- -#include "absl/status/status.h" // from @com_google_absl
- +#include "absl/status/status.h" // from @com_google_absl
- #include "absl/strings/string_view.h" // from @com_google_absl
- #include "tensorflow/lite/c/common.h"
- #include "tensorflow_lite_support/cc/common.h"
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/classification_head.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/classification_head.h
- index c868060f9894a..c91552f7ec82e 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/classification_head.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/classification_head.h
- @@ -18,7 +18,7 @@ limitations under the License.
- #include <string>
- #include <vector>
-
- -#include "absl/memory/memory.h" // from @com_google_absl
- +#include "absl/memory/memory.h" // from @com_google_absl
- #include "absl/strings/string_view.h" // from @com_google_absl
- #include "tensorflow_lite_support/cc/port/statusor.h"
- #include "tensorflow_lite_support/cc/task/core/label_map_item.h"
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/error_reporter.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/error_reporter.cc
- index 80dea95cce24b..a626ce6030b96 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/error_reporter.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/error_reporter.cc
- @@ -35,9 +35,13 @@ int ErrorReporter::Report(const char* format, va_list args) {
- return num_characters;
- }
-
- -std::string ErrorReporter::message() { return last_message_; }
- +std::string ErrorReporter::message() {
- + return last_message_;
- +}
-
- -std::string ErrorReporter::previous_message() { return second_last_message_; }
- +std::string ErrorReporter::previous_message() {
- + return second_last_message_;
- +}
-
- } // namespace core
- } // namespace task
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/external_file_handler.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/external_file_handler.cc
- index 9c4cc2009baea..e15830d5ab061 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/external_file_handler.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/external_file_handler.cc
- @@ -18,11 +18,11 @@ limitations under the License.
- #include <memory>
- #include <string>
-
- -#include "absl/memory/memory.h" // from @com_google_absl
- +#include "absl/memory/memory.h" // from @com_google_absl
- #include "absl/strings/str_format.h" // from @com_google_absl
- #include "tensorflow_lite_support/cc/common.h"
- -#include "tensorflow_lite_support/cc/port/statusor.h"
- #include "tensorflow_lite_support/cc/port/status_macros.h"
- +#include "tensorflow_lite_support/cc/port/statusor.h"
-
- namespace tflite {
- namespace task {
- @@ -57,11 +57,10 @@ absl::Status ExternalFileHandler::MapExternalFile() {
- StatusCode::kInvalidArgument,
- "ExternalFile must specify 'file_content' in Chromium.",
- TfLiteSupportStatus::kInvalidArgumentError);
- -
- }
-
- absl::string_view ExternalFileHandler::GetFileContent() {
- - return external_file_.file_content();
- + return external_file_.file_content();
- }
-
- ExternalFileHandler::~ExternalFileHandler() = default;
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/external_file_handler.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/external_file_handler.h
- index a7daa175f77f5..9f35fdd6d09ce 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/external_file_handler.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/external_file_handler.h
- @@ -18,7 +18,7 @@ limitations under the License.
-
- #include <memory>
-
- -#include "absl/status/status.h" // from @com_google_absl
- +#include "absl/status/status.h" // from @com_google_absl
- #include "absl/strings/string_view.h" // from @com_google_absl
- #include "tensorflow_lite_support/cc/port/integral_types.h"
- #include "tensorflow_lite_support/cc/port/statusor.h"
- @@ -64,7 +64,6 @@ class ExternalFileHandler {
-
- // Reference to the input ExternalFile.
- const ExternalFile& external_file_;
- -
- };
-
- } // namespace core
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/label_map_item.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/label_map_item.cc
- index 694c55ab34e78..72e4b670cb172 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/label_map_item.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/label_map_item.cc
- @@ -15,7 +15,7 @@ limitations under the License.
- #include "tensorflow_lite_support/cc/task/core/label_map_item.h"
-
- #include "absl/strings/str_format.h" // from @com_google_absl
- -#include "absl/strings/str_split.h" // from @com_google_absl
- +#include "absl/strings/str_split.h" // from @com_google_absl
- #include "tensorflow_lite_support/cc/common.h"
-
- namespace tflite {
- @@ -28,7 +28,8 @@ using ::tflite::support::StatusOr;
- using ::tflite::support::TfLiteSupportStatus;
-
- StatusOr<std::vector<LabelMapItem>> BuildLabelMapFromFiles(
- - absl::string_view labels_file, absl::string_view display_names_file) {
- + absl::string_view labels_file,
- + absl::string_view display_names_file) {
- if (labels_file.empty()) {
- return CreateStatusWithPayload(StatusCode::kInvalidArgument,
- "Expected non-empty labels file.",
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/label_map_item.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/label_map_item.h
- index 4d8422a2a572d..d8e1f70d8fab1 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/label_map_item.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/label_map_item.h
- @@ -20,8 +20,8 @@ limitations under the License.
-
- #include "absl/container/flat_hash_map.h" // from @com_google_absl
- #include "absl/container/flat_hash_set.h" // from @com_google_absl
- -#include "absl/status/status.h" // from @com_google_absl
- -#include "absl/strings/string_view.h" // from @com_google_absl
- +#include "absl/status/status.h" // from @com_google_absl
- +#include "absl/strings/string_view.h" // from @com_google_absl
- #include "tensorflow_lite_support/cc/port/statusor.h"
-
- namespace tflite {
- @@ -49,7 +49,8 @@ struct LabelMapItem {
- // Returns an error e.g. if there's a mismatch between the number of labels and
- // display names.
- tflite::support::StatusOr<std::vector<LabelMapItem>> BuildLabelMapFromFiles(
- - absl::string_view labels_file, absl::string_view display_names_file);
- + absl::string_view labels_file,
- + absl::string_view display_names_file);
-
- // A class that represents a hierarchy of labels as specified in a label map.
- //
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/score_calibration.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/score_calibration.cc
- index 818839a77e43d..e7faebad487b9 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/score_calibration.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/score_calibration.cc
- @@ -19,11 +19,11 @@ limitations under the License.
- #include <utility>
- #include <vector>
-
- -#include "absl/status/status.h" // from @com_google_absl
- -#include "absl/strings/str_format.h" // from @com_google_absl
- -#include "absl/strings/str_split.h" // from @com_google_absl
- +#include "absl/status/status.h" // from @com_google_absl
- +#include "absl/strings/str_format.h" // from @com_google_absl
- +#include "absl/strings/str_split.h" // from @com_google_absl
- #include "absl/strings/string_view.h" // from @com_google_absl
- -#include "absl/types/optional.h" // from @com_google_absl
- +#include "absl/types/optional.h" // from @com_google_absl
- #include "tensorflow_lite_support/cc/common.h"
- #include "tensorflow_lite_support/cc/port/status_macros.h"
-
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/score_calibration.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/score_calibration.h
- index c1b945f76ab48..6e2b308bef101 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/score_calibration.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/score_calibration.h
- @@ -23,9 +23,9 @@ limitations under the License.
- #include <vector>
-
- #include "absl/container/flat_hash_map.h" // from @com_google_absl
- -#include "absl/status/status.h" // from @com_google_absl
- -#include "absl/strings/string_view.h" // from @com_google_absl
- -#include "absl/types/optional.h" // from @com_google_absl
- +#include "absl/status/status.h" // from @com_google_absl
- +#include "absl/strings/string_view.h" // from @com_google_absl
- +#include "absl/types/optional.h" // from @com_google_absl
- #include "tensorflow_lite_support/cc/port/statusor.h"
- #include "tensorflow_lite_support/cc/task/core/label_map_item.h"
- #include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
- @@ -37,7 +37,10 @@ namespace core {
- // Sigmoid structure.
- struct Sigmoid {
- Sigmoid() : scale(1.0) {}
- - Sigmoid(std::string label, float slope, float offset, float scale = 1.0,
- + Sigmoid(std::string label,
- + float slope,
- + float offset,
- + float scale = 1.0,
- absl::optional<float> min_uncalibrated_score = absl::nullopt)
- : label(label),
- slope(slope),
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/task_api_factory.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/task_api_factory.h
- index 3d3bc801a6e5d..bbe549a802b39 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/task_api_factory.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/task_api_factory.h
- @@ -18,7 +18,7 @@ limitations under the License.
-
- #include <memory>
-
- -#include "absl/base/macros.h" // from @com_google_absl
- +#include "absl/base/macros.h" // from @com_google_absl
- #include "absl/status/status.h" // from @com_google_absl
- #include "tensorflow/lite/core/api/op_resolver.h"
- #include "tensorflow/lite/kernels/op_macros.h"
- @@ -48,7 +48,8 @@ class TaskAPIFactory {
- "Use CreateFromBaseOptions and configure model input from "
- "tensorflow_lite_support/cc/task/core/proto/base_options.proto")
- static tflite::support::StatusOr<std::unique_ptr<T>> CreateFromBuffer(
- - const char* buffer_data, size_t buffer_size,
- + const char* buffer_data,
- + size_t buffer_size,
- std::unique_ptr<tflite::OpResolver> resolver =
- absl::make_unique<tflite_shims::ops::builtin::BuiltinOpResolver>(),
- int num_threads = 1,
- @@ -156,7 +157,8 @@ class TaskAPIFactory {
- private:
- template <typename T, EnableIfBaseUntypedTaskApiSubclass<T> = nullptr>
- static tflite::support::StatusOr<std::unique_ptr<T>> CreateFromTfLiteEngine(
- - std::unique_ptr<TfLiteEngine> engine, int num_threads,
- + std::unique_ptr<TfLiteEngine> engine,
- + int num_threads,
- const tflite::proto::ComputeSettings& compute_settings =
- tflite::proto::ComputeSettings()) {
- tflite::proto::ComputeSettings settings_copy =
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/task_utils.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/task_utils.h
- index 9c26d154634e1..2c21a95a1b075 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/task_utils.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/task_utils.h
- @@ -21,12 +21,12 @@ limitations under the License.
- #include <numeric>
- #include <vector>
-
- -#include "absl/memory/memory.h" // from @com_google_absl
- -#include "absl/status/status.h" // from @com_google_absl
- -#include "absl/strings/str_cat.h" // from @com_google_absl
- -#include "absl/strings/str_format.h" // from @com_google_absl
- +#include "absl/memory/memory.h" // from @com_google_absl
- +#include "absl/status/status.h" // from @com_google_absl
- +#include "absl/strings/str_cat.h" // from @com_google_absl
- +#include "absl/strings/str_format.h" // from @com_google_absl
- #include "absl/strings/string_view.h" // from @com_google_absl
- -#include "flatbuffers/flatbuffers.h" // from @flatbuffers
- +#include "flatbuffers/flatbuffers.h" // from @flatbuffers
- #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
- #include "tensorflow/lite/kernels/op_macros.h"
- #include "tensorflow/lite/string_util.h"
- @@ -66,9 +66,11 @@ tflite::support::StatusOr<T*> AssertAndReturnTypedTensor(
- // type or has not the same number of elements.
- // Note: std::negation is not used because it is from C++17, where the code will
- // be compiled using C++14 in OSS.
- -template <typename T, typename = std::enable_if_t<
- - std::is_same<T, std::string>::value == false>>
- -inline absl::Status PopulateTensor(const T* data, int num_elements,
- +template <
- + typename T,
- + typename = std::enable_if_t<std::is_same<T, std::string>::value == false>>
- +inline absl::Status PopulateTensor(const T* data,
- + int num_elements,
- TfLiteTensor* tensor) {
- T* v;
- ASSIGN_OR_RETURN(v, AssertAndReturnTypedTensor<T>(tensor));
- @@ -93,7 +95,8 @@ inline absl::Status PopulateTensor(const std::vector<T>& data,
-
- template <>
- inline absl::Status PopulateTensor<std::string>(
- - const std::vector<std::string>& data, TfLiteTensor* tensor) {
- + const std::vector<std::string>& data,
- + TfLiteTensor* tensor) {
- if (tensor->type != kTfLiteString) {
- return tflite::support::CreateStatusWithPayload(
- absl::StatusCode::kInternal,
- @@ -144,7 +147,8 @@ inline absl::Status PopulateVector(const TfLiteTensor* tensor,
-
- template <>
- inline absl::Status PopulateVector<std::string>(
- - const TfLiteTensor* tensor, std::vector<std::string>* data) {
- + const TfLiteTensor* tensor,
- + std::vector<std::string>* data) {
- std::string* v __attribute__((unused));
- ASSIGN_OR_RETURN(v, AssertAndReturnTypedTensor<std::string>(tensor));
- int num = GetStringCount(tensor);
- @@ -160,7 +164,8 @@ inline absl::Status PopulateVector<std::string>(
- // Note: std::negation is not used because it is from C++17, where the code will
- // be compiled using C++14 in OSS.
- template <
- - class TRepeatedField, class T = float,
- + class TRepeatedField,
- + class T = float,
- typename = std::enable_if_t<std::is_same<T, std::string>::value == false>>
- inline absl::Status PopulateVectorToRepeated(const TfLiteTensor* tensor,
- TRepeatedField* data) {
- @@ -236,7 +241,8 @@ int FindTensorIndexByName(
- if (tensor_metadata != nullptr && tensor_metadata->size() == tensors.size()) {
- int index =
- FindTensorIndexByMetadataName(tensor_metadata, metadata_tensor_name);
- - if (index > -1) return index;
- + if (index > -1)
- + return index;
- }
-
- return FindTensorIndexByModelName(tensors, model_tensor_name);
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/tflite_engine.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/tflite_engine.cc
- index 5999090cab973..41e06389af80b 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/tflite_engine.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/tflite_engine.cc
- @@ -17,7 +17,7 @@ limitations under the License.
-
- #include <memory>
-
- -#include "absl/strings/match.h" // from @com_google_absl
- +#include "absl/strings/match.h" // from @com_google_absl
- #include "absl/strings/str_cat.h" // from @com_google_absl
- #include "tensorflow/lite/builtin_ops.h"
- #include "tensorflow/lite/core/shims/cc/kernels/register.h"
- @@ -38,7 +38,8 @@ using ::tflite::support::CreateStatusWithPayload;
- using ::tflite::support::InterpreterCreationResources;
- using ::tflite::support::TfLiteSupportStatus;
-
- -bool TfLiteEngine::Verifier::Verify(const char* data, int length,
- +bool TfLiteEngine::Verifier::Verify(const char* data,
- + int length,
- tflite::ErrorReporter* reporter) {
- return tflite_shims::Verify(data, length, reporter);
- }
- @@ -69,7 +70,8 @@ std::vector<const TfLiteTensor*> TfLiteEngine::GetOutputs() {
- }
-
- void TfLiteEngine::VerifyAndBuildModelFromBuffer(
- - const char* buffer_data, size_t buffer_size,
- + const char* buffer_data,
- + size_t buffer_size,
- TfLiteVerifier* extra_verifier) {
- model_ = tflite_shims::FlatBufferModel::VerifyAndBuildFromBuffer(
- buffer_data, buffer_size, extra_verifier, &error_reporter_);
- @@ -116,7 +118,8 @@ absl::Status TfLiteEngine::InitializeFromModelFileHandler(
- }
-
- absl::Status TfLiteEngine::BuildModelFromFlatBuffer(
- - const char* buffer_data, size_t buffer_size,
- + const char* buffer_data,
- + size_t buffer_size,
- const tflite::proto::ComputeSettings& compute_settings) {
- if (model_) {
- return CreateStatusWithPayload(StatusCode::kInternal,
- @@ -205,7 +208,8 @@ absl::Status TfLiteEngine::InitInterpreter(int num_threads) {
- // absl::Status TfLiteEngine::InitInterpreter(
- // const tflite::proto::ComputeSettings& compute_settings)
- absl::Status TfLiteEngine::InitInterpreter(
- - const tflite::proto::ComputeSettings& compute_settings, int num_threads) {
- + const tflite::proto::ComputeSettings& compute_settings,
- + int num_threads) {
- ComputeSettings settings_copy = ComputeSettings(compute_settings);
- settings_copy.mutable_tflite_settings()
- ->mutable_cpu_settings()
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/tflite_engine.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/tflite_engine.h
- index 53dabdc4841d7..0cbaa738e6db6 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/tflite_engine.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/tflite_engine.h
- @@ -18,8 +18,8 @@ limitations under the License.
-
- #include <memory>
-
- -#include "absl/memory/memory.h" // from @com_google_absl
- -#include "absl/status/status.h" // from @com_google_absl
- +#include "absl/memory/memory.h" // from @com_google_absl
- +#include "absl/status/status.h" // from @com_google_absl
- #include "absl/strings/string_view.h" // from @com_google_absl
- #include "tensorflow/lite/core/api/op_resolver.h"
- #include "tensorflow/lite/core/shims/c/common.h"
- @@ -96,7 +96,8 @@ class TfLiteEngine {
- // object. This performs extra verification on the input data using
- // tflite::Verify.
- absl::Status BuildModelFromFlatBuffer(
- - const char* buffer_data, size_t buffer_size,
- + const char* buffer_data,
- + size_t buffer_size,
- const tflite::proto::ComputeSettings& compute_settings =
- tflite::proto::ComputeSettings());
-
- @@ -138,7 +139,8 @@ class TfLiteEngine {
- // absl::Status TfLiteEngine::InitInterpreter(
- // const tflite::proto::ComputeSettings& compute_settings)
- absl::Status InitInterpreter(
- - const tflite::proto::ComputeSettings& compute_settings, int num_threads);
- + const tflite::proto::ComputeSettings& compute_settings,
- + int num_threads);
-
- // Cancels the on-going `Invoke()` call if any and if possible. This method
- // can be called from a different thread than the one where `Invoke()` is
- @@ -155,7 +157,8 @@ class TfLiteEngine {
- // the FlatBuffer data provided as input.
- class Verifier : public tflite::TfLiteVerifier {
- public:
- - bool Verify(const char* data, int length,
- + bool Verify(const char* data,
- + int length,
- tflite::ErrorReporter* reporter) override;
- };
-
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/audio_preprocessor.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/audio_preprocessor.cc
- index e3ea2b134e3f4..254d0689e5ecc 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/audio_preprocessor.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/audio_preprocessor.cc
- @@ -14,7 +14,7 @@ limitations under the License.
- ==============================================================================*/
- #include "tensorflow_lite_support/cc/task/processor/audio_preprocessor.h"
-
- -#include "absl/status/status.h" // from @com_google_absl
- +#include "absl/status/status.h" // from @com_google_absl
- #include "absl/strings/str_format.h" // from @com_google_absl
- #include "tensorflow_lite_support/cc/common.h"
- #include "tensorflow_lite_support/cc/port/statusor.h"
- @@ -29,7 +29,8 @@ namespace {
- // Looks up AudioProperty from metadata. If no error occurs, the returned value
- // is guaranteed to be valid (not null).
- tflite::support::StatusOr<const AudioProperties*> GetAudioPropertiesSafe(
- - const TensorMetadata* tensor_metadata, int input_index) {
- + const TensorMetadata* tensor_metadata,
- + int input_index) {
- if (tensor_metadata->content() == nullptr ||
- tensor_metadata->content()->content_properties() == nullptr) {
- return CreateStatusWithPayload(
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/classification_postprocessor.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/classification_postprocessor.cc
- index 9c11083c4f839..63962003f5e77 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/classification_postprocessor.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/classification_postprocessor.cc
- @@ -17,7 +17,7 @@ limitations under the License.
-
- #include <memory>
-
- -#include "absl/status/status.h" // from @com_google_absl
- +#include "absl/status/status.h" // from @com_google_absl
- #include "absl/strings/str_format.h" // from @com_google_absl
- #include "tensorflow/lite/c/c_api_types.h"
- #include "tensorflow_lite_support/cc/port/status_macros.h"
- @@ -42,7 +42,8 @@ using ::tflite::task::core::ScoreCalibration;
- /* static */
- tflite::support::StatusOr<std::unique_ptr<ClassificationPostprocessor>>
- ClassificationPostprocessor::Create(
- - core::TfLiteEngine* engine, const std::initializer_list<int> output_indices,
- + core::TfLiteEngine* engine,
- + const std::initializer_list<int> output_indices,
- std::unique_ptr<ClassificationOptions> options) {
- ASSIGN_OR_RETURN(auto processor,
- Processor::Create<ClassificationPostprocessor>(
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/embedding_postprocessor.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/embedding_postprocessor.h
- index 7863e3aa82fb7..f04048d84b4ce 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/embedding_postprocessor.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/embedding_postprocessor.h
- @@ -69,8 +69,8 @@ class EmbeddingPostprocessor : public Postprocessor {
-
- // Performs actual cosine similarity computation.
- template <typename T>
- - static tflite::support::StatusOr<double> ComputeCosineSimilarity(
- - const T* u, const T* v, int num_elements);
- + static tflite::support::StatusOr<double>
- + ComputeCosineSimilarity(const T* u, const T* v, int num_elements);
-
- template <typename T>
- void NormalizeFeatureVector(T* feature_vector) const;
- @@ -146,7 +146,8 @@ void EmbeddingPostprocessor::QuantizeFeatureVector(T* feature_vector) const {
- /* static */
- template <typename T>
- tflite::support::StatusOr<double>
- -EmbeddingPostprocessor::ComputeCosineSimilarity(const T* u, const T* v,
- +EmbeddingPostprocessor::ComputeCosineSimilarity(const T* u,
- + const T* v,
- int num_elements) {
- if (num_elements <= 0) {
- return CreateStatusWithPayload(
- @@ -174,7 +175,8 @@ EmbeddingPostprocessor::ComputeCosineSimilarity(const T* u, const T* v,
- /* static */
- template <typename T>
- tflite::support::StatusOr<double> EmbeddingPostprocessor::CosineSimilarity(
- - const T& u, const T& v) {
- + const T& u,
- + const T& v) {
- if (u.has_value_string() && v.has_value_string()) {
- if (u.value_string().size() != v.value_string().size()) {
- return CreateStatusWithPayload(
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/image_preprocessor.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/image_preprocessor.cc
- index 7ad4ad4703789..310a1f5eba724 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/image_preprocessor.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/image_preprocessor.cc
- @@ -36,7 +36,8 @@ using ::tflite::task::vision::FrameBuffer;
- /* static */
- tflite::support::StatusOr<std::unique_ptr<ImagePreprocessor>>
- ImagePreprocessor::Create(
- - core::TfLiteEngine* engine, const std::initializer_list<int> input_indices,
- + core::TfLiteEngine* engine,
- + const std::initializer_list<int> input_indices,
- const vision::FrameBufferUtils::ProcessEngine& process_engine) {
- ASSIGN_OR_RETURN(auto processor,
- Processor::Create<ImagePreprocessor>(
- @@ -49,7 +50,8 @@ ImagePreprocessor::Create(
-
- // Returns false if image preprocessing could be skipped, true otherwise.
- bool ImagePreprocessor::IsImagePreprocessingNeeded(
- - const FrameBuffer& frame_buffer, const BoundingBox& roi) {
- + const FrameBuffer& frame_buffer,
- + const BoundingBox& roi) {
- // Is crop required?
- if (roi.origin_x() != 0 || roi.origin_y() != 0 ||
- roi.width() != frame_buffer.dimension().width ||
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/processor.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/processor.h
- index 4aad40b2afd97..b3c43605ac82e 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/processor.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/processor.h
- @@ -18,7 +18,7 @@ limitations under the License.
- #include <initializer_list>
- #include <vector>
-
- -#include "absl/status/status.h" // from @com_google_absl
- +#include "absl/status/status.h" // from @com_google_absl
- #include "absl/strings/str_format.h" // from @com_google_absl
- #include "tensorflow/lite/core/shims/c/common.h"
- #include "tensorflow_lite_support/cc/common.h"
- @@ -52,7 +52,8 @@ class Processor {
- // num_expected_tensors, engine, tensor_indices);
- template <typename T, EnableIfProcessorSubclass<T> = nullptr>
- static tflite::support::StatusOr<std::unique_ptr<T>> Create(
- - int num_expected_tensors, tflite::task::core::TfLiteEngine* engine,
- + int num_expected_tensors,
- + tflite::task::core::TfLiteEngine* engine,
- const std::initializer_list<int> tensor_indices,
- bool requires_metadata = true) {
- auto processor = absl::make_unique<T>(engine, tensor_indices);
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/regex_preprocessor.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/regex_preprocessor.cc
- index af923b4d6f2c1..58b77b6952de1 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/regex_preprocessor.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/regex_preprocessor.cc
- @@ -55,7 +55,8 @@ StatusOr<absl::string_view> CheckAndLoadFirstAssociatedFile(
-
- /* static */
- StatusOr<std::unique_ptr<RegexPreprocessor>> RegexPreprocessor::Create(
- - tflite::task::core::TfLiteEngine* engine, int input_tensor_index) {
- + tflite::task::core::TfLiteEngine* engine,
- + int input_tensor_index) {
- ASSIGN_OR_RETURN(auto processor, Processor::Create<RegexPreprocessor>(
- /* num_expected_tensors = */ 1, engine,
- {input_tensor_index},
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/regex_preprocessor.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/regex_preprocessor.h
- index 1f92bcc18e524..bdd4e5e207a12 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/regex_preprocessor.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/regex_preprocessor.h
- @@ -34,7 +34,8 @@ namespace processor {
- class RegexPreprocessor : public TextPreprocessor {
- public:
- static tflite::support::StatusOr<std::unique_ptr<RegexPreprocessor>> Create(
- - tflite::task::core::TfLiteEngine* engine, int input_tensor_index);
- + tflite::task::core::TfLiteEngine* engine,
- + int input_tensor_index);
-
- absl::Status Preprocess(const std::string& text);
-
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/search_postprocessor.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/search_postprocessor.cc
- index 730c9919cadee..a2fa1f8243199 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/search_postprocessor.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/search_postprocessor.cc
- @@ -22,17 +22,12 @@ limitations under the License.
- #include <memory>
- #include <vector>
-
- -#include "tensorflow_lite_support/scann_ondevice/cc/core/partitioner.h"
- -#include "tensorflow_lite_support/scann_ondevice/cc/core/processor.h"
- -#include "tensorflow_lite_support/scann_ondevice/cc/core/searcher.h"
- -#include "tensorflow_lite_support/scann_ondevice/cc/core/serialized_searcher.pb.h"
- -#include "tensorflow_lite_support/scann_ondevice/cc/core/top_n_amortized_constant.h"
- -#include "absl/memory/memory.h" // from @com_google_absl
- -#include "absl/status/status.h" // from @com_google_absl
- -#include "absl/strings/str_format.h" // from @com_google_absl
- +#include "Eigen/Core" // from @eigen
- +#include "absl/memory/memory.h" // from @com_google_absl
- +#include "absl/status/status.h" // from @com_google_absl
- +#include "absl/strings/str_format.h" // from @com_google_absl
- #include "absl/strings/string_view.h" // from @com_google_absl
- -#include "absl/types/span.h" // from @com_google_absl
- -#include "Eigen/Core" // from @eigen
- +#include "absl/types/span.h" // from @com_google_absl
- #include "tensorflow_lite_support/cc/common.h"
- #include "tensorflow_lite_support/cc/port/status_macros.h"
- #include "tensorflow_lite_support/cc/port/statusor.h"
- @@ -45,6 +40,11 @@ limitations under the License.
- #include "tensorflow_lite_support/cc/task/processor/proto/search_result.pb.h"
- #include "tensorflow_lite_support/metadata/cc/metadata_extractor.h"
- #include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
- +#include "tensorflow_lite_support/scann_ondevice/cc/core/partitioner.h"
- +#include "tensorflow_lite_support/scann_ondevice/cc/core/processor.h"
- +#include "tensorflow_lite_support/scann_ondevice/cc/core/searcher.h"
- +#include "tensorflow_lite_support/scann_ondevice/cc/core/serialized_searcher.pb.h"
- +#include "tensorflow_lite_support/scann_ondevice/cc/core/top_n_amortized_constant.h"
- #include "tensorflow_lite_support/scann_ondevice/cc/index.h"
- #include "tensorflow_lite_support/scann_ondevice/proto/index_config.pb.h"
-
- @@ -56,16 +56,16 @@ namespace {
-
- constexpr int kNoNeighborId = -1;
-
- +using ::tflite::TensorMetadata;
- +using ::tflite::metadata::ModelMetadataExtractor;
- +using ::tflite::scann_ondevice::Index;
- +using ::tflite::scann_ondevice::IndexConfig;
- using ::tflite::scann_ondevice::core::AsymmetricHashFindNeighbors;
- using ::tflite::scann_ondevice::core::DistanceMeasure;
- using ::tflite::scann_ondevice::core::FloatFindNeighbors;
- using ::tflite::scann_ondevice::core::QueryInfo;
- using ::tflite::scann_ondevice::core::ScannOnDeviceConfig;
- using ::tflite::scann_ondevice::core::TopN;
- -using ::tflite::TensorMetadata;
- -using ::tflite::metadata::ModelMetadataExtractor;
- -using ::tflite::scann_ondevice::Index;
- -using ::tflite::scann_ondevice::IndexConfig;
- using ::tflite::support::CreateStatusWithPayload;
- using ::tflite::support::StatusOr;
- using ::tflite::support::TfLiteSupportStatus;
- @@ -212,7 +212,8 @@ absl::Status ConvertEmbeddingToEigenMatrix(const Embedding& embedding,
-
- /* static */
- StatusOr<std::unique_ptr<SearchPostprocessor>> SearchPostprocessor::Create(
- - TfLiteEngine* engine, int output_index,
- + TfLiteEngine* engine,
- + int output_index,
- std::unique_ptr<SearchOptions> search_options,
- std::unique_ptr<EmbeddingOptions> embedding_options) {
- ASSIGN_OR_RETURN(auto embedding_postprocessor,
- @@ -316,7 +317,8 @@ absl::Status SearchPostprocessor::Init(
- index_config_.scann_config().partitioner().search_fraction())),
- partitioner_->NumPartitions());
- } else {
- - partitioner_ = absl::make_unique<tflite::scann_ondevice::core::NoOpPartitioner>();
- + partitioner_ =
- + absl::make_unique<tflite::scann_ondevice::core::NoOpPartitioner>();
- num_leaves_to_search_ = partitioner_->NumPartitions();
- }
-
- @@ -330,7 +332,8 @@ absl::Status SearchPostprocessor::Init(
- }
-
- absl::Status SearchPostprocessor::QuantizedSearch(
- - Eigen::Ref<Eigen::MatrixXf> query, std::vector<int> leaves_to_search,
- + Eigen::Ref<Eigen::MatrixXf> query,
- + std::vector<int> leaves_to_search,
- absl::Span<TopN> top_n) {
- int dim = index_config_.embedding_dim();
- // Prepare QueryInfo used for all leaves.
- @@ -360,7 +363,8 @@ absl::Status SearchPostprocessor::QuantizedSearch(
- }
-
- absl::Status SearchPostprocessor::LinearSearch(
- - Eigen::Ref<Eigen::MatrixXf> query, std::vector<int> leaves_to_search,
- + Eigen::Ref<Eigen::MatrixXf> query,
- + std::vector<int> leaves_to_search,
- absl::Span<TopN> top_n) {
- int dim = index_config_.embedding_dim();
- for (int leaf_id : leaves_to_search) {
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/search_postprocessor.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/search_postprocessor.h
- index 47c78b64ba2ca..d79bc853148a9 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/search_postprocessor.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/search_postprocessor.h
- @@ -21,14 +21,9 @@ limitations under the License.
- #include <memory>
- #include <vector>
-
- -#include "tensorflow_lite_support/scann_ondevice/cc/core/partitioner.h"
- -#include "tensorflow_lite_support/scann_ondevice/cc/core/processor.h"
- -#include "tensorflow_lite_support/scann_ondevice/cc/core/searcher.h"
- -#include "tensorflow_lite_support/scann_ondevice/cc/core/serialized_searcher.pb.h"
- -#include "tensorflow_lite_support/scann_ondevice/cc/core/top_n_amortized_constant.h"
- +#include "Eigen/Core" // from @eigen
- #include "absl/strings/string_view.h" // from @com_google_absl
- -#include "absl/types/span.h" // from @com_google_absl
- -#include "Eigen/Core" // from @eigen
- +#include "absl/types/span.h" // from @com_google_absl
- #include "tensorflow_lite_support/cc/port/statusor.h"
- #include "tensorflow_lite_support/cc/task/core/external_file_handler.h"
- #include "tensorflow_lite_support/cc/task/core/tflite_engine.h"
- @@ -37,6 +32,11 @@ limitations under the License.
- #include "tensorflow_lite_support/cc/task/processor/proto/embedding_options.pb.h"
- #include "tensorflow_lite_support/cc/task/processor/proto/search_options.pb.h"
- #include "tensorflow_lite_support/cc/task/processor/proto/search_result.pb.h"
- +#include "tensorflow_lite_support/scann_ondevice/cc/core/partitioner.h"
- +#include "tensorflow_lite_support/scann_ondevice/cc/core/processor.h"
- +#include "tensorflow_lite_support/scann_ondevice/cc/core/searcher.h"
- +#include "tensorflow_lite_support/scann_ondevice/cc/core/serialized_searcher.pb.h"
- +#include "tensorflow_lite_support/scann_ondevice/cc/core/top_n_amortized_constant.h"
- #include "tensorflow_lite_support/scann_ondevice/cc/index.h"
- #include "tensorflow_lite_support/scann_ondevice/proto/index_config.pb.h"
-
- @@ -55,7 +55,8 @@ namespace processor {
- class SearchPostprocessor : public Postprocessor {
- public:
- static tflite::support::StatusOr<std::unique_ptr<SearchPostprocessor>> Create(
- - tflite::task::core::TfLiteEngine* engine, int output_index,
- + tflite::task::core::TfLiteEngine* engine,
- + int output_index,
- std::unique_ptr<SearchOptions> search_options,
- std::unique_ptr<EmbeddingOptions> embedding_options =
- std::make_unique<EmbeddingOptions>());
- @@ -76,12 +77,14 @@ class SearchPostprocessor : public Postprocessor {
- std::unique_ptr<EmbeddingPostprocessor> embedding_postprocessor,
- std::unique_ptr<SearchOptions> options);
-
- - absl::Status QuantizedSearch(Eigen::Ref<Eigen::MatrixXf> query,
- - std::vector<int> leaves_to_search,
- - absl::Span<tflite::scann_ondevice::core::TopN> top_n);
- - absl::Status LinearSearch(Eigen::Ref<Eigen::MatrixXf> query,
- - std::vector<int> leaves_to_search,
- - absl::Span<tflite::scann_ondevice::core::TopN> top_n);
- + absl::Status QuantizedSearch(
- + Eigen::Ref<Eigen::MatrixXf> query,
- + std::vector<int> leaves_to_search,
- + absl::Span<tflite::scann_ondevice::core::TopN> top_n);
- + absl::Status LinearSearch(
- + Eigen::Ref<Eigen::MatrixXf> query,
- + std::vector<int> leaves_to_search,
- + absl::Span<tflite::scann_ondevice::core::TopN> top_n);
-
- std::unique_ptr<SearchOptions> options_;
-
- @@ -96,8 +99,10 @@ class SearchPostprocessor : public Postprocessor {
- // ScaNN management.
- int num_leaves_to_search_;
- tflite::scann_ondevice::core::DistanceMeasure distance_measure_;
- - std::unique_ptr<tflite::scann_ondevice::core::PartitionerInterface> partitioner_;
- - std::shared_ptr<tflite::scann_ondevice::core::AsymmetricHashQuerier> quantizer_;
- + std::unique_ptr<tflite::scann_ondevice::core::PartitionerInterface>
- + partitioner_;
- + std::shared_ptr<tflite::scann_ondevice::core::AsymmetricHashQuerier>
- + quantizer_;
- };
-
- } // namespace processor
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_clu_annotator.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_clu_annotator.cc
- index f60a556dbbe1b..802facec374f3 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_clu_annotator.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_clu_annotator.cc
- @@ -19,9 +19,9 @@ limitations under the License.
- #include <string>
- #include <utility>
-
- -#include "absl/status/status.h" // from @com_google_absl
- +#include "absl/status/status.h" // from @com_google_absl
- #include "absl/strings/str_format.h" // from @com_google_absl
- -#include "absl/strings/str_split.h" // from @com_google_absl
- +#include "absl/strings/str_split.h" // from @com_google_absl
- #include "tensorflow_lite_support/cc/port/status_macros.h"
- #include "tensorflow_lite_support/cc/task/core/task_api_factory.h"
- #include "tensorflow_lite_support/cc/task/core/task_utils.h"
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_nl_classifier.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_nl_classifier.cc
- index 52c898dacb9ca..d4481cdd17874 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_nl_classifier.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_nl_classifier.cc
- @@ -57,7 +57,8 @@ absl::Status SanityCheckOptions(const BertNLClassifierOptions& options) {
- } // namespace
-
- absl::Status BertNLClassifier::Preprocess(
- - const std::vector<TfLiteTensor*>& input_tensors, const std::string& input) {
- + const std::vector<TfLiteTensor*>& input_tensors,
- + const std::string& input) {
- return preprocessor_->Preprocess(input);
- }
-
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_nl_classifier.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_nl_classifier.h
- index 4151025df917b..bcc9c5a533a3f 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_nl_classifier.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_nl_classifier.h
- @@ -22,7 +22,7 @@ limitations under the License.
- #include <string>
- #include <vector>
-
- -#include "absl/base/macros.h" // from @com_google_absl
- +#include "absl/base/macros.h" // from @com_google_absl
- #include "absl/status/status.h" // from @com_google_absl
- #include "tensorflow/lite/c/common.h"
- #include "tensorflow/lite/core/api/op_resolver.h"
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_question_answerer.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_question_answerer.cc
- index 6b37649d4fbfd..b886e3b362902 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_question_answerer.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_question_answerer.cc
- @@ -15,7 +15,7 @@ limitations under the License.
-
- #include "tensorflow_lite_support/cc/task/text/bert_question_answerer.h"
-
- -#include "absl/status/status.h" // from @com_google_absl
- +#include "absl/status/status.h" // from @com_google_absl
- #include "absl/strings/str_join.h" // from @com_google_absl
- #include "absl/strings/str_split.h" // from @com_google_absl
- #include "tensorflow/lite/core/shims/cc/kernels/register.h"
- @@ -111,7 +111,8 @@ StatusOr<std::unique_ptr<QuestionAnswerer>> BertQuestionAnswerer::CreateFromFd(
-
- StatusOr<std::unique_ptr<QuestionAnswerer>>
- BertQuestionAnswerer::CreateBertQuestionAnswererFromFile(
- - const std::string& path_to_model, const std::string& path_to_vocab) {
- + const std::string& path_to_model,
- + const std::string& path_to_vocab) {
- std::unique_ptr<BertQuestionAnswerer> api_to_init;
- ASSIGN_OR_RETURN(
- api_to_init,
- @@ -125,8 +126,10 @@ BertQuestionAnswerer::CreateBertQuestionAnswererFromFile(
-
- StatusOr<std::unique_ptr<QuestionAnswerer>>
- BertQuestionAnswerer::CreateBertQuestionAnswererFromBuffer(
- - const char* model_buffer_data, size_t model_buffer_size,
- - const char* vocab_buffer_data, size_t vocab_buffer_size) {
- + const char* model_buffer_data,
- + size_t model_buffer_size,
- + const char* vocab_buffer_data,
- + size_t vocab_buffer_size) {
- std::unique_ptr<BertQuestionAnswerer> api_to_init;
- ASSIGN_OR_RETURN(
- api_to_init,
- @@ -141,7 +144,8 @@ BertQuestionAnswerer::CreateBertQuestionAnswererFromBuffer(
-
- StatusOr<std::unique_ptr<QuestionAnswerer>>
- BertQuestionAnswerer::CreateAlbertQuestionAnswererFromFile(
- - const std::string& path_to_model, const std::string& path_to_spmodel) {
- + const std::string& path_to_model,
- + const std::string& path_to_spmodel) {
- std::unique_ptr<BertQuestionAnswerer> api_to_init;
- ASSIGN_OR_RETURN(
- api_to_init,
- @@ -155,8 +159,10 @@ BertQuestionAnswerer::CreateAlbertQuestionAnswererFromFile(
-
- StatusOr<std::unique_ptr<QuestionAnswerer>>
- BertQuestionAnswerer::CreateAlbertQuestionAnswererFromBuffer(
- - const char* model_buffer_data, size_t model_buffer_size,
- - const char* spmodel_buffer_data, size_t spmodel_buffer_size) {
- + const char* model_buffer_data,
- + size_t model_buffer_size,
- + const char* spmodel_buffer_data,
- + size_t spmodel_buffer_size) {
- std::unique_ptr<BertQuestionAnswerer> api_to_init;
- ASSIGN_OR_RETURN(
- api_to_init,
- @@ -170,14 +176,16 @@ BertQuestionAnswerer::CreateAlbertQuestionAnswererFromBuffer(
- }
-
- std::vector<QaAnswer> BertQuestionAnswerer::Answer(
- - const std::string& context, const std::string& question) {
- + const std::string& context,
- + const std::string& question) {
- // The BertQuestionAnswererer implementation for Preprocess() and
- // Postprocess() never returns errors: just call value().
- return Infer(context, question).value();
- }
-
- absl::Status BertQuestionAnswerer::Preprocess(
- - const std::vector<TfLiteTensor*>& input_tensors, const std::string& context,
- + const std::vector<TfLiteTensor*>& input_tensors,
- + const std::string& context,
- const std::string& query) {
- auto* input_tensor_metadatas =
- GetMetadataExtractor()->GetInputTensorMetadata();
- @@ -392,7 +400,8 @@ void BertQuestionAnswerer::InitializeBertTokenizer(
- }
-
- void BertQuestionAnswerer::InitializeBertTokenizerFromBinary(
- - const char* vocab_buffer_data, size_t vocab_buffer_size) {
- + const char* vocab_buffer_data,
- + size_t vocab_buffer_size) {
- tokenizer_ =
- absl::make_unique<BertTokenizer>(vocab_buffer_data, vocab_buffer_size);
- }
- @@ -403,7 +412,8 @@ void BertQuestionAnswerer::InitializeSentencepieceTokenizer(
- }
-
- void BertQuestionAnswerer::InitializeSentencepieceTokenizerFromBinary(
- - const char* spmodel_buffer_data, size_t spmodel_buffer_size) {
- + const char* spmodel_buffer_data,
- + size_t spmodel_buffer_size) {
- tokenizer_ = absl::make_unique<SentencePieceTokenizer>(spmodel_buffer_data,
- spmodel_buffer_size);
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_question_answerer.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_question_answerer.h
- index f041cc8e51637..52ec835371386 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_question_answerer.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_question_answerer.h
- @@ -16,9 +16,9 @@ limitations under the License.
- #ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_QA_BERT_QUESTION_ANSWERER_H_
- #define TENSORFLOW_LITE_SUPPORT_CC_TASK_QA_BERT_QUESTION_ANSWERER_H_
-
- -#include "absl/base/macros.h" // from @com_google_absl
- +#include "absl/base/macros.h" // from @com_google_absl
- #include "absl/container/flat_hash_map.h" // from @com_google_absl
- -#include "absl/status/status.h" // from @com_google_absl
- +#include "absl/status/status.h" // from @com_google_absl
- #include "tensorflow_lite_support/cc/port/statusor.h"
- #include "tensorflow_lite_support/cc/task/core/base_task_api.h"
- #include "tensorflow_lite_support/cc/task/core/task_api_factory.h"
- @@ -136,7 +136,8 @@ class BertQuestionAnswerer : public QuestionAnswerer {
- void InitializeSentencepieceTokenizer(const std::string& path_to_spmodel);
- // Initialize API with a SentencepieceTokenizer from the model buffer.
- void InitializeSentencepieceTokenizerFromBinary(
- - const char* spmodel_buffer_data, size_t spmodel_buffer_size);
- + const char* spmodel_buffer_data,
- + size_t spmodel_buffer_size);
-
- // Initialize the API with the tokenizer set in the metadata.
- absl::Status InitializeFromMetadata(
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/bert_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/bert_utils.cc
- index 0164bf48f156e..dc88aad9c2bdf 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/bert_utils.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/bert_utils.cc
- @@ -17,8 +17,8 @@ limitations under the License.
-
- #include <string>
-
- -#include "absl/status/status.h" // from @com_google_absl
- -#include "absl/strings/ascii.h" // from @com_google_absl
- +#include "absl/status/status.h" // from @com_google_absl
- +#include "absl/strings/ascii.h" // from @com_google_absl
- #include "absl/strings/str_cat.h" // from @com_google_absl
- #include "tensorflow_lite_support/cc/port/status_macros.h"
- #include "tensorflow_lite_support/cc/task/text/clu_lib/constants.h"
- @@ -46,10 +46,13 @@ constexpr int kTurnIdForCurrentUtterance = 0;
- absl::Status BertPreprocessing(
- const tflite::support::text::tokenizer::BertTokenizer* tokenizer,
- const std::vector<absl::string_view>& utterances_in_reverse_order,
- - int max_seq_length, int max_history_turns, std::vector<int>* out_token_ids,
- + int max_seq_length,
- + int max_history_turns,
- + std::vector<int>* out_token_ids,
- std::vector<std::pair<int, int>>* out_token_alignments,
- std::vector<int>* out_token_first_subword_indicators,
- - std::vector<int>* out_segment_id_list, std::vector<int>* out_turn_id_list) {
- + std::vector<int>* out_segment_id_list,
- + std::vector<int>* out_turn_id_list) {
- int cls_id;
- if (!tokenizer->LookupId(kClsToken, &cls_id)) {
- return absl::InternalError(
- @@ -183,7 +186,8 @@ absl::Status BertPreprocessing(
- out_turn_id_list->push_back(turn_id);
-
- // Break if reaching max_seq_length.
- - if (out_token_ids->size() >= max_seq_length) break;
- + if (out_token_ids->size() >= max_seq_length)
- + break;
- }
- if (out_token_ids->size() != out_token_alignments->size()) {
- return absl::InternalError(absl::StrCat(
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/bert_utils.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/bert_utils.h
- index 69d13be6ce114..c3b3f6c4caf78 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/bert_utils.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/bert_utils.h
- @@ -80,10 +80,13 @@ namespace tflite::task::text::clu {
- absl::Status BertPreprocessing(
- const tflite::support::text::tokenizer::BertTokenizer* tokenizer,
- const std::vector<absl::string_view>& utterances_in_reverse_order,
- - int max_seq_length, int max_history_turns, std::vector<int>* out_token_ids,
- + int max_seq_length,
- + int max_history_turns,
- + std::vector<int>* out_token_ids,
- std::vector<std::pair<int, int>>* out_token_alignments,
- std::vector<int>* out_token_first_subword_indicators,
- - std::vector<int>* out_segment_id_list, std::vector<int>* out_turn_id_list);
- + std::vector<int>* out_segment_id_list,
- + std::vector<int>* out_turn_id_list);
-
- } // namespace tflite::task::text::clu
-
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/intent_repr.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/intent_repr.cc
- index b310a0782c69f..037566235cf7c 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/intent_repr.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/intent_repr.cc
- @@ -17,9 +17,9 @@ limitations under the License.
-
- #include <vector>
-
- -#include "absl/status/status.h" // from @com_google_absl
- -#include "absl/strings/str_cat.h" // from @com_google_absl
- -#include "absl/strings/str_split.h" // from @com_google_absl
- +#include "absl/status/status.h" // from @com_google_absl
- +#include "absl/strings/str_cat.h" // from @com_google_absl
- +#include "absl/strings/str_split.h" // from @com_google_absl
- #include "absl/strings/string_view.h" // from @com_google_absl
- #include "tensorflow_lite_support/cc/task/text/clu_lib/constants.h"
-
- @@ -28,7 +28,8 @@ namespace tflite::task::text::clu {
- // IntentRepr
-
- std::string IntentRepr::FullName() const {
- - if (domain_.empty()) return name_;
- + if (domain_.empty())
- + return name_;
- return absl::StrCat(domain_, kNamespaceDelim, name_);
- }
-
- @@ -40,16 +41,19 @@ absl::StatusOr<IntentRepr> IntentRepr::CreateFromFullName(
- if (splits.size() > 2) {
- return absl::InternalError(absl::StrCat("invalid argument: ", full_name));
- }
- - if (splits.size() == 2) ret.domain_ = splits[0];
- + if (splits.size() == 2)
- + ret.domain_ = splits[0];
- ret.name_ = splits[splits.size() - 1];
- return ret;
- }
-
- -IntentRepr IntentRepr::Create(absl::string_view name, absl::string_view domain,
- +IntentRepr IntentRepr::Create(absl::string_view name,
- + absl::string_view domain,
- const bool share_across_domains) {
- IntentRepr ret;
- ret.name_ = std::string(name);
- - if (!share_across_domains) ret.domain_ = std::string(domain);
- + if (!share_across_domains)
- + ret.domain_ = std::string(domain);
- return ret;
- }
-
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/intent_repr.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/intent_repr.h
- index 9084deb1203b4..e040b04d998ea 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/intent_repr.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/intent_repr.h
- @@ -18,7 +18,7 @@ limitations under the License.
-
- #include <string>
-
- -#include "absl/status/statusor.h" // from @com_google_absl
- +#include "absl/status/statusor.h" // from @com_google_absl
- #include "absl/strings/string_view.h" // from @com_google_absl
-
- namespace tflite::task::text::clu {
- @@ -30,7 +30,8 @@ class IntentRepr {
- const std::string& Name() const { return name_; }
- std::string FullName() const;
- static absl::StatusOr<IntentRepr> CreateFromFullName(const absl::string_view);
- - static IntentRepr Create(absl::string_view name, absl::string_view domain,
- + static IntentRepr Create(absl::string_view name,
- + absl::string_view domain,
- const bool share_across_domains);
-
- private:
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/slot_repr.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/slot_repr.cc
- index 114a721ee40ef..dbb0dc2a14263 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/slot_repr.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/slot_repr.cc
- @@ -20,15 +20,15 @@ limitations under the License.
- #include <memory>
- #include <vector>
-
- -#include "absl/status/status.h" // from @com_google_absl
- -#include "absl/status/statusor.h" // from @com_google_absl
- -#include "absl/strings/match.h" // from @com_google_absl
- -#include "absl/strings/str_cat.h" // from @com_google_absl
- -#include "absl/strings/str_split.h" // from @com_google_absl
- +#include "absl/status/status.h" // from @com_google_absl
- +#include "absl/status/statusor.h" // from @com_google_absl
- +#include "absl/strings/match.h" // from @com_google_absl
- +#include "absl/strings/str_cat.h" // from @com_google_absl
- +#include "absl/strings/str_split.h" // from @com_google_absl
- #include "absl/strings/string_view.h" // from @com_google_absl
- -#include "absl/strings/strip.h" // from @com_google_absl
- -#include "absl/strings/substitute.h" // from @com_google_absl
- -#include "absl/types/span.h" // from @com_google_absl
- +#include "absl/strings/strip.h" // from @com_google_absl
- +#include "absl/strings/substitute.h" // from @com_google_absl
- +#include "absl/types/span.h" // from @com_google_absl
- #include "tensorflow_lite_support/cc/port/status_macros.h"
- #include "tensorflow_lite_support/cc/task/text/clu_lib/constants.h"
-
- @@ -39,7 +39,8 @@ using ::absl::StatusOr;
- // SlotRepr
-
- std::string SlotRepr::FullName() const {
- - if (domain_.empty()) return name_;
- + if (domain_.empty())
- + return name_;
- return absl::StrCat(domain_, kNamespaceDelim, name_);
- }
-
- @@ -52,14 +53,16 @@ SlotRepr::SplitDomainAndName(const absl::string_view full_name) {
- }
- absl::string_view domain = "";
- absl::string_view name;
- - if (splits.size() == 2) domain = splits[0];
- + if (splits.size() == 2)
- + domain = splits[0];
- name = splits[splits.size() - 1];
- return std::tuple<absl::string_view, absl::string_view>{domain, name};
- }
-
- StatusOr<SlotRepr> SlotRepr::CreateFromIob(const absl::string_view repr) {
- SlotRepr ret;
- - if (IsO(repr)) return ret;
- + if (IsO(repr))
- + return ret;
- absl::string_view full_name;
- if (absl::StartsWith(repr, kSlotBTagPrefix)) {
- full_name = absl::StripPrefix(repr, kSlotBTagPrefix);
- @@ -76,7 +79,8 @@ StatusOr<SlotRepr> SlotRepr::CreateFromIob(const absl::string_view repr) {
- return ret;
- }
-
- -SlotRepr SlotRepr::Create(absl::string_view name, absl::string_view domain,
- +SlotRepr SlotRepr::Create(absl::string_view name,
- + absl::string_view domain,
- const bool share_across_domains) {
- SlotRepr ret;
- ret.name_ = std::string(name);
- @@ -94,7 +98,9 @@ bool SlotRepr::IsB(const absl::string_view repr) {
- return absl::StartsWith(repr, kSlotBTagPrefix);
- }
-
- -bool SlotRepr::IsO(const absl::string_view repr) { return repr == kSlotOTag; }
- +bool SlotRepr::IsO(const absl::string_view repr) {
- + return repr == kSlotOTag;
- +}
-
- bool SlotRepr::operator==(const SlotRepr& other) const {
- return domain_ == other.domain_ && name_ == other.name_;
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/slot_repr.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/slot_repr.h
- index 04ca49b268917..9a5f68a00bdcd 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/slot_repr.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/slot_repr.h
- @@ -20,9 +20,9 @@ limitations under the License.
- #include <utility>
- #include <vector>
-
- -#include "absl/status/status.h" // from @com_google_absl
- -#include "absl/status/statusor.h" // from @com_google_absl
- -#include "absl/strings/str_cat.h" // from @com_google_absl
- +#include "absl/status/status.h" // from @com_google_absl
- +#include "absl/status/statusor.h" // from @com_google_absl
- +#include "absl/strings/str_cat.h" // from @com_google_absl
- #include "absl/strings/string_view.h" // from @com_google_absl
- #include "tensorflow_lite_support/cc/task/text/clu_lib/constants.h"
-
- @@ -68,7 +68,8 @@ class SlotRepr {
- static absl::StatusOr<SlotRepr> CreateFromIob(const absl::string_view);
-
- // Factory
- - static SlotRepr Create(absl::string_view name, absl::string_view domain = "",
- + static SlotRepr Create(absl::string_view name,
- + absl::string_view domain = "",
- const bool share_across_domains = true);
-
- // Splits the full_name into domain and slot name.
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/slot_tagging_output.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/slot_tagging_output.cc
- index 716d29b76a98f..0d5abb443fcc3 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/slot_tagging_output.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/slot_tagging_output.cc
- @@ -17,10 +17,10 @@ limitations under the License.
-
- #include <vector>
-
- -#include "absl/status/status.h" // from @com_google_absl
- -#include "absl/status/statusor.h" // from @com_google_absl
- +#include "absl/status/status.h" // from @com_google_absl
- +#include "absl/status/statusor.h" // from @com_google_absl
- #include "absl/strings/string_view.h" // from @com_google_absl
- -#include "absl/types/span.h" // from @com_google_absl
- +#include "absl/types/span.h" // from @com_google_absl
- #include "tensorflow_lite_support/cc/port/status_macros.h"
- #include "tensorflow_lite_support/cc/task/text/clu_lib/slot_repr.h"
-
- @@ -29,7 +29,9 @@ namespace {
-
- absl::StatusOr<std::vector<SlotMentionStruct>>
- DecodeSlotChunksPredictOnFirstSubword(
- - int cur_turn_start, int cur_turn_end, int seq_len,
- + int cur_turn_start,
- + int cur_turn_end,
- + int seq_len,
- const absl::Span<const absl::string_view> tags_as_span,
- const absl::Span<const float> confidences_as_span,
- const absl::Span<const std::pair<int, int>> token_alignments_as_span,
- @@ -74,10 +76,12 @@ DecodeSlotChunksPredictOnFirstSubword(
- } // namespace
-
- absl::Status SlotModulePopulateResponse(
- - const std::vector<absl::string_view>& tags, const float* confidences,
- + const std::vector<absl::string_view>& tags,
- + const float* confidences,
- const std::vector<std::pair<int, int>>& token_alignments,
- const std::vector<int>& token_turn_ids,
- - const std::vector<int>& first_subword_indicators, float threshold,
- + const std::vector<int>& first_subword_indicators,
- + float threshold,
- const std::vector<absl::string_view>& reverse_utterance_list_to_encode,
- CluResponse* response) {
- if (token_alignments.size() != token_turn_ids.size()) {
- @@ -104,7 +108,7 @@ absl::Status SlotModulePopulateResponse(
-
- // Prepare the data and decode slot chunks.
- std::vector<SlotMentionStruct> cur_turn_slot_mentions;
- - // Decode slot chunks based on first subword tokens in the turn.
- + // Decode slot chunks based on first subword tokens in the turn.
- ASSIGN_OR_RETURN(cur_turn_slot_mentions,
- DecodeSlotChunksPredictOnFirstSubword(
- cur_turn_start, cur_turn_end, seq_len, tags_as_span,
- @@ -113,8 +117,10 @@ absl::Status SlotModulePopulateResponse(
-
- // Populate the response.
- for (const auto& chunk : cur_turn_slot_mentions) {
- - if (chunk.start == -1 || cur_turn_idx != 0) continue;
- - if (chunk.confidence < threshold) continue;
- + if (chunk.start == -1 || cur_turn_idx != 0)
- + continue;
- + if (chunk.confidence < threshold)
- + continue;
- auto slot = response->mutable_noncategorical_slots()->Add();
- slot->set_slot(chunk.repr.Name());
- auto extraction = slot->mutable_extraction();
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/slot_tagging_output.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/slot_tagging_output.h
- index b8fc64425634e..7d2b9a1a1fd27 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/slot_tagging_output.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/slot_tagging_output.h
- @@ -41,10 +41,12 @@ namespace tflite::task::text::clu {
- // Outputs:
- // response
- absl::Status SlotModulePopulateResponse(
- - const std::vector<absl::string_view>& tags, const float* confidences,
- + const std::vector<absl::string_view>& tags,
- + const float* confidences,
- const std::vector<std::pair<int, int>>& token_alignments,
- const std::vector<int>& token_turn_ids,
- - const std::vector<int>& first_subword_indicators, float threshold,
- + const std::vector<int>& first_subword_indicators,
- + float threshold,
- const std::vector<absl::string_view>& reverse_utterance_list_to_encode,
- CluResponse* response);
-
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/tflite_modules.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/tflite_modules.cc
- index c16f5bc02b861..f893f0341c903 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/tflite_modules.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/tflite_modules.cc
- @@ -18,11 +18,11 @@ limitations under the License.
- #include <memory>
- #include <utility>
-
- -#include "absl/status/status.h" // from @com_google_absl
- -#include "absl/status/statusor.h" // from @com_google_absl
- -#include "absl/strings/str_cat.h" // from @com_google_absl
- -#include "absl/strings/str_join.h" // from @com_google_absl
- -#include "absl/strings/str_split.h" // from @com_google_absl
- +#include "absl/status/status.h" // from @com_google_absl
- +#include "absl/status/statusor.h" // from @com_google_absl
- +#include "absl/strings/str_cat.h" // from @com_google_absl
- +#include "absl/strings/str_join.h" // from @com_google_absl
- +#include "absl/strings/str_split.h" // from @com_google_absl
- #include "absl/strings/string_view.h" // from @com_google_absl
- #include "tensorflow/lite/kernels/kernel_util.h"
- #include "tensorflow/lite/string_util.h"
- @@ -39,10 +39,14 @@ namespace tflite::task::text::clu {
- // tensors by concatenating the current utterance with history turns. It also
- // sets utterance_turn_id_seq for post-processing.
- absl::Status PopulateInputTextTensorForBERT(
- - const CluRequest& request, int token_id_tensor_idx,
- - int token_mask_tensor_idx, int token_type_id_tensor_idx,
- + const CluRequest& request,
- + int token_id_tensor_idx,
- + int token_mask_tensor_idx,
- + int token_type_id_tensor_idx,
- const tflite::support::text::tokenizer::BertTokenizer* tokenizer,
- - size_t max_seq_len, int max_history_turns, tflite::Interpreter* interpreter,
- + size_t max_seq_len,
- + int max_history_turns,
- + tflite::Interpreter* interpreter,
- Artifacts* artifacts) {
- size_t seq_len;
- int64_t* tokens_tensor =
- @@ -139,7 +143,8 @@ absl::Status AbstractModule::Init(tflite::Interpreter* interpreter,
- }
-
- absl::StatusOr<std::unique_ptr<AbstractModule>> UtteranceSeqModule::Create(
- - tflite::Interpreter* interpreter, const TensorIndexMap* tensor_index_map,
- + tflite::Interpreter* interpreter,
- + const TensorIndexMap* tensor_index_map,
- const BertCluAnnotatorOptions* options,
- const tflite::support::text::tokenizer::BertTokenizer* tokenizer) {
- auto out = std::make_unique<UtteranceSeqModule>();
- @@ -187,7 +192,8 @@ AbstractModule::NamesAndConfidencesFromOutput(int names_tensor_idx,
- }
-
- absl::StatusOr<std::unique_ptr<AbstractModule>> DomainModule::Create(
- - tflite::Interpreter* interpreter, const TensorIndexMap* tensor_index_map,
- + tflite::Interpreter* interpreter,
- + const TensorIndexMap* tensor_index_map,
- const BertCluAnnotatorOptions* options) {
- auto out = std::make_unique<DomainModule>();
- out->tensor_index_map_ = tensor_index_map;
- @@ -204,7 +210,8 @@ absl::Status DomainModule::Postprocess(Artifacts* artifacts,
- tensor_index_map_->domain_scores_idx));
- const auto& [names, confidences] = t_output;
- for (int i = 0; i < names.size(); ++i) {
- - if (confidences[i] < domain_threshold_) continue;
- + if (confidences[i] < domain_threshold_)
- + continue;
- auto domain = response->add_domains();
- // Conversion to string is needed due to portable_proto generated code
- const std::string names_i(names[i]);
- @@ -215,7 +222,8 @@ absl::Status DomainModule::Postprocess(Artifacts* artifacts,
- }
-
- absl::StatusOr<std::unique_ptr<AbstractModule>> IntentModule::Create(
- - tflite::Interpreter* interpreter, const TensorIndexMap* tensor_index_map,
- + tflite::Interpreter* interpreter,
- + const TensorIndexMap* tensor_index_map,
- const BertCluAnnotatorOptions* options) {
- auto out = std::make_unique<IntentModule>();
- out->tensor_index_map_ = tensor_index_map;
- @@ -239,7 +247,8 @@ absl::Status IntentModule::Postprocess(Artifacts* artifacts,
- std::vector<absl::string_view> parts = absl::StrSplit(name.Name(), '=');
- if (parts.size() == 2) {
- // The name is like 'xxx=yyy'. It's a categorical slot.
- - if (confidences[i] < categorical_slot_threshold_) continue;
- + if (confidences[i] < categorical_slot_threshold_)
- + continue;
- auto new_categorical_slot = response->mutable_categorical_slots()->Add();
-
- const auto slot = std::string(parts[0]);
- @@ -251,7 +260,8 @@ absl::Status IntentModule::Postprocess(Artifacts* artifacts,
- new_categorical_slot_prediction->set_score(confidences[i]);
- } else {
- // It's an intent.
- - if (confidences[i] < intent_threshold_) continue;
- + if (confidences[i] < intent_threshold_)
- + continue;
- auto new_intent = response->mutable_intents()->Add();
- new_intent->set_display_name(name.Name());
- new_intent->set_score(confidences[i]);
- @@ -261,7 +271,8 @@ absl::Status IntentModule::Postprocess(Artifacts* artifacts,
- }
-
- absl::StatusOr<std::unique_ptr<AbstractModule>> SlotModule::Create(
- - tflite::Interpreter* interpreter, const TensorIndexMap* tensor_index_map,
- + tflite::Interpreter* interpreter,
- + const TensorIndexMap* tensor_index_map,
- const BertCluAnnotatorOptions* options) {
- auto out = std::make_unique<SlotModule>();
- out->tensor_index_map_ = tensor_index_map;
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/tflite_modules.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/tflite_modules.h
- index 5a9f183b8ca4e..eecd65fc495bf 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/tflite_modules.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/tflite_modules.h
- @@ -16,7 +16,7 @@ limitations under the License.
- #ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_CLU_LIB_TFLITE_MODULES_H_
- #define TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_CLU_LIB_TFLITE_MODULES_H_
-
- -#include "absl/status/statusor.h" // from @com_google_absl
- +#include "absl/status/statusor.h" // from @com_google_absl
- #include "absl/strings/string_view.h" // from @com_google_absl
- #include "tensorflow/lite/interpreter.h"
- #include "tensorflow_lite_support/cc/task/text/proto/bert_clu_annotator_options_proto_inc.h"
- @@ -85,7 +85,8 @@ class AbstractModule {
- // output tensors.
- // The tensors are assumed to be of shape [1, max_seq_len]
- absl::StatusOr<NamesAndConfidences> NamesAndConfidencesFromOutput(
- - int names_tensor_idx, int scores_tensor_idx) const;
- + int names_tensor_idx,
- + int scores_tensor_idx) const;
-
- // TFLite interpreter
- Interpreter* interpreter_ = nullptr;
- @@ -98,7 +99,8 @@ class AbstractModule {
- class UtteranceSeqModule : public AbstractModule {
- public:
- static absl::StatusOr<std::unique_ptr<AbstractModule>> Create(
- - Interpreter* interpreter, const TensorIndexMap* tensor_index_map,
- + Interpreter* interpreter,
- + const TensorIndexMap* tensor_index_map,
- const BertCluAnnotatorOptions* options,
- const tflite::support::text::tokenizer::BertTokenizer* tokenizer);
-
- @@ -116,7 +118,8 @@ class UtteranceSeqModule : public AbstractModule {
- class DomainModule : public AbstractModule {
- public:
- static absl::StatusOr<std::unique_ptr<AbstractModule>> Create(
- - Interpreter* interpreter, const TensorIndexMap* tensor_index_map,
- + Interpreter* interpreter,
- + const TensorIndexMap* tensor_index_map,
- const BertCluAnnotatorOptions* options);
-
- absl::Status Postprocess(Artifacts* artifacts,
- @@ -130,7 +133,8 @@ class DomainModule : public AbstractModule {
- class IntentModule : public AbstractModule {
- public:
- static absl::StatusOr<std::unique_ptr<AbstractModule>> Create(
- - Interpreter* interpreter, const TensorIndexMap* tensor_index_map,
- + Interpreter* interpreter,
- + const TensorIndexMap* tensor_index_map,
- const BertCluAnnotatorOptions* options);
-
- absl::Status Postprocess(Artifacts* artifacts,
- @@ -145,7 +149,8 @@ class IntentModule : public AbstractModule {
- class SlotModule : public AbstractModule {
- public:
- static absl::StatusOr<std::unique_ptr<AbstractModule>> Create(
- - Interpreter* interpreter, const TensorIndexMap* tensor_index_map,
- + Interpreter* interpreter,
- + const TensorIndexMap* tensor_index_map,
- const BertCluAnnotatorOptions* options);
-
- absl::Status Postprocess(Artifacts* artifacts,
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/tflite_test_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/tflite_test_utils.cc
- index 543958ce93994..30d2bd7513909 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/tflite_test_utils.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/tflite_test_utils.cc
- @@ -24,7 +24,8 @@ namespace tflite::task::text::clu {
-
- template <>
- void PopulateTfLiteTensorValue<std::string>(
- - const std::initializer_list<std::string> values, TfLiteTensor* tensor) {
- + const std::initializer_list<std::string> values,
- + TfLiteTensor* tensor) {
- tflite::DynamicBuffer buf;
- for (const std::string& s : values) {
- buf.AddString(s.data(), s.length());
- @@ -38,13 +39,18 @@ size_t NumTotalFromShape(const std::initializer_list<int>& shape) {
- num_total = 1;
- else
- num_total = 0;
- - for (const int dim : shape) num_total *= dim;
- + for (const int dim : shape)
- + num_total *= dim;
- return num_total;
- }
-
- -TfLiteTensor* UniqueTfLiteTensor::get() { return tensor_; }
- +TfLiteTensor* UniqueTfLiteTensor::get() {
- + return tensor_;
- +}
-
- -UniqueTfLiteTensor::~UniqueTfLiteTensor() { TfLiteTensorFree(tensor_); }
- +UniqueTfLiteTensor::~UniqueTfLiteTensor() {
- + TfLiteTensorFree(tensor_);
- +}
-
- template <>
- TfLiteType TypeToTfLiteType<std::string>() {
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/tflite_test_utils.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/tflite_test_utils.h
- index 3a393c5223369..f19d2366fc092 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/tflite_test_utils.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/clu_lib/tflite_test_utils.h
- @@ -64,7 +64,8 @@ size_t NumTotalFromShape(const std::initializer_list<int>& shape);
-
- template <>
- void PopulateTfLiteTensorValue<std::string>(
- - const std::initializer_list<std::string> values, TfLiteTensor* tensor);
- + const std::initializer_list<std::string> values,
- + TfLiteTensor* tensor);
-
- template <typename T>
- TfLiteType TypeToTfLiteType() {
- @@ -84,7 +85,8 @@ void ReallocDynamicTensor(const std::initializer_list<int> shape,
- TfLiteIntArray* shape_arr = TfLiteIntArrayCreate(shape.size());
- int i = 0;
- const size_t num_total = NumTotalFromShape(shape);
- - for (const int dim : shape) shape_arr->data[i++] = dim;
- + for (const int dim : shape)
- + shape_arr->data[i++] = dim;
- tensor->dims = shape_arr;
- if (tensor->type != kTfLiteString) {
- TfLiteTensorRealloc(num_total * sizeof(T), tensor);
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.cc
- index 376ff58ec0b52..5a2966a70e1a2 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.cc
- @@ -22,10 +22,10 @@ limitations under the License.
- #include <vector>
-
- #include "absl/algorithm/container.h" // from @com_google_absl
- -#include "absl/status/status.h" // from @com_google_absl
- -#include "absl/strings/str_cat.h" // from @com_google_absl
- +#include "absl/status/status.h" // from @com_google_absl
- +#include "absl/strings/str_cat.h" // from @com_google_absl
- #include "absl/strings/string_view.h" // from @com_google_absl
- -#include "flatbuffers/flatbuffers.h" // from @flatbuffers
- +#include "flatbuffers/flatbuffers.h" // from @flatbuffers
- #include "tensorflow/lite/c/common.h"
- #include "tensorflow/lite/core/api/op_resolver.h"
- #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
- @@ -127,7 +127,8 @@ StatusOr<std::vector<Category>> NLClassifier::ClassifyText(
- }
-
- absl::Status NLClassifier::Preprocess(
- - const std::vector<TfLiteTensor*>& input_tensors, const std::string& input) {
- + const std::vector<TfLiteTensor*>& input_tensors,
- + const std::string& input) {
- return preprocessor_->Preprocess(input);
- }
-
- @@ -307,7 +308,8 @@ StatusOr<std::unique_ptr<NLClassifier>> NLClassifier::CreateFromOptions(
-
- StatusOr<std::unique_ptr<NLClassifier>>
- NLClassifier::CreateFromBufferAndOptions(
- - const char* model_buffer_data, size_t model_buffer_size,
- + const char* model_buffer_data,
- + size_t model_buffer_size,
- const NLClassifierOptions& options,
- std::unique_ptr<tflite::OpResolver> resolver) {
- std::unique_ptr<NLClassifier> nl_classifier;
- @@ -320,7 +322,8 @@ NLClassifier::CreateFromBufferAndOptions(
- }
-
- StatusOr<std::unique_ptr<NLClassifier>> NLClassifier::CreateFromFileAndOptions(
- - const std::string& path_to_model, const NLClassifierOptions& options,
- + const std::string& path_to_model,
- + const NLClassifierOptions& options,
- std::unique_ptr<tflite::OpResolver> resolver) {
- std::unique_ptr<NLClassifier> nl_classifier;
- ASSIGN_OR_RETURN(nl_classifier,
- @@ -331,7 +334,8 @@ StatusOr<std::unique_ptr<NLClassifier>> NLClassifier::CreateFromFileAndOptions(
- }
-
- StatusOr<std::unique_ptr<NLClassifier>> NLClassifier::CreateFromFdAndOptions(
- - int fd, const NLClassifierOptions& options,
- + int fd,
- + const NLClassifierOptions& options,
- std::unique_ptr<tflite::OpResolver> resolver) {
- std::unique_ptr<NLClassifier> nl_classifier;
- ASSIGN_OR_RETURN(nl_classifier,
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h
- index b7af66044b129..68ddc4b5312b7 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h
- @@ -23,8 +23,8 @@ limitations under the License.
- #include <string>
- #include <vector>
-
- -#include "absl/base/macros.h" // from @com_google_absl
- -#include "absl/status/status.h" // from @com_google_absl
- +#include "absl/base/macros.h" // from @com_google_absl
- +#include "absl/status/status.h" // from @com_google_absl
- #include "flatbuffers/flatbuffers.h" // from @flatbuffers
- #include "tensorflow/lite/c/common.h"
- #include "tensorflow/lite/core/api/op_resolver.h"
- @@ -109,7 +109,8 @@ class NLClassifier : public core::BaseTaskApi<std::vector<core::Category>,
- ABSL_DEPRECATED("Prefer using `CreateFromOptions`")
- static tflite::support::StatusOr<std::unique_ptr<NLClassifier>>
- CreateFromBufferAndOptions(
- - const char* model_buffer_data, size_t model_buffer_size,
- + const char* model_buffer_data,
- + size_t model_buffer_size,
- const NLClassifierOptions& options = {},
- std::unique_ptr<tflite::OpResolver> resolver =
- absl::make_unique<tflite_shims::ops::builtin::BuiltinOpResolver>());
- @@ -118,7 +119,8 @@ class NLClassifier : public core::BaseTaskApi<std::vector<core::Category>,
- ABSL_DEPRECATED("Prefer using `CreateFromOptions`")
- static tflite::support::StatusOr<std::unique_ptr<NLClassifier>>
- CreateFromFileAndOptions(
- - const std::string& path_to_model, const NLClassifierOptions& options = {},
- + const std::string& path_to_model,
- + const NLClassifierOptions& options = {},
- std::unique_ptr<tflite::OpResolver> resolver =
- absl::make_unique<tflite_shims::ops::builtin::BuiltinOpResolver>());
-
- @@ -126,7 +128,8 @@ class NLClassifier : public core::BaseTaskApi<std::vector<core::Category>,
- ABSL_DEPRECATED("Prefer using `CreateFromOptions`")
- static tflite::support::StatusOr<std::unique_ptr<NLClassifier>>
- CreateFromFdAndOptions(
- - int fd, const NLClassifierOptions& options = {},
- + int fd,
- + const NLClassifierOptions& options = {},
- std::unique_ptr<tflite::OpResolver> resolver =
- absl::make_unique<tflite_shims::ops::builtin::BuiltinOpResolver>());
-
- @@ -182,7 +185,8 @@ class NLClassifier : public core::BaseTaskApi<std::vector<core::Category>,
- const std::vector<TensorType*>& tensors,
- const flatbuffers::Vector<flatbuffers::Offset<TensorMetadata>>*
- metadata_array,
- - const std::string& name, int index) {
- + const std::string& name,
- + int index) {
- int tensor_index = FindTensorIndex(tensors, metadata_array, name, index);
- return tensor_index >= 0 && tensor_index < tensors.size()
- ? tensors[tensor_index]
- @@ -197,7 +201,8 @@ class NLClassifier : public core::BaseTaskApi<std::vector<core::Category>,
- const std::vector<TensorType*>& tensors,
- const flatbuffers::Vector<flatbuffers::Offset<TensorMetadata>>*
- metadata_array,
- - const std::string& name, int default_index) {
- + const std::string& name,
- + int default_index) {
- if (metadata_array != nullptr && metadata_array->size() == tensors.size()) {
- for (size_t i = 0; i < metadata_array->size(); i++) {
- if (strcmp(name.data(), metadata_array->Get(i)->name()->c_str()) == 0) {
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/proto/text_searcher_options.proto b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/proto/text_searcher_options.proto
- index ebce50cbe5491..ed4c2db81dd01 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/proto/text_searcher_options.proto
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/proto/text_searcher_options.proto
- @@ -21,7 +21,6 @@ import "tensorflow_lite_support/cc/task/core/proto/base_options.proto";
- import "tensorflow_lite_support/cc/task/processor/proto/embedding_options.proto";
- import "tensorflow_lite_support/cc/task/processor/proto/search_options.proto";
-
- -
- // Options for setting up an TextSearcher.
- // Next Id: 4.
- message TextSearcherOptions {
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/question_answerer.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/question_answerer.h
- index 4cde4329a716b..df21662a40e3a 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/question_answerer.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/question_answerer.h
- @@ -45,9 +45,9 @@ struct QaAnswer {
- };
-
- // Interface for an Question-Answer API.
- -class QuestionAnswerer
- - : public core::BaseTaskApi<std::vector<QaAnswer>, const std::string&,
- - const std::string&> {
- +class QuestionAnswerer : public core::BaseTaskApi<std::vector<QaAnswer>,
- + const std::string&,
- + const std::string&> {
- public:
- explicit QuestionAnswerer(std::unique_ptr<core::TfLiteEngine> engine)
- : BaseTaskApi(std::move(engine)) {}
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/text_embedder.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/text_embedder.cc
- index 7363540797cf2..f7412224cae66 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/text_embedder.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/text_embedder.cc
- @@ -58,7 +58,8 @@ absl::Status SanityCheckOptions(const TextEmbedderOptions& options) {
-
- /* static */
- tflite::support::StatusOr<double> TextEmbedder::CosineSimilarity(
- - const FeatureVector& u, const FeatureVector& v) {
- + const FeatureVector& u,
- + const FeatureVector& v) {
- return processor::EmbeddingPostprocessor::CosineSimilarity(u, v);
- }
-
- @@ -170,7 +171,8 @@ tflite::support::StatusOr<EmbeddingResult> TextEmbedder::Embed(
- }
-
- absl::Status TextEmbedder::Preprocess(
- - const std::vector<TfLiteTensor*>& input_tensors, const std::string& input) {
- + const std::vector<TfLiteTensor*>& input_tensors,
- + const std::string& input) {
- return preprocessor_->Preprocess(input);
- }
-
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/text_embedder.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/text_embedder.h
- index 3d20d558ca9a0..75597bc040468 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/text_embedder.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/text_embedder.h
- @@ -84,7 +84,8 @@ class TextEmbedder
- //
- // [1]: https://en.wikipedia.org/wiki/Cosine_similarity
- static tflite::support::StatusOr<double> CosineSimilarity(
- - const processor::FeatureVector& u, const processor::FeatureVector& v);
- + const processor::FeatureVector& u,
- + const processor::FeatureVector& v);
-
- protected:
- // The options used to build this TextEmbedder.
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/text_searcher.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/text_searcher.h
- index f9f680847ac5b..ca90bb6c0d141 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/text_searcher.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/text_searcher.h
- @@ -19,8 +19,8 @@ limitations under the License.
- #include <memory>
- #include <vector>
-
- -#include "absl/memory/memory.h" // from @com_google_absl
- -#include "absl/status/status.h" // from @com_google_absl
- +#include "absl/memory/memory.h" // from @com_google_absl
- +#include "absl/status/status.h" // from @com_google_absl
- #include "absl/strings/string_view.h" // from @com_google_absl
- #include "tensorflow/lite/c/common.h"
- #include "tensorflow/lite/core/api/op_resolver.h"
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/universal_sentence_encoder_qa.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/universal_sentence_encoder_qa.cc
- index 52b0041039acf..ba6af609c776b 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/universal_sentence_encoder_qa.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/universal_sentence_encoder_qa.cc
- @@ -22,7 +22,7 @@ limitations under the License.
- #include <vector>
-
- #include "absl/container/flat_hash_map.h" // from @com_google_absl
- -#include "absl/status/status.h" // from @com_google_absl
- +#include "absl/status/status.h" // from @com_google_absl
- #include "tensorflow_lite_support/cc/port/statusor.h"
- #include "tensorflow_lite_support/cc/task/core/base_task_api.h"
- #include "tensorflow_lite_support/cc/task/core/task_api_factory.h"
- @@ -169,7 +169,8 @@ StatusOr<FeatureVector> UniversalSentenceEncoderQA::EncodeQuery(
- }
-
- StatusOr<FeatureVector> UniversalSentenceEncoderQA::EncodeResponse(
- - absl::string_view response_text, absl::string_view response_context) {
- + absl::string_view response_text,
- + absl::string_view response_context) {
- if (response_text.empty() && response_context.empty()) {
- return Status(
- StatusCode::kInvalidArgument,
- @@ -190,7 +191,8 @@ StatusOr<float> UniversalSentenceEncoderQA::Similarity(const FeatureVector& a,
- }
-
- std::vector<size_t> UniversalSentenceEncoderQA::Top(
- - const RetrievalOutput& output, size_t k) {
- + const RetrievalOutput& output,
- + size_t k) {
- // Ensure k in [0, total_size).
- // If k == 0, it means that all outputs are ranked.
- if (k == 0) {
- @@ -214,7 +216,8 @@ std::vector<size_t> UniversalSentenceEncoderQA::Top(
- }
-
- Status UniversalSentenceEncoderQA::Preprocess(
- - const std::vector<TfLiteTensor*>& input_tensors, const QAInput& input) {
- + const std::vector<TfLiteTensor*>& input_tensors,
- + const QAInput& input) {
- RETURN_IF_ERROR(
- PopulateTensor(input.query_text, input_tensors[input_indices_[0]]));
- RETURN_IF_ERROR(
- @@ -235,7 +238,8 @@ StatusOr<QAOutput> UniversalSentenceEncoderQA::Postprocess(
- }
-
- internal::QAOutput UniversalSentenceEncoderQA::Run(
- - absl::string_view query_text, absl::string_view response_text,
- + absl::string_view query_text,
- + absl::string_view response_text,
- absl::string_view response_context) {
- QAInput input;
- input.query_text = std::string(query_text);
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/universal_sentence_encoder_qa.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/universal_sentence_encoder_qa.h
- index 3e83c7132c4e7..9b4a58676209c 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/universal_sentence_encoder_qa.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/universal_sentence_encoder_qa.h
- @@ -20,8 +20,8 @@ limitations under the License.
- #include <vector>
-
- #include "absl/container/flat_hash_map.h" // from @com_google_absl
- -#include "absl/status/status.h" // from @com_google_absl
- -#include "absl/strings/str_format.h" // from @com_google_absl
- +#include "absl/status/status.h" // from @com_google_absl
- +#include "absl/strings/str_format.h" // from @com_google_absl
- #include "tensorflow_lite_support/cc/port/statusor.h"
- #include "tensorflow_lite_support/cc/task/core/base_task_api.h"
- #include "tensorflow_lite_support/cc/task/core/task_api_factory.h"
- @@ -88,7 +88,8 @@ class UniversalSentenceEncoderQA
- // Encodes response from the text and/or context.
- // Returns an error, if both text and context are empty.
- tflite::support::StatusOr<FeatureVector> EncodeResponse(
- - absl::string_view response_text, absl::string_view response_context);
- + absl::string_view response_text,
- + absl::string_view response_context);
-
- // Calculates similarity between two encoded vectors (require same size).
- static tflite::support::StatusOr<float> Similarity(const FeatureVector& a,
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/utils/bert_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/utils/bert_utils.cc
- index 1c0a5b01b7789..04bfc2e4f95d7 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/utils/bert_utils.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/utils/bert_utils.cc
- @@ -17,7 +17,7 @@ limitations under the License.
-
- #include <algorithm>
-
- -#include "absl/status/status.h" // from @com_google_absl
- +#include "absl/status/status.h" // from @com_google_absl
- #include "absl/strings/str_format.h" // from @com_google_absl
- #include "tensorflow_lite_support/cc/common.h"
- #include "tensorflow_lite_support/cc/task/core/task_utils.h"
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/base_vision_task_api.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/base_vision_task_api.h
- index 76a03671b54af..d3557fc508c61 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/base_vision_task_api.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/base_vision_task_api.h
- @@ -23,7 +23,7 @@ limitations under the License.
-
- #include "absl/memory/memory.h" // from @com_google_absl
- #include "absl/status/status.h" // from @com_google_absl
- -#include "absl/time/clock.h" // from @com_google_absl
- +#include "absl/time/clock.h" // from @com_google_absl
- #include "tensorflow/lite/c/common.h"
- #include "tensorflow_lite_support/cc/common.h"
- #include "tensorflow_lite_support/cc/port/integral_types.h"
- @@ -45,11 +45,12 @@ namespace vision {
- // Base class providing common logic for vision models.
- template <class OutputType>
- class BaseVisionTaskApi
- - : public tflite::task::core::BaseTaskApi<OutputType, const FrameBuffer&,
- - const BoundingBox&> {
- + : public tflite::task::core::
- + BaseTaskApi<OutputType, const FrameBuffer&, const BoundingBox&> {
- public:
- explicit BaseVisionTaskApi(std::unique_ptr<core::TfLiteEngine> engine)
- - : tflite::task::core::BaseTaskApi<OutputType, const FrameBuffer&,
- + : tflite::task::core::BaseTaskApi<OutputType,
- + const FrameBuffer&,
- const BoundingBox&>(std::move(engine)) {
- }
- // BaseVisionTaskApi is neither copyable nor movable.
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/classification_head.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/classification_head.h
- index 47db0d121d43b..2e1aa6d652967 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/classification_head.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/classification_head.h
- @@ -18,7 +18,7 @@ limitations under the License.
- #include <string>
- #include <vector>
-
- -#include "absl/memory/memory.h" // from @com_google_absl
- +#include "absl/memory/memory.h" // from @com_google_absl
- #include "absl/strings/string_view.h" // from @com_google_absl
- #include "tensorflow_lite_support/cc/port/statusor.h"
- #include "tensorflow_lite_support/cc/task/vision/core/label_map_item.h"
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/frame_buffer.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/frame_buffer.h
- index 1668447393e9e..2936f5acbb921 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/frame_buffer.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/frame_buffer.h
- @@ -22,12 +22,12 @@ limitations under the License.
- #include <utility>
- #include <vector>
-
- -#include "absl/memory/memory.h" // from @com_google_absl
- -#include "absl/status/status.h" // from @com_google_absl
- +#include "absl/memory/memory.h" // from @com_google_absl
- +#include "absl/status/status.h" // from @com_google_absl
- #include "absl/strings/str_cat.h" // from @com_google_absl
- -#include "absl/time/clock.h" // from @com_google_absl
- -#include "absl/time/time.h" // from @com_google_absl
- -#include "absl/types/optional.h" // from @com_google_absl
- +#include "absl/time/clock.h" // from @com_google_absl
- +#include "absl/time/time.h" // from @com_google_absl
- +#include "absl/types/optional.h" // from @com_google_absl
- #include "tensorflow_lite_support/cc/port/integral_types.h"
- #include "tensorflow_lite_support/cc/port/statusor.h"
-
- @@ -74,7 +74,16 @@ namespace vision {
- class FrameBuffer {
- public:
- // Colorspace formats.
- - enum class Format { kRGBA, kRGB, kNV12, kNV21, kYV12, kYV21, kGRAY, kUNKNOWN};
- + enum class Format {
- + kRGBA,
- + kRGB,
- + kNV12,
- + kNV21,
- + kYV12,
- + kYV21,
- + kGRAY,
- + kUNKNOWN
- + };
-
- // Stride information.
- struct Stride {
- @@ -166,7 +175,8 @@ class FrameBuffer {
- // buffers. In a streaming use case (e.g continuous camera stream), the
- // timestamp can be used as an ID to identify a frame.
- static std::unique_ptr<FrameBuffer> Create(const std::vector<Plane>& planes,
- - Dimension dimension, Format format,
- + Dimension dimension,
- + Format format,
- Orientation orientation,
- absl::Time timestamp) {
- return absl::make_unique<FrameBuffer>(planes, dimension, format,
- @@ -177,7 +187,8 @@ class FrameBuffer {
- // backing buffers. In a streaming use case (e.g continuous camera stream),
- // the timestamp can be used as an ID to identify a frame.
- static std::unique_ptr<FrameBuffer> Create(std::vector<Plane>&& planes,
- - Dimension dimension, Format format,
- + Dimension dimension,
- + Format format,
- Orientation orientation,
- absl::Time timestamp) {
- return absl::make_unique<FrameBuffer>(std::move(planes), dimension, format,
- @@ -189,7 +200,8 @@ class FrameBuffer {
- // more suitable for processing use case that does not need to re-identify
- // this buffer.
- static std::unique_ptr<FrameBuffer> Create(const std::vector<Plane>& planes,
- - Dimension dimension, Format format,
- + Dimension dimension,
- + Format format,
- Orientation orientation) {
- return absl::make_unique<FrameBuffer>(planes, dimension, format,
- orientation, absl::Now());
- @@ -200,7 +212,8 @@ class FrameBuffer {
- // method is more suitable for processing use case that does not need to
- // re-identify this buffer.
- static std::unique_ptr<FrameBuffer> Create(std::vector<Plane>&& planes,
- - Dimension dimension, Format format,
- + Dimension dimension,
- + Format format,
- Orientation orientation) {
- return absl::make_unique<FrameBuffer>(std::move(planes), dimension, format,
- orientation, absl::Now());
- @@ -217,8 +230,11 @@ class FrameBuffer {
- // The FrameBuffer does not take ownership of the backing buffer. The backing
- // buffer is read-only and the caller is responsible for maintaining the
- // backing buffer lifecycle for the lifetime of FrameBuffer.
- - FrameBuffer(const std::vector<Plane>& planes, Dimension dimension,
- - Format format, Orientation orientation, absl::Time timestamp)
- + FrameBuffer(const std::vector<Plane>& planes,
- + Dimension dimension,
- + Format format,
- + Orientation orientation,
- + absl::Time timestamp)
- : planes_(planes),
- dimension_(dimension),
- format_(format),
- @@ -230,8 +246,11 @@ class FrameBuffer {
- // The FrameBuffer does not take ownership of the backing buffer. The backing
- // buffer is read-only and the caller is responsible for maintaining the
- // backing buffer lifecycle for the lifetime of FrameBuffer.
- - FrameBuffer(std::vector<Plane>&& planes, Dimension dimension, Format format,
- - Orientation orientation, absl::Time timestamp)
- + FrameBuffer(std::vector<Plane>&& planes,
- + Dimension dimension,
- + Format format,
- + Orientation orientation,
- + absl::Time timestamp)
- : planes_(std::move(planes)),
- dimension_(dimension),
- format_(format),
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/label_map_item.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/label_map_item.cc
- index 9c82b63a10359..67fe07534b52a 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/label_map_item.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/label_map_item.cc
- @@ -16,7 +16,7 @@ limitations under the License.
- #include "tensorflow_lite_support/cc/task/vision/core/label_map_item.h"
-
- #include "absl/strings/str_format.h" // from @com_google_absl
- -#include "absl/strings/str_split.h" // from @com_google_absl
- +#include "absl/strings/str_split.h" // from @com_google_absl
- #include "tensorflow_lite_support/cc/common.h"
-
- namespace tflite {
- @@ -29,7 +29,8 @@ using ::tflite::support::StatusOr;
- using ::tflite::support::TfLiteSupportStatus;
-
- StatusOr<std::vector<LabelMapItem>> BuildLabelMapFromFiles(
- - absl::string_view labels_file, absl::string_view display_names_file) {
- + absl::string_view labels_file,
- + absl::string_view display_names_file) {
- if (labels_file.empty()) {
- return CreateStatusWithPayload(StatusCode::kInvalidArgument,
- "Expected non-empty labels file.",
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/label_map_item.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/label_map_item.h
- index 0fb66f2639806..20c316ba4a992 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/label_map_item.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/label_map_item.h
- @@ -20,8 +20,8 @@ limitations under the License.
-
- #include "absl/container/flat_hash_map.h" // from @com_google_absl
- #include "absl/container/flat_hash_set.h" // from @com_google_absl
- -#include "absl/status/status.h" // from @com_google_absl
- -#include "absl/strings/string_view.h" // from @com_google_absl
- +#include "absl/status/status.h" // from @com_google_absl
- +#include "absl/strings/string_view.h" // from @com_google_absl
- #include "tensorflow_lite_support/cc/port/statusor.h"
-
- namespace tflite {
- @@ -49,7 +49,8 @@ struct LabelMapItem {
- // Returns an error e.g. if there's a mismatch between the number of labels and
- // display names.
- tflite::support::StatusOr<std::vector<LabelMapItem>> BuildLabelMapFromFiles(
- - absl::string_view labels_file, absl::string_view display_names_file);
- + absl::string_view labels_file,
- + absl::string_view display_names_file);
-
- // A class that represents a hierarchy of labels as specified in a label map.
- //
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_classifier.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_classifier.cc
- index aa1e7707dd99b..36ab3c3ca1903 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_classifier.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_classifier.cc
- @@ -16,9 +16,9 @@ limitations under the License.
- #include "tensorflow_lite_support/cc/task/vision/image_classifier.h"
-
- #include "absl/algorithm/container.h" // from @com_google_absl
- -#include "absl/strings/str_format.h" // from @com_google_absl
- +#include "absl/strings/str_format.h" // from @com_google_absl
- #include "absl/strings/string_view.h" // from @com_google_absl
- -#include "flatbuffers/flatbuffers.h" // from @flatbuffers
- +#include "flatbuffers/flatbuffers.h" // from @flatbuffers
- #include "tensorflow_lite_support/cc/common.h"
- #include "tensorflow_lite_support/cc/port/integral_types.h"
- #include "tensorflow_lite_support/cc/port/status_macros.h"
- @@ -146,7 +146,9 @@ absl::Status ImageClassifier::PreInit() {
- return absl::OkStatus();
- }
-
- -absl::Status ImageClassifier::PostInit() { return InitScoreCalibrations(); }
- +absl::Status ImageClassifier::PostInit() {
- + return InitScoreCalibrations();
- +}
-
- absl::Status ImageClassifier::CheckAndSetOutputs() {
- num_outputs_ = TfLiteEngine::OutputCount(GetTfLiteEngine()->interpreter());
- @@ -380,13 +382,15 @@ StatusOr<ClassificationResult> ImageClassifier::Classify(
- }
-
- StatusOr<ClassificationResult> ImageClassifier::Classify(
- - const FrameBuffer& frame_buffer, const BoundingBox& roi) {
- + const FrameBuffer& frame_buffer,
- + const BoundingBox& roi) {
- return InferWithFallback(frame_buffer, roi);
- }
-
- StatusOr<ClassificationResult> ImageClassifier::Postprocess(
- const std::vector<const TfLiteTensor*>& output_tensors,
- - const FrameBuffer& /*frame_buffer*/, const BoundingBox& /*roi*/) {
- + const FrameBuffer& /*frame_buffer*/,
- + const BoundingBox& /*roi*/) {
- if (output_tensors.size() != num_outputs_) {
- return CreateStatusWithPayload(
- StatusCode::kInternal,
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_classifier.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_classifier.h
- index b2f595715e9da..eb0c13ec55c5b 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_classifier.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_classifier.h
- @@ -20,7 +20,7 @@ limitations under the License.
- #include <vector>
-
- #include "absl/container/flat_hash_set.h" // from @com_google_absl
- -#include "absl/status/status.h" // from @com_google_absl
- +#include "absl/status/status.h" // from @com_google_absl
- #include "tensorflow/lite/c/common.h"
- #include "tensorflow/lite/core/api/op_resolver.h"
- #include "tensorflow/lite/core/shims/cc/kernels/register.h"
- @@ -109,7 +109,8 @@ class ImageClassifier : public BaseVisionTaskApi<ClassificationResult> {
- // region of interest is not clamped, so this method will return a non-ok
- // status if the region is out of these bounds.
- tflite::support::StatusOr<ClassificationResult> Classify(
- - const FrameBuffer& frame_buffer, const BoundingBox& roi);
- + const FrameBuffer& frame_buffer,
- + const BoundingBox& roi);
-
- protected:
- // The options used to build this ImageClassifier.
- @@ -123,7 +124,8 @@ class ImageClassifier : public BaseVisionTaskApi<ClassificationResult> {
- // results.
- tflite::support::StatusOr<ClassificationResult> Postprocess(
- const std::vector<const TfLiteTensor*>& output_tensors,
- - const FrameBuffer& frame_buffer, const BoundingBox& roi) override;
- + const FrameBuffer& frame_buffer,
- + const BoundingBox& roi) override;
-
- // Performs sanity checks on the provided ImageClassifierOptions.
- static absl::Status SanityCheckOptions(const ImageClassifierOptions& options);
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_embedder.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_embedder.cc
- index 0ce46fb9f9806..943a39b1f762e 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_embedder.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_embedder.cc
- @@ -18,10 +18,10 @@ limitations under the License.
- #include <algorithm>
-
- #include "absl/container/node_hash_set.h" // from @com_google_absl
- -#include "absl/memory/memory.h" // from @com_google_absl
- -#include "absl/status/status.h" // from @com_google_absl
- -#include "absl/strings/str_format.h" // from @com_google_absl
- -#include "absl/strings/string_view.h" // from @com_google_absl
- +#include "absl/memory/memory.h" // from @com_google_absl
- +#include "absl/status/status.h" // from @com_google_absl
- +#include "absl/strings/str_format.h" // from @com_google_absl
- +#include "absl/strings/string_view.h" // from @com_google_absl
- #include "tensorflow/lite/c/common.h"
- #include "tensorflow_lite_support/cc/common.h"
- #include "tensorflow_lite_support/cc/port/status_macros.h"
- @@ -51,7 +51,8 @@ CreatePostprocessor(core::TfLiteEngine* engine,
-
- /* static */
- tflite::support::StatusOr<double> ImageEmbedder::CosineSimilarity(
- - const FeatureVector& u, const FeatureVector& v) {
- + const FeatureVector& u,
- + const FeatureVector& v) {
- return processor::EmbeddingPostprocessor::CosineSimilarity(u, v);
- }
-
- @@ -118,13 +119,15 @@ tflite::support::StatusOr<EmbeddingResult> ImageEmbedder::Embed(
- }
-
- tflite::support::StatusOr<EmbeddingResult> ImageEmbedder::Embed(
- - const FrameBuffer& frame_buffer, const BoundingBox& roi) {
- + const FrameBuffer& frame_buffer,
- + const BoundingBox& roi) {
- return InferWithFallback(frame_buffer, roi);
- }
-
- tflite::support::StatusOr<EmbeddingResult> ImageEmbedder::Postprocess(
- const std::vector<const TfLiteTensor*>& output_tensors,
- - const FrameBuffer& /*frame_buffer*/, const BoundingBox& /*roi*/) {
- + const FrameBuffer& /*frame_buffer*/,
- + const BoundingBox& /*roi*/) {
- EmbeddingResult result;
- for (int i = 0; i < postprocessors_.size(); ++i) {
- RETURN_IF_ERROR(
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_embedder.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_embedder.h
- index bc321c83d3774..93e2455eebd19 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_embedder.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_embedder.h
- @@ -90,7 +90,8 @@ class ImageEmbedder
- // region of interest. Note that the region of interest is not clamped, so
- // this method will fail if the region is out of bounds of the input image.
- tflite::support::StatusOr<EmbeddingResult> Embed(
- - const FrameBuffer& frame_buffer, const BoundingBox& roi);
- + const FrameBuffer& frame_buffer,
- + const BoundingBox& roi);
-
- // Returns the Embedding output by the output_index'th layer. In (the most
- // common) case where a single embedding is produced, you can just call
- @@ -113,7 +114,8 @@ class ImageEmbedder
- //
- // [1]: https://en.wikipedia.org/wiki/Cosine_similarity
- static tflite::support::StatusOr<double> CosineSimilarity(
- - const FeatureVector& u, const FeatureVector& v);
- + const FeatureVector& u,
- + const FeatureVector& v);
-
- protected:
- // The options used to build this ImageEmbedder.
- @@ -122,7 +124,8 @@ class ImageEmbedder
- // Post-processing to transform the raw model outputs into embedding results.
- tflite::support::StatusOr<EmbeddingResult> Postprocess(
- const std::vector<const TfLiteTensor*>& output_tensors,
- - const FrameBuffer& frame_buffer, const BoundingBox& roi) override;
- + const FrameBuffer& frame_buffer,
- + const BoundingBox& roi) override;
-
- // Performs pre-initialization actions.
- virtual absl::Status PreInit();
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_searcher.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_searcher.cc
- index fb8bdf4f36446..4916290cb1473 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_searcher.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_searcher.cc
- @@ -19,8 +19,8 @@ limitations under the License.
- #include <utility>
- #include <vector>
-
- -#include "absl/memory/memory.h" // from @com_google_absl
- -#include "absl/status/status.h" // from @com_google_absl
- +#include "absl/memory/memory.h" // from @com_google_absl
- +#include "absl/status/status.h" // from @com_google_absl
- #include "absl/strings/string_view.h" // from @com_google_absl
- #include "tensorflow/lite/c/common.h"
- #include "tensorflow/lite/core/api/op_resolver.h"
- @@ -110,7 +110,8 @@ StatusOr<absl::string_view> ImageSearcher::GetUserInfo() {
-
- StatusOr<SearchResult> ImageSearcher::Postprocess(
- const std::vector<const TfLiteTensor*>& /*output_tensors*/,
- - const FrameBuffer& /*frame_buffer*/, const BoundingBox& /*roi*/) {
- + const FrameBuffer& /*frame_buffer*/,
- + const BoundingBox& /*roi*/) {
- return postprocessor_->Postprocess();
- }
-
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_searcher.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_searcher.h
- index 4a510a615ab5b..6b43f8d7736d9 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_searcher.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_searcher.h
- @@ -19,7 +19,7 @@ limitations under the License.
- #include <memory>
- #include <vector>
-
- -#include "absl/status/status.h" // from @com_google_absl
- +#include "absl/status/status.h" // from @com_google_absl
- #include "absl/strings/string_view.h" // from @com_google_absl
- #include "tensorflow/lite/core/api/op_resolver.h"
- #include "tensorflow/lite/core/shims/cc/kernels/register.h"
- @@ -93,7 +93,8 @@ class ImageSearcher
- // region of interest. Note that the region of interest is not clamped, so
- // this method will fail if the region is out of bounds of the input image.
- tflite::support::StatusOr<tflite::task::processor::SearchResult> Search(
- - const FrameBuffer& frame_buffer, const BoundingBox& roi);
- + const FrameBuffer& frame_buffer,
- + const BoundingBox& roi);
-
- // Provides access to the opaque user info stored in the index file (if any),
- // in raw binary form. Returns an empty string if the index doesn't contain
- @@ -108,7 +109,8 @@ class ImageSearcher
- // perform the nearest-neighbor search in the index.
- tflite::support::StatusOr<tflite::task::processor::SearchResult> Postprocess(
- const std::vector<const TfLiteTensor*>& output_tensors,
- - const FrameBuffer& frame_buffer, const BoundingBox& roi) override;
- + const FrameBuffer& frame_buffer,
- + const BoundingBox& roi) override;
-
- // Initializes the ImageSearcher.
- absl::Status Init(std::unique_ptr<ImageSearcherOptions> options);
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_segmenter.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_segmenter.cc
- index c9dad866f1a68..1cf9a54b91e0f 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_segmenter.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_segmenter.cc
- @@ -17,10 +17,10 @@ limitations under the License.
-
- #include <algorithm>
-
- -#include "absl/memory/memory.h" // from @com_google_absl
- -#include "absl/strings/str_format.h" // from @com_google_absl
- +#include "absl/memory/memory.h" // from @com_google_absl
- +#include "absl/strings/str_format.h" // from @com_google_absl
- #include "absl/strings/string_view.h" // from @com_google_absl
- -#include "flatbuffers/flatbuffers.h" // from @flatbuffers
- +#include "flatbuffers/flatbuffers.h" // from @flatbuffers
- #include "tensorflow/lite/c/common.h"
- #include "tensorflow_lite_support/cc/common.h"
- #include "tensorflow_lite_support/cc/port/integral_types.h"
- @@ -110,7 +110,8 @@ constexpr uint8 kColorMap[768] = {
-
- StatusOr<std::vector<LabelMapItem>> GetLabelMapIfAny(
- const ModelMetadataExtractor& metadata_extractor,
- - const TensorMetadata& tensor_metadata, absl::string_view locale) {
- + const TensorMetadata& tensor_metadata,
- + absl::string_view locale) {
- const std::string labels_filename =
- ModelMetadataExtractor::FindFirstAssociatedFileName(
- tensor_metadata, tflite::AssociatedFileType_TENSOR_AXIS_LABELS);
- @@ -332,7 +333,8 @@ StatusOr<SegmentationResult> ImageSegmenter::Segment(
-
- StatusOr<SegmentationResult> ImageSegmenter::Postprocess(
- const std::vector<const TfLiteTensor*>& output_tensors,
- - const FrameBuffer& frame_buffer, const BoundingBox& /*roi*/) {
- + const FrameBuffer& frame_buffer,
- + const BoundingBox& /*roi*/) {
- if (output_tensors.size() != 1) {
- return CreateStatusWithPayload(
- StatusCode::kInternal,
- @@ -432,7 +434,10 @@ StatusOr<SegmentationResult> ImageSegmenter::Postprocess(
- }
-
- StatusOr<float> ImageSegmenter::GetOutputConfidence(
- - const TfLiteTensor& output_tensor, int x, int y, int depth) {
- + const TfLiteTensor& output_tensor,
- + int x,
- + int y,
- + int depth) {
- int index = output_width_ * output_depth_ * y + output_depth_ * x + depth;
- if (has_uint8_outputs_) {
- ASSIGN_OR_RETURN(const uint8* data,
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_segmenter.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_segmenter.h
- index 3f51f4962738e..e255110d9dc66 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_segmenter.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_segmenter.h
- @@ -119,7 +119,8 @@ class ImageSegmenter : public BaseVisionTaskApi<SegmentationResult> {
- // results.
- tflite::support::StatusOr<SegmentationResult> Postprocess(
- const std::vector<const TfLiteTensor*>& output_tensors,
- - const FrameBuffer& frame_buffer, const BoundingBox& roi) override;
- + const FrameBuffer& frame_buffer,
- + const BoundingBox& roi) override;
-
- // Performs sanity checks on the provided ImageSegmenterOptions.
- static absl::Status SanityCheckOptions(const ImageSegmenterOptions& options);
- @@ -148,7 +149,10 @@ class ImageSegmenter : public BaseVisionTaskApi<SegmentationResult> {
- // Returns the output confidence at coordinates {x, y, depth}, dequantizing
- // on-the-fly if needed (i.e. if `has_uint8_outputs_` is true).
- tflite::support::StatusOr<float> GetOutputConfidence(
- - const TfLiteTensor& output_tensor, int x, int y, int depth);
- + const TfLiteTensor& output_tensor,
- + int x,
- + int y,
- + int depth);
-
- // Prebuilt list of ColoredLabel attached to each Segmentation result. The
- // i-th item in this list corresponds to the i-th label map item.
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/object_detector.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/object_detector.cc
- index 0a4d5f7553ee9..00775015515ac 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/object_detector.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/object_detector.cc
- @@ -20,8 +20,8 @@ limitations under the License.
- #include <vector>
-
- #include <glog/logging.h>
- -#include "absl/memory/memory.h" // from @com_google_absl
- -#include "absl/status/status.h" // from @com_google_absl
- +#include "absl/memory/memory.h" // from @com_google_absl
- +#include "absl/status/status.h" // from @com_google_absl
- #include "absl/strings/str_format.h" // from @com_google_absl
- #include "absl/strings/string_view.h" // from @com_google_absl
- #include "tensorflow/lite/c/common.h"
- @@ -141,7 +141,8 @@ StatusOr<const BoundingBoxProperties*> GetBoundingBoxProperties(
-
- StatusOr<std::vector<LabelMapItem>> GetLabelMapIfAny(
- const ModelMetadataExtractor& metadata_extractor,
- - const TensorMetadata& tensor_metadata, absl::string_view locale) {
- + const TensorMetadata& tensor_metadata,
- + absl::string_view locale) {
- const std::string labels_filename =
- ModelMetadataExtractor::FindFirstAssociatedFileName(
- tensor_metadata, tflite::AssociatedFileType_TENSOR_VALUE_LABELS);
- @@ -370,7 +371,9 @@ absl::Status ObjectDetector::PreInit() {
- return absl::OkStatus();
- }
-
- -absl::Status ObjectDetector::PostInit() { return InitScoreCalibrations(); }
- +absl::Status ObjectDetector::PostInit() {
- + return InitScoreCalibrations();
- +}
-
- StatusOr<SigmoidCalibrationParameters> BuildCalibrationParametersIfAny(
- const tflite::metadata::ModelMetadataExtractor& metadata_extractor,
- @@ -599,7 +602,8 @@ StatusOr<DetectionResult> ObjectDetector::Detect(
-
- StatusOr<DetectionResult> ObjectDetector::Postprocess(
- const std::vector<const TfLiteTensor*>& output_tensors,
- - const FrameBuffer& frame_buffer, const BoundingBox& /*roi*/) {
- + const FrameBuffer& frame_buffer,
- + const BoundingBox& /*roi*/) {
- // Most of the checks here should never happen, as outputs have been validated
- // at construction time. Checking nonetheless and returning internal errors if
- // something bad happens.
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/object_detector.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/object_detector.h
- index eaa6b5371ba52..c37fa8771081e 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/object_detector.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/object_detector.h
- @@ -19,7 +19,7 @@ limitations under the License.
- #include <memory>
-
- #include "absl/container/flat_hash_set.h" // from @com_google_absl
- -#include "absl/status/status.h" // from @com_google_absl
- +#include "absl/status/status.h" // from @com_google_absl
- #include "tensorflow/lite/core/api/op_resolver.h"
- #include "tensorflow/lite/core/shims/cc/kernels/register.h"
- #include "tensorflow_lite_support/cc/port/statusor.h"
- @@ -123,7 +123,8 @@ class ObjectDetector : public BaseVisionTaskApi<DetectionResult> {
- // Post-processing to transform the raw model outputs into detection results.
- tflite::support::StatusOr<DetectionResult> Postprocess(
- const std::vector<const TfLiteTensor*>& output_tensors,
- - const FrameBuffer& frame_buffer, const BoundingBox& roi) override;
- + const FrameBuffer& frame_buffer,
- + const BoundingBox& roi) override;
-
- // Performs sanity checks on the provided ObjectDetectorOptions.
- static absl::Status SanityCheckOptions(const ObjectDetectorOptions& options);
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/proto/image_searcher_options.proto b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/proto/image_searcher_options.proto
- index 7501bb24d659d..5b5aaf1fa035c 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/proto/image_searcher_options.proto
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/proto/image_searcher_options.proto
- @@ -21,7 +21,6 @@ import "tensorflow_lite_support/cc/task/core/proto/base_options.proto";
- import "tensorflow_lite_support/cc/task/processor/proto/embedding_options.proto";
- import "tensorflow_lite_support/cc/task/processor/proto/search_options.proto";
-
- -
- // Options for setting up an ImageSearcher.
- // Next Id: 4.
- message ImageSearcherOptions {
- @@ -37,5 +36,4 @@ message ImageSearcherOptions {
- // Options specifying the index to search into and controlling the search
- // behavior.
- optional tflite.task.processor.SearchOptions search_options = 3;
- -
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.cc
- index 1854cf546d599..9a5b96160c033 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.cc
- @@ -18,7 +18,7 @@ limitations under the License.
- #include <string>
- #include <vector>
-
- -#include "absl/strings/str_cat.h" // from @com_google_absl
- +#include "absl/strings/str_cat.h" // from @com_google_absl
- #include "absl/strings/str_format.h" // from @com_google_absl
- #include "tensorflow_lite_support/cc/port/status_macros.h"
-
- @@ -36,8 +36,10 @@ constexpr int kGrayChannel = 1;
- // Creates a FrameBuffer from one plane raw NV21/NV12 buffer and passing
- // arguments.
- StatusOr<std::unique_ptr<FrameBuffer>> CreateFromOnePlaneNVRawBuffer(
- - const uint8* input, FrameBuffer::Dimension dimension,
- - FrameBuffer::Format format, FrameBuffer::Orientation orientation,
- + const uint8* input,
- + FrameBuffer::Dimension dimension,
- + FrameBuffer::Format format,
- + FrameBuffer::Orientation orientation,
- const absl::Time timestamp) {
- FrameBuffer::Plane input_plane = {/*buffer=*/input,
- /*stride=*/{dimension.width, kGrayChannel}};
- @@ -129,7 +131,8 @@ StatusOr<const uint8*> GetUvRawBuffer(const FrameBuffer& buffer) {
- }
-
- StatusOr<FrameBuffer::Dimension> GetUvPlaneDimension(
- - FrameBuffer::Dimension dimension, FrameBuffer::Format format) {
- + FrameBuffer::Dimension dimension,
- + FrameBuffer::Format format) {
- if (dimension.width <= 0 || dimension.height <= 0) {
- return absl::InvalidArgumentError(
- absl::StrFormat("Invalid input dimension: {%d, %d}.", dimension.width,
- @@ -176,7 +179,8 @@ absl::Status ValidateBufferFormat(const FrameBuffer& buffer) {
- case FrameBuffer::Format::kGRAY:
- case FrameBuffer::Format::kRGB:
- case FrameBuffer::Format::kRGBA:
- - if (buffer.plane_count() == 1) return absl::OkStatus();
- + if (buffer.plane_count() == 1)
- + return absl::OkStatus();
- return absl::InvalidArgumentError(
- "Plane count must be 1 for grayscale and RGB[a] buffers.");
- case FrameBuffer::Format::kNV21:
- @@ -252,8 +256,11 @@ absl::Status ValidateRotateBufferInputs(const FrameBuffer& buffer,
- }
-
- absl::Status ValidateCropBufferInputs(const FrameBuffer& buffer,
- - const FrameBuffer& output_buffer, int x0,
- - int y0, int x1, int y1) {
- + const FrameBuffer& output_buffer,
- + int x0,
- + int y0,
- + int x1,
- + int y1) {
- if (!AreBufferFormatsCompatible(buffer, output_buffer)) {
- return absl::InvalidArgumentError(
- "Input and output buffer formats must match.");
- @@ -309,8 +316,10 @@ absl::Status ValidateConvertFormats(FrameBuffer::Format from_format,
-
- // Creates a FrameBuffer from raw RGBA buffer and passing arguments.
- std::unique_ptr<FrameBuffer> CreateFromRgbaRawBuffer(
- - const uint8* input, FrameBuffer::Dimension dimension,
- - FrameBuffer::Orientation orientation, const absl::Time timestamp,
- + const uint8* input,
- + FrameBuffer::Dimension dimension,
- + FrameBuffer::Orientation orientation,
- + const absl::Time timestamp,
- FrameBuffer::Stride stride) {
- if (stride == kDefaultStride) {
- stride.row_stride_bytes = dimension.width * kRgbaChannels;
- @@ -325,8 +334,10 @@ std::unique_ptr<FrameBuffer> CreateFromRgbaRawBuffer(
-
- // Creates a FrameBuffer from raw RGB buffer and passing arguments.
- std::unique_ptr<FrameBuffer> CreateFromRgbRawBuffer(
- - const uint8* input, FrameBuffer::Dimension dimension,
- - FrameBuffer::Orientation orientation, const absl::Time timestamp,
- + const uint8* input,
- + FrameBuffer::Dimension dimension,
- + FrameBuffer::Orientation orientation,
- + const absl::Time timestamp,
- FrameBuffer::Stride stride) {
- if (stride == kDefaultStride) {
- stride.row_stride_bytes = dimension.width * kRgbChannels;
- @@ -340,8 +351,10 @@ std::unique_ptr<FrameBuffer> CreateFromRgbRawBuffer(
-
- // Creates a FrameBuffer from raw grayscale buffer and passing arguments.
- std::unique_ptr<FrameBuffer> CreateFromGrayRawBuffer(
- - const uint8* input, FrameBuffer::Dimension dimension,
- - FrameBuffer::Orientation orientation, const absl::Time timestamp,
- + const uint8* input,
- + FrameBuffer::Dimension dimension,
- + FrameBuffer::Orientation orientation,
- + const absl::Time timestamp,
- FrameBuffer::Stride stride) {
- if (stride == kDefaultStride) {
- stride.row_stride_bytes = dimension.width * kGrayChannel;
- @@ -356,10 +369,16 @@ std::unique_ptr<FrameBuffer> CreateFromGrayRawBuffer(
-
- // Creates a FrameBuffer from raw YUV buffer and passing arguments.
- StatusOr<std::unique_ptr<FrameBuffer>> CreateFromYuvRawBuffer(
- - const uint8* y_plane, const uint8* u_plane, const uint8* v_plane,
- - FrameBuffer::Format format, FrameBuffer::Dimension dimension,
- - int row_stride_y, int row_stride_uv, int pixel_stride_uv,
- - FrameBuffer::Orientation orientation, const absl::Time timestamp) {
- + const uint8* y_plane,
- + const uint8* u_plane,
- + const uint8* v_plane,
- + FrameBuffer::Format format,
- + FrameBuffer::Dimension dimension,
- + int row_stride_y,
- + int row_stride_uv,
- + int pixel_stride_uv,
- + FrameBuffer::Orientation orientation,
- + const absl::Time timestamp) {
- const int pixel_stride_y = 1;
- std::vector<FrameBuffer::Plane> planes;
- if (format == FrameBuffer::Format::kNV21 ||
- @@ -380,9 +399,11 @@ StatusOr<std::unique_ptr<FrameBuffer>> CreateFromYuvRawBuffer(
- }
-
- StatusOr<std::unique_ptr<FrameBuffer>> CreateFromRawBuffer(
- - const uint8* buffer, FrameBuffer::Dimension dimension,
- + const uint8* buffer,
- + FrameBuffer::Dimension dimension,
- const FrameBuffer::Format target_format,
- - FrameBuffer::Orientation orientation, absl::Time timestamp) {
- + FrameBuffer::Orientation orientation,
- + absl::Time timestamp) {
- switch (target_format) {
- case FrameBuffer::Format::kNV12:
- return CreateFromOnePlaneNVRawBuffer(buffer, dimension, target_format,
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h
- index 470e76b9037a1..7ebf69fadc3de 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h
- @@ -18,8 +18,8 @@ limitations under the License.
- #include <memory>
-
- #include "absl/status/status.h" // from @com_google_absl
- -#include "absl/time/clock.h" // from @com_google_absl
- -#include "absl/time/time.h" // from @com_google_absl
- +#include "absl/time/clock.h" // from @com_google_absl
- +#include "absl/time/time.h" // from @com_google_absl
- #include "tensorflow_lite_support/cc/port/integral_types.h"
- #include "tensorflow_lite_support/cc/port/statusor.h"
- #include "tensorflow_lite_support/cc/task/vision/core/frame_buffer.h"
- @@ -58,7 +58,8 @@ tflite::support::StatusOr<const uint8*> GetUvRawBuffer(
- // supported formats. This method assums the UV plane share the same dimension,
- // especially for the YV12 / YV21 formats.
- tflite::support::StatusOr<FrameBuffer::Dimension> GetUvPlaneDimension(
- - FrameBuffer::Dimension dimension, FrameBuffer::Format format);
- + FrameBuffer::Dimension dimension,
- + FrameBuffer::Format format);
-
- // Returns crop dimension based on crop start and end points.
- FrameBuffer::Dimension GetCropDimension(int x0, int x1, int y0, int y1);
- @@ -92,8 +93,11 @@ absl::Status ValidateRotateBufferInputs(const FrameBuffer& buffer,
- // (x0, y0) represents the top-left point of the buffer.
- // (x1, y1) represents the bottom-right point of the buffer.
- absl::Status ValidateCropBufferInputs(const FrameBuffer& buffer,
- - const FrameBuffer& output_buffer, int x0,
- - int y0, int x1, int y1);
- + const FrameBuffer& output_buffer,
- + int x0,
- + int y0,
- + int x1,
- + int y1);
-
- // Validates the given inputs for flipping `buffer` horizontally or vertically.
- absl::Status ValidateFlipBufferInputs(const FrameBuffer& buffer,
- @@ -110,36 +114,45 @@ absl::Status ValidateConvertFormats(FrameBuffer::Format from_format,
-
- // Creates a FrameBuffer from raw RGBA buffer and passing arguments.
- std::unique_ptr<FrameBuffer> CreateFromRgbaRawBuffer(
- - const uint8* input, FrameBuffer::Dimension dimension,
- + const uint8* input,
- + FrameBuffer::Dimension dimension,
- FrameBuffer::Orientation orientation = FrameBuffer::Orientation::kTopLeft,
- absl::Time timestamp = absl::Now(),
- FrameBuffer::Stride stride = kDefaultStride);
-
- // Creates a FrameBuffer from raw RGB buffer and passing arguments.
- std::unique_ptr<FrameBuffer> CreateFromRgbRawBuffer(
- - const uint8* input, FrameBuffer::Dimension dimension,
- + const uint8* input,
- + FrameBuffer::Dimension dimension,
- FrameBuffer::Orientation orientation = FrameBuffer::Orientation::kTopLeft,
- absl::Time timestamp = absl::Now(),
- FrameBuffer::Stride stride = kDefaultStride);
-
- // Creates a FrameBuffer from raw grayscale buffer and passing arguments.
- std::unique_ptr<FrameBuffer> CreateFromGrayRawBuffer(
- - const uint8* input, FrameBuffer::Dimension dimension,
- + const uint8* input,
- + FrameBuffer::Dimension dimension,
- FrameBuffer::Orientation orientation = FrameBuffer::Orientation::kTopLeft,
- absl::Time timestamp = absl::Now(),
- FrameBuffer::Stride stride = kDefaultStride);
-
- // Creates a FrameBuffer from raw YUV buffer and passing arguments.
- tflite::support::StatusOr<std::unique_ptr<FrameBuffer>> CreateFromYuvRawBuffer(
- - const uint8* y_plane, const uint8* u_plane, const uint8* v_plane,
- - FrameBuffer::Format format, FrameBuffer::Dimension dimension,
- - int row_stride_y, int row_stride_uv, int pixel_stride_uv,
- + const uint8* y_plane,
- + const uint8* u_plane,
- + const uint8* v_plane,
- + FrameBuffer::Format format,
- + FrameBuffer::Dimension dimension,
- + int row_stride_y,
- + int row_stride_uv,
- + int pixel_stride_uv,
- FrameBuffer::Orientation orientation = FrameBuffer::Orientation::kTopLeft,
- absl::Time timestamp = absl::Now());
-
- // Creates an instance of FrameBuffer from raw buffer and passing arguments.
- tflite::support::StatusOr<std::unique_ptr<FrameBuffer>> CreateFromRawBuffer(
- - const uint8* buffer, FrameBuffer::Dimension dimension,
- + const uint8* buffer,
- + FrameBuffer::Dimension dimension,
- FrameBuffer::Format target_format,
- FrameBuffer::Orientation orientation = FrameBuffer::Orientation::kTopLeft,
- absl::Time timestamp = absl::Now());
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.cc
- index 4d767fc3e48b2..4728c30cb60dc 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.cc
- @@ -22,8 +22,8 @@ limitations under the License.
- #include <utility>
- #include <vector>
-
- -#include "absl/memory/memory.h" // from @com_google_absl
- -#include "absl/status/status.h" // from @com_google_absl
- +#include "absl/memory/memory.h" // from @com_google_absl
- +#include "absl/status/status.h" // from @com_google_absl
- #include "absl/strings/str_format.h" // from @com_google_absl
- #include "tensorflow/lite/kernels/internal/compatibility.h"
- #include "tensorflow/lite/kernels/op_macros.h"
- @@ -91,7 +91,8 @@ static int GetOrientationIndex(FrameBuffer::Orientation orientation) {
- // The new box origin is (x:box.origin_y, y:width - (box.origin_x + box.width).
- // The new box dimension is (w: box.height, h: box.width).
- //
- -static BoundingBox RotateBoundingBox(const BoundingBox& box, int angle,
- +static BoundingBox RotateBoundingBox(const BoundingBox& box,
- + int angle,
- FrameBuffer::Dimension frame_dimension) {
- int rx = box.origin_x(), ry = box.origin_y(), rw = box.width(),
- rh = box.height();
- @@ -130,9 +131,12 @@ static BoundingBox RotateBoundingBox(const BoundingBox& box, int angle,
- // in counterclockwise degree in one of the values [0, 90, 180, 270].
- //
- // See `RotateBoundingBox` above for more details.
- -static void RotateCoordinates(int from_x, int from_y, int angle,
- +static void RotateCoordinates(int from_x,
- + int from_y,
- + int angle,
- const FrameBuffer::Dimension& frame_dimension,
- - int* to_x, int* to_y) {
- + int* to_x,
- + int* to_y) {
- switch (angle) {
- case 0:
- *to_x = from_x;
- @@ -199,7 +203,10 @@ BoundingBox OrientBoundingBox(const BoundingBox& from_box,
- }
-
- BoundingBox OrientAndDenormalizeBoundingBox(
- - float from_left, float from_top, float from_right, float from_bottom,
- + float from_left,
- + float from_top,
- + float from_right,
- + float from_bottom,
- FrameBuffer::Orientation from_orientation,
- FrameBuffer::Orientation to_orientation,
- FrameBuffer::Dimension from_dimension) {
- @@ -214,10 +221,12 @@ BoundingBox OrientAndDenormalizeBoundingBox(
- return to_box;
- }
-
- -void OrientCoordinates(int from_x, int from_y,
- +void OrientCoordinates(int from_x,
- + int from_y,
- FrameBuffer::Orientation from_orientation,
- FrameBuffer::Orientation to_orientation,
- - FrameBuffer::Dimension from_dimension, int* to_x,
- + FrameBuffer::Dimension from_dimension,
- + int* to_x,
- int* to_y) {
- *to_x = from_x;
- *to_y = from_y;
- @@ -298,15 +307,19 @@ bool RequireDimensionSwap(FrameBuffer::Orientation from_orientation,
- return params.rotation_angle_deg == 90 || params.rotation_angle_deg == 270;
- }
-
- -absl::Status FrameBufferUtils::Crop(const FrameBuffer& buffer, int x0, int y0,
- - int x1, int y1,
- +absl::Status FrameBufferUtils::Crop(const FrameBuffer& buffer,
- + int x0,
- + int y0,
- + int x1,
- + int y1,
- FrameBuffer* output_buffer) {
- TFLITE_DCHECK(utils_ != nullptr);
- return utils_->Crop(buffer, x0, y0, x1, y1, output_buffer);
- }
-
- FrameBuffer::Dimension FrameBufferUtils::GetSize(
- - const FrameBuffer& buffer, const FrameBufferOperation& operation) {
- + const FrameBuffer& buffer,
- + const FrameBufferOperation& operation) {
- FrameBuffer::Dimension dimension = buffer.dimension();
- if (absl::holds_alternative<OrientOperation>(operation)) {
- OrientParams params =
- @@ -327,7 +340,8 @@ FrameBuffer::Dimension FrameBufferUtils::GetSize(
- }
-
- std::vector<FrameBuffer::Plane> FrameBufferUtils::GetPlanes(
- - const uint8* buffer, FrameBuffer::Dimension dimension,
- + const uint8* buffer,
- + FrameBuffer::Dimension dimension,
- FrameBuffer::Format format) {
- std::vector<FrameBuffer::Plane> planes;
- switch (format) {
- @@ -378,7 +392,8 @@ std::vector<FrameBuffer::Plane> FrameBufferUtils::GetPlanes(
- }
-
- FrameBuffer::Orientation FrameBufferUtils::GetOrientation(
- - const FrameBuffer& buffer, const FrameBufferOperation& operation) {
- + const FrameBuffer& buffer,
- + const FrameBufferOperation& operation) {
- if (absl::holds_alternative<OrientOperation>(operation)) {
- return absl::get<OrientOperation>(operation).to_orientation;
- }
- @@ -386,7 +401,8 @@ FrameBuffer::Orientation FrameBufferUtils::GetOrientation(
- }
-
- FrameBuffer::Format FrameBufferUtils::GetFormat(
- - const FrameBuffer& buffer, const FrameBufferOperation& operation) {
- + const FrameBuffer& buffer,
- + const FrameBufferOperation& operation) {
- if (absl::holds_alternative<ConvertOperation>(operation)) {
- return absl::get<ConvertOperation>(operation).to_format;
- }
- @@ -578,8 +594,10 @@ absl::Status FrameBufferUtils::Execute(
- }
-
- absl::Status FrameBufferUtils::Preprocess(
- - const FrameBuffer& buffer, absl::optional<BoundingBox> bounding_box,
- - FrameBuffer* output_buffer, bool uniform_resizing) {
- + const FrameBuffer& buffer,
- + absl::optional<BoundingBox> bounding_box,
- + FrameBuffer* output_buffer,
- + bool uniform_resizing) {
- std::vector<FrameBufferOperation> frame_buffer_operations;
- // Handle cropping and resizing.
- bool needs_dimension_swap =
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.h
- index 59e80e5765bb0..48549461159cb 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.h
- @@ -19,9 +19,9 @@ limitations under the License.
- #include <memory>
- #include <vector>
-
- -#include "absl/status/status.h" // from @com_google_absl
- +#include "absl/status/status.h" // from @com_google_absl
- #include "absl/types/optional.h" // from @com_google_absl
- -#include "absl/types/variant.h" // from @com_google_absl
- +#include "absl/types/variant.h" // from @com_google_absl
- #include "tensorflow_lite_support/cc/port/integral_types.h"
- #include "tensorflow_lite_support/cc/task/vision/core/frame_buffer.h"
- #include "tensorflow_lite_support/cc/task/vision/proto/bounding_box_proto_inc.h"
- @@ -45,7 +45,10 @@ BoundingBox OrientBoundingBox(const BoundingBox& from_box,
-
- // Same as OrientBoundingBox but from normalized coordinates.
- BoundingBox OrientAndDenormalizeBoundingBox(
- - float from_left, float from_top, float from_right, float from_bottom,
- + float from_left,
- + float from_top,
- + float from_right,
- + float from_bottom,
- FrameBuffer::Orientation from_orientation,
- FrameBuffer::Orientation to_orientation,
- FrameBuffer::Dimension from_dimension);
- @@ -53,10 +56,12 @@ BoundingBox OrientAndDenormalizeBoundingBox(
- // Rotates `(from_x, from_y)` coordinates from an image of dimension
- // `from_dimension` and orientation `from_orientation` into `(to_x, to_y)`
- // coordinates with orientation `to_orientation`.
- -void OrientCoordinates(int from_x, int from_y,
- +void OrientCoordinates(int from_x,
- + int from_y,
- FrameBuffer::Orientation from_orientation,
- FrameBuffer::Orientation to_orientation,
- - FrameBuffer::Dimension from_dimension, int* to_x,
- + FrameBuffer::Dimension from_dimension,
- + int* to_x,
- int* to_y);
-
- // Returns whether the conversion from from_orientation to to_orientation
- @@ -92,7 +97,8 @@ OrientParams GetOrientParams(FrameBuffer::Orientation from_orientation,
- // To perform just cropping, the `crop_width` and `crop_height` should be the
- // same as `resize_width` `and resize_height`.
- struct CropResizeOperation {
- - CropResizeOperation(int crop_origin_x, int crop_origin_y,
- + CropResizeOperation(int crop_origin_x,
- + int crop_origin_y,
- FrameBuffer::Dimension crop_dimension,
- FrameBuffer::Dimension resize_dimension)
- : crop_origin_x(crop_origin_x),
- @@ -124,7 +130,8 @@ struct CropResizeOperation {
- // The resized region is aligned to the upper left pixel of the output buffer.
- // The unfilled area of the output buffer remains untouched.
- struct UniformCropResizeOperation {
- - UniformCropResizeOperation(int crop_origin_x, int crop_origin_y,
- + UniformCropResizeOperation(int crop_origin_x,
- + int crop_origin_y,
- FrameBuffer::Dimension crop_dimension,
- FrameBuffer::Dimension output_dimension)
- : crop_origin_x(crop_origin_x),
- @@ -154,9 +161,10 @@ struct OrientOperation {
-
- // A variant of the supported operations on FrameBuffers. Alias for user
- // convenience.
- -using FrameBufferOperation =
- - absl::variant<CropResizeOperation, ConvertOperation, OrientOperation,
- - UniformCropResizeOperation>;
- +using FrameBufferOperation = absl::variant<CropResizeOperation,
- + ConvertOperation,
- + OrientOperation,
- + UniformCropResizeOperation>;
-
- // Image processing utility. This utility provides both basic image buffer
- // manipulations (e.g. rotation, format conversion, resizing, etc) as well as
- @@ -212,7 +220,11 @@ class FrameBufferUtils {
- // should be big enough to store the operation result. If the `output_buffer`
- // size dimension does not match with crop dimension, then a resize is
- // automatically performed.
- - absl::Status Crop(const FrameBuffer& buffer, int x0, int y0, int x1, int y1,
- + absl::Status Crop(const FrameBuffer& buffer,
- + int x0,
- + int y0,
- + int x1,
- + int y1,
- FrameBuffer* output_buffer);
-
- // Performs resizing operation.
- @@ -229,7 +241,8 @@ class FrameBufferUtils {
- //
- // The output_buffer should have metadata populated and its backing buffer
- // should be big enough to store the operation result.
- - absl::Status Rotate(const FrameBuffer& buffer, RotationDegree rotation,
- + absl::Status Rotate(const FrameBuffer& buffer,
- + RotationDegree rotation,
- FrameBuffer* output_buffer);
-
- // Performs horizontal flip operation.
- @@ -305,7 +318,8 @@ class FrameBufferUtils {
-
- // Returns the new FrameBuffer orientation after command is processed.
- FrameBuffer::Orientation GetOrientation(
- - const FrameBuffer& buffer, const FrameBufferOperation& operation);
- + const FrameBuffer& buffer,
- + const FrameBufferOperation& operation);
-
- // Returns the new FrameBuffer format after command is processed.
- FrameBuffer::Format GetFormat(const FrameBuffer& buffer,
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils_interface.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils_interface.h
- index ec0c3119ea4e8..59da2206bb06f 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils_interface.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils_interface.h
- @@ -37,8 +37,12 @@ class FrameBufferUtilsInterface {
- //
- // The `output_buffer` should have metadata populated and its backing buffer
- // should be big enough to store the operation result.
- - virtual absl::Status Crop(const FrameBuffer& buffer, int x0, int y0, int x1,
- - int y1, FrameBuffer* output_buffer) = 0;
- + virtual absl::Status Crop(const FrameBuffer& buffer,
- + int x0,
- + int y0,
- + int x1,
- + int y1,
- + FrameBuffer* output_buffer) = 0;
-
- // Resizes `buffer` to the size of the given `output_buffer`.
- //
- @@ -57,7 +61,8 @@ class FrameBufferUtilsInterface {
- //
- // The `output_buffer` should have metadata populated and its backing buffer
- // should be big enough to store the operation result.
- - virtual absl::Status Rotate(const FrameBuffer& buffer, int angle_deg,
- + virtual absl::Status Rotate(const FrameBuffer& buffer,
- + int angle_deg,
- FrameBuffer* output_buffer) = 0;
-
- // Flips `buffer` horizontally.
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/image_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/image_utils.cc
- index 3f8bc7b43f4f1..d5b277ad33b89 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/image_utils.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/image_utils.cc
- @@ -23,11 +23,11 @@ limitations under the License.
- #define STB_IMAGE_IMPLEMENTATION
- #define STB_IMAGE_WRITE_IMPLEMENTATION
-
- -#include "absl/status/status.h" // from @com_google_absl
- -#include "absl/strings/match.h" // from @com_google_absl
- +#include "absl/status/status.h" // from @com_google_absl
- +#include "absl/strings/match.h" // from @com_google_absl
- #include "absl/strings/str_format.h" // from @com_google_absl
- -#include "stb_image.h" // from @stblib
- -#include "stb_image_write.h" // from @stblib
- +#include "stb_image.h" // from @stblib
- +#include "stb_image_write.h" // from @stblib
- #include "tensorflow_lite_support/cc/port/status_macros.h"
- #include "tensorflow_lite_support/cc/port/statusor.h"
- #include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h"
- @@ -88,7 +88,9 @@ absl::Status EncodeImageToPngFile(const ImageData& image_data,
- return absl::OkStatus();
- }
-
- -void ImageDataFree(ImageData* image) { stbi_image_free(image->pixel_data); }
- +void ImageDataFree(ImageData* image) {
- + stbi_image_free(image->pixel_data);
- +}
-
- tflite::support::StatusOr<std::unique_ptr<FrameBuffer>>
- CreateFrameBufferFromImageData(const ImageData& image) {
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/image_utils.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/image_utils.h
- index 6ba5c2d6490ab..7de32ee9c0f53 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/image_utils.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/image_utils.h
- @@ -15,7 +15,7 @@ limitations under the License.
- #ifndef TENSORFLOW_LITE_SUPPORT_EXAMPLES_TASK_VISION_DESKTOP_UTILS_IMAGE_UTILS_H_
- #define TENSORFLOW_LITE_SUPPORT_EXAMPLES_TASK_VISION_DESKTOP_UTILS_IMAGE_UTILS_H_
-
- -#include "absl/status/status.h" // from @com_google_absl
- +#include "absl/status/status.h" // from @com_google_absl
- #include "absl/strings/string_view.h" // from @com_google_absl
- #include "tensorflow_lite_support/cc/port/integral_types.h"
- #include "tensorflow_lite_support/cc/port/statusor.h"
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/libyuv_frame_buffer_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/libyuv_frame_buffer_utils.cc
- index e0dd8a99c64c0..a0ee2dab96b6a 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/libyuv_frame_buffer_utils.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/libyuv_frame_buffer_utils.cc
- @@ -20,11 +20,11 @@ limitations under the License.
- #include <memory>
- #include <string>
-
- -#include "absl/status/status.h" // from @com_google_absl
- -#include "absl/strings/str_cat.h" // from @com_google_absl
- +#include "absl/status/status.h" // from @com_google_absl
- +#include "absl/strings/str_cat.h" // from @com_google_absl
- #include "absl/strings/str_format.h" // from @com_google_absl
- -#include "libyuv.h" // from @libyuv
- -#include "libyuv/convert_argb.h" // from @libyuv
- +#include "libyuv.h" // from @libyuv
- +#include "libyuv/convert_argb.h" // from @libyuv
- #include "tensorflow_lite_support/cc/common.h"
- #include "tensorflow_lite_support/cc/port/integral_types.h"
- #include "tensorflow_lite_support/cc/port/status_macros.h"
- @@ -384,7 +384,8 @@ absl::Status ResizeNv(const FrameBuffer& buffer, FrameBuffer* output_buffer) {
-
- // Converts `buffer` to libyuv ARGB format and stores the conversion result
- // in `dest_argb`.
- -absl::Status ConvertRgbToArgb(const FrameBuffer& buffer, uint8* dest_argb,
- +absl::Status ConvertRgbToArgb(const FrameBuffer& buffer,
- + uint8* dest_argb,
- int dest_stride_argb) {
- RETURN_IF_ERROR(ValidateBufferPlaneMetadata(buffer));
- if (buffer.format() != FrameBuffer::Format::kRGB) {
- @@ -421,7 +422,8 @@ absl::Status ConvertRgbToArgb(const FrameBuffer& buffer, uint8* dest_argb,
-
- // Converts `src_argb` in libyuv ARGB format to FrameBuffer::kRGB format and
- // stores the conversion result in `output_buffer`.
- -absl::Status ConvertArgbToRgb(uint8* src_argb, int src_stride_argb,
- +absl::Status ConvertArgbToRgb(uint8* src_argb,
- + int src_stride_argb,
- FrameBuffer* output_buffer) {
- RETURN_IF_ERROR(ValidateBufferPlaneMetadata(*output_buffer));
- if (output_buffer->format() != FrameBuffer::Format::kRGB) {
- @@ -457,7 +459,8 @@ absl::Status ConvertArgbToRgb(uint8* src_argb, int src_stride_argb,
-
- // Converts `buffer` in FrameBuffer::kRGBA format to libyuv ARGB (BGRA in
- // memory) format and stores the conversion result in `dest_argb`.
- -absl::Status ConvertRgbaToArgb(const FrameBuffer& buffer, uint8* dest_argb,
- +absl::Status ConvertRgbaToArgb(const FrameBuffer& buffer,
- + uint8* dest_argb,
- int dest_stride_argb) {
- RETURN_IF_ERROR(ValidateBufferPlaneMetadata(buffer));
- if (buffer.format() != FrameBuffer::Format::kRGBA) {
- @@ -689,7 +692,8 @@ libyuv::RotationMode GetLibyuvRotationMode(int angle_deg) {
- }
- }
-
- -absl::Status RotateRgba(const FrameBuffer& buffer, int angle_deg,
- +absl::Status RotateRgba(const FrameBuffer& buffer,
- + int angle_deg,
- FrameBuffer* output_buffer) {
- if (buffer.plane_count() > 1) {
- return CreateStatusWithPayload(
- @@ -713,7 +717,8 @@ absl::Status RotateRgba(const FrameBuffer& buffer, int angle_deg,
- return absl::OkStatus();
- }
-
- -absl::Status RotateRgb(const FrameBuffer& buffer, int angle_deg,
- +absl::Status RotateRgb(const FrameBuffer& buffer,
- + int angle_deg,
- FrameBuffer* output_buffer) {
- // libyuv does not support rotate kRGB (RGB24) foramat. In this method, the
- // implementation converts kRGB format to ARGB and use ARGB buffer for
- @@ -746,7 +751,8 @@ absl::Status RotateRgb(const FrameBuffer& buffer, int angle_deg,
- output_buffer);
- }
-
- -absl::Status RotateGray(const FrameBuffer& buffer, int angle_deg,
- +absl::Status RotateGray(const FrameBuffer& buffer,
- + int angle_deg,
- FrameBuffer* output_buffer) {
- if (buffer.plane_count() > 1) {
- return CreateStatusWithPayload(
- @@ -769,7 +775,8 @@ absl::Status RotateGray(const FrameBuffer& buffer, int angle_deg,
- }
-
- // Rotates YV12/YV21 frame buffer.
- -absl::Status RotateYv(const FrameBuffer& buffer, int angle_deg,
- +absl::Status RotateYv(const FrameBuffer& buffer,
- + int angle_deg,
- FrameBuffer* output_buffer) {
- ASSIGN_OR_RETURN(FrameBuffer::YuvData input_data,
- FrameBuffer::GetYuvDataFromFrameBuffer(buffer));
- @@ -794,7 +801,8 @@ absl::Status RotateYv(const FrameBuffer& buffer, int angle_deg,
- // Rotates NV12/NV21 frame buffer.
- // TODO(b/152097364): Refactor NV12/NV21 rotation after libyuv explicitly
- // support that.
- -absl::Status RotateNv(const FrameBuffer& buffer, int angle_deg,
- +absl::Status RotateNv(const FrameBuffer& buffer,
- + int angle_deg,
- FrameBuffer* output_buffer) {
- if (buffer.format() != FrameBuffer::Format::kNV12 &&
- buffer.format() != FrameBuffer::Format::kNV21) {
- @@ -884,8 +892,12 @@ absl::Status FlipPlaneVertically(const FrameBuffer& buffer,
- }
-
- // This method only supports kGRAY, kRGBA, and kRGB formats.
- -absl::Status CropPlane(const FrameBuffer& buffer, int x0, int y0, int x1,
- - int y1, FrameBuffer* output_buffer) {
- +absl::Status CropPlane(const FrameBuffer& buffer,
- + int x0,
- + int y0,
- + int x1,
- + int y1,
- + FrameBuffer* output_buffer) {
- if (buffer.plane_count() > 1) {
- return CreateStatusWithPayload(
- StatusCode::kInternal,
- @@ -912,7 +924,11 @@ absl::Status CropPlane(const FrameBuffer& buffer, int x0, int y0, int x1,
-
- // Crops NV12/NV21 FrameBuffer to the subregion defined by the top left pixel
- // position (x0, y0) and the bottom right pixel position (x1, y1).
- -absl::Status CropNv(const FrameBuffer& buffer, int x0, int y0, int x1, int y1,
- +absl::Status CropNv(const FrameBuffer& buffer,
- + int x0,
- + int y0,
- + int x1,
- + int y1,
- FrameBuffer* output_buffer) {
- ASSIGN_OR_RETURN(FrameBuffer::YuvData input_data,
- FrameBuffer::GetYuvDataFromFrameBuffer(buffer));
- @@ -944,7 +960,11 @@ absl::Status CropNv(const FrameBuffer& buffer, int x0, int y0, int x1, int y1,
-
- // Crops YV12/YV21 FrameBuffer to the subregion defined by the top left pixel
- // position (x0, y0) and the bottom right pixel position (x1, y1).
- -absl::Status CropYv(const FrameBuffer& buffer, int x0, int y0, int x1, int y1,
- +absl::Status CropYv(const FrameBuffer& buffer,
- + int x0,
- + int y0,
- + int x1,
- + int y1,
- FrameBuffer* output_buffer) {
- ASSIGN_OR_RETURN(FrameBuffer::YuvData input_data,
- FrameBuffer::GetYuvDataFromFrameBuffer(buffer));
- @@ -979,8 +999,12 @@ absl::Status CropYv(const FrameBuffer& buffer, int x0, int y0, int x1, int y1,
- return absl::OkStatus();
- }
-
- -absl::Status CropResizeYuv(const FrameBuffer& buffer, int x0, int y0, int x1,
- - int y1, FrameBuffer* output_buffer) {
- +absl::Status CropResizeYuv(const FrameBuffer& buffer,
- + int x0,
- + int y0,
- + int x1,
- + int y1,
- + FrameBuffer* output_buffer) {
- FrameBuffer::Dimension crop_dimension = GetCropDimension(x0, x1, y0, y1);
- if (crop_dimension == output_buffer->dimension()) {
- switch (buffer.format()) {
- @@ -1308,8 +1332,12 @@ absl::Status ResizeGray(const FrameBuffer& buffer, FrameBuffer* output_buffer) {
- }
-
- // This method only supports kGRAY, kRGBA, and kRGB formats.
- -absl::Status CropResize(const FrameBuffer& buffer, int x0, int y0, int x1,
- - int y1, FrameBuffer* output_buffer) {
- +absl::Status CropResize(const FrameBuffer& buffer,
- + int x0,
- + int y0,
- + int x1,
- + int y1,
- + FrameBuffer* output_buffer) {
- FrameBuffer::Dimension crop_dimension = GetCropDimension(x0, x1, y0, y1);
- if (crop_dimension == output_buffer->dimension()) {
- return CropPlane(buffer, x0, y0, x1, y1, output_buffer);
- @@ -1343,8 +1371,11 @@ absl::Status CropResize(const FrameBuffer& buffer, int x0, int y0, int x1,
-
- } // namespace
-
- -absl::Status LibyuvFrameBufferUtils::Crop(const FrameBuffer& buffer, int x0,
- - int y0, int x1, int y1,
- +absl::Status LibyuvFrameBufferUtils::Crop(const FrameBuffer& buffer,
- + int x0,
- + int y0,
- + int x1,
- + int y1,
- FrameBuffer* output_buffer) {
- RETURN_IF_ERROR(ValidateBufferPlaneMetadata(buffer));
- RETURN_IF_ERROR(ValidateBufferPlaneMetadata(*output_buffer));
- @@ -1425,7 +1456,8 @@ absl::Status LibyuvFrameBufferUtils::Rotate(const FrameBuffer& buffer,
- }
-
- absl::Status LibyuvFrameBufferUtils::FlipHorizontally(
- - const FrameBuffer& buffer, FrameBuffer* output_buffer) {
- + const FrameBuffer& buffer,
- + FrameBuffer* output_buffer) {
- RETURN_IF_ERROR(ValidateBufferPlaneMetadata(buffer));
- RETURN_IF_ERROR(ValidateBufferPlaneMetadata(*output_buffer));
- RETURN_IF_ERROR(ValidateFlipBufferInputs(buffer, *output_buffer));
- @@ -1453,7 +1485,8 @@ absl::Status LibyuvFrameBufferUtils::FlipHorizontally(
- }
-
- absl::Status LibyuvFrameBufferUtils::FlipVertically(
- - const FrameBuffer& buffer, FrameBuffer* output_buffer) {
- + const FrameBuffer& buffer,
- + FrameBuffer* output_buffer) {
- RETURN_IF_ERROR(ValidateBufferPlaneMetadata(buffer));
- RETURN_IF_ERROR(ValidateBufferPlaneMetadata(*output_buffer));
- RETURN_IF_ERROR(ValidateFlipBufferInputs(buffer, *output_buffer));
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/libyuv_frame_buffer_utils.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/libyuv_frame_buffer_utils.h
- index 5da898bc058a4..6f83559139130 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/libyuv_frame_buffer_utils.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/libyuv_frame_buffer_utils.h
- @@ -41,7 +41,11 @@ class LibyuvFrameBufferUtils : public FrameBufferUtilsInterface {
- //
- // Crop region dimensions must be equal or smaller than input `buffer`
- // dimensions.
- - absl::Status Crop(const FrameBuffer& buffer, int x0, int y0, int x1, int y1,
- + absl::Status Crop(const FrameBuffer& buffer,
- + int x0,
- + int y0,
- + int x1,
- + int y1,
- FrameBuffer* output_buffer) override;
-
- // Resizes `buffer` to the size of the given `output_buffer`.
- @@ -51,7 +55,8 @@ class LibyuvFrameBufferUtils : public FrameBufferUtilsInterface {
- // Rotates `buffer` counter-clockwise by the given `angle_deg` (in degrees).
- //
- // The given angle must be a multiple of 90 degrees.
- - absl::Status Rotate(const FrameBuffer& buffer, int angle_deg,
- + absl::Status Rotate(const FrameBuffer& buffer,
- + int angle_deg,
- FrameBuffer* output_buffer) override;
-
- // Flips `buffer` horizontally.
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/score_calibration.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/score_calibration.cc
- index bc57c0b904534..d58969d96827e 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/score_calibration.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/score_calibration.cc
- @@ -20,11 +20,11 @@ limitations under the License.
- #include <utility>
- #include <vector>
-
- -#include "absl/status/status.h" // from @com_google_absl
- -#include "absl/strings/str_format.h" // from @com_google_absl
- -#include "absl/strings/str_split.h" // from @com_google_absl
- +#include "absl/status/status.h" // from @com_google_absl
- +#include "absl/strings/str_format.h" // from @com_google_absl
- +#include "absl/strings/str_split.h" // from @com_google_absl
- #include "absl/strings/string_view.h" // from @com_google_absl
- -#include "absl/types/optional.h" // from @com_google_absl
- +#include "absl/types/optional.h" // from @com_google_absl
- #include "tensorflow_lite_support/cc/common.h"
- #include "tensorflow_lite_support/cc/port/status_macros.h"
-
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/score_calibration.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/score_calibration.h
- index 95cbecf54bd1d..e2b403d9b35b9 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/score_calibration.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/score_calibration.h
- @@ -23,9 +23,9 @@ limitations under the License.
- #include <vector>
-
- #include "absl/container/flat_hash_map.h" // from @com_google_absl
- -#include "absl/status/status.h" // from @com_google_absl
- -#include "absl/strings/string_view.h" // from @com_google_absl
- -#include "absl/types/optional.h" // from @com_google_absl
- +#include "absl/status/status.h" // from @com_google_absl
- +#include "absl/strings/string_view.h" // from @com_google_absl
- +#include "absl/types/optional.h" // from @com_google_absl
- #include "tensorflow_lite_support/cc/port/statusor.h"
- #include "tensorflow_lite_support/cc/task/vision/core/label_map_item.h"
- #include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
- @@ -37,7 +37,10 @@ namespace vision {
- // Sigmoid structure.
- struct Sigmoid {
- Sigmoid() : scale(1.0) {}
- - Sigmoid(std::string label, float slope, float offset, float scale = 1.0,
- + Sigmoid(std::string label,
- + float slope,
- + float offset,
- + float scale = 1.0,
- absl::optional<float> min_uncalibrated_score = absl::nullopt)
- : label(label),
- slope(slope),
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/common_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/common_test.cc
- index 311994c1abbf9..bc2f9dfd53a96 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/common_test.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/common_test.cc
- @@ -16,7 +16,7 @@ limitations under the License.
- #include "tensorflow_lite_support/cc/common.h"
-
- #include "absl/status/status.h" // from @com_google_absl
- -#include "absl/strings/cord.h" // from @com_google_absl
- +#include "absl/strings/cord.h" // from @com_google_absl
- #include "tensorflow_lite_support/cc/port/gmock.h"
- #include "tensorflow_lite_support/cc/port/gtest.h"
-
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/processor/image_preprocessor_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/processor/image_preprocessor_test.cc
- index 9a00e2f9e89a1..ef0e783e97c3e 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/processor/image_preprocessor_test.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/processor/image_preprocessor_test.cc
- @@ -46,8 +46,8 @@ constexpr char kTestDataDirectory[] =
- constexpr char kDilatedConvolutionModelWithMetaData[] = "dilated_conv.tflite";
-
- StatusOr<ImageData> LoadImage(std::string image_name) {
- - return DecodeImageFromFile(JoinPath("./" /*test src dir*/,
- - kTestDataDirectory, image_name));
- + return DecodeImageFromFile(
- + JoinPath("./" /*test src dir*/, kTestDataDirectory, image_name));
- }
-
- class DynamicInputTest : public tflite_shims::testing::Test {
- @@ -60,7 +60,7 @@ class DynamicInputTest : public tflite_shims::testing::Test {
- SUPPORT_ASSERT_OK(engine_->InitInterpreter());
-
- SUPPORT_ASSERT_OK_AND_ASSIGN(auto preprocessor,
- - ImagePreprocessor::Create(engine_.get(), {0}));
- + ImagePreprocessor::Create(engine_.get(), {0}));
-
- SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image, LoadImage("burger.jpg"));
- std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbRawBuffer(
- @@ -94,9 +94,10 @@ TEST_F(DynamicInputTest, GoldenImageComparison) {
- PreprocessImage();
-
- // Get the processed input image.
- - SUPPORT_ASSERT_OK_AND_ASSIGN(float* processed_input_data,
- - tflite::task::core::AssertAndReturnTypedTensor<float>(
- - engine_->GetInputs()[0]));
- + SUPPORT_ASSERT_OK_AND_ASSIGN(
- + float* processed_input_data,
- + tflite::task::core::AssertAndReturnTypedTensor<float>(
- + engine_->GetInputs()[0]));
-
- SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image, LoadImage("burger.jpg"));
- const uint8* image_data = image.pixel_data;
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/bert_nl_classifier_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/bert_nl_classifier_test.cc
- index 629f069e7b8d1..c4a8cea0d53b9 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/bert_nl_classifier_test.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/bert_nl_classifier_test.cc
- @@ -49,8 +49,7 @@ constexpr char kInvalidModelPath[] = "i/do/not/exist.tflite";
- constexpr int kMaxSeqLen = 128;
-
- std::string GetFullPath(absl::string_view file_name) {
- - return JoinPath("./" /*test src dir*/, kTestDataDirectory,
- - file_name);
- + return JoinPath("./" /*test src dir*/, kTestDataDirectory, file_name);
- }
-
- class BertNLClassifierTest : public tflite_shims::testing::Test {};
- @@ -77,14 +76,15 @@ TEST_F(BertNLClassifierTest, CreateFromOptionsFailsWithMissingBaseOptions) {
- }
-
- TEST_F(BertNLClassifierTest, TestNLClassifierCreationFilePath) {
- - SUPPORT_ASSERT_OK(BertNLClassifier::CreateFromFile(GetFullPath(kTestModelPath)));
- + SUPPORT_ASSERT_OK(
- + BertNLClassifier::CreateFromFile(GetFullPath(kTestModelPath)));
- }
-
- TEST_F(BertNLClassifierTest, TestNLClassifierCreationBinary) {
- std::string model_buffer =
- LoadBinaryContent(GetFullPath(kTestModelPath).c_str());
- SUPPORT_ASSERT_OK(BertNLClassifier::CreateFromBuffer(model_buffer.data(),
- - model_buffer.size()));
- + model_buffer.size()));
- }
-
- TEST_F(BertNLClassifierTest, TestNLClassifierCreationFailure) {
- @@ -136,7 +136,7 @@ TEST_F(BertNLClassifierTest, ClassifySucceedsWithBaseOptions) {
- contents);
-
- SUPPORT_ASSERT_OK_AND_ASSIGN(classifier,
- - BertNLClassifier::CreateFromOptions(options));
- + BertNLClassifier::CreateFromOptions(options));
- }
-
- verify_classifier(std::move(classifier), /*verify_positive=*/false);
- @@ -146,8 +146,8 @@ TEST_F(BertNLClassifierTest, TestNLClassifier_ClassifyNegative) {
- std::string model_buffer =
- LoadBinaryContent(GetFullPath(kTestModelPath).c_str());
- SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<BertNLClassifier> classifier,
- - BertNLClassifier::CreateFromBuffer(model_buffer.data(),
- - model_buffer.size()));
- + BertNLClassifier::CreateFromBuffer(
- + model_buffer.data(), model_buffer.size()));
-
- verify_classifier(std::move(classifier), false);
- }
- @@ -156,24 +156,26 @@ TEST_F(BertNLClassifierTest, TestNLClassifier_ClassifyPositive) {
- std::string model_buffer =
- LoadBinaryContent(GetFullPath(kTestModelPath).c_str());
- SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<BertNLClassifier> classifier,
- - BertNLClassifier::CreateFromBuffer(model_buffer.data(),
- - model_buffer.size()));
- + BertNLClassifier::CreateFromBuffer(
- + model_buffer.data(), model_buffer.size()));
-
- verify_classifier(std::move(classifier), true);
- }
-
- TEST_F(BertNLClassifierTest, TestNLClassifierFd_ClassifyPositive) {
- - SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<BertNLClassifier> classifier,
- - BertNLClassifier::CreateFromFd(open(
- - GetFullPath(kTestModelPath).c_str(), O_RDONLY)));
- + SUPPORT_ASSERT_OK_AND_ASSIGN(
- + std::unique_ptr<BertNLClassifier> classifier,
- + BertNLClassifier::CreateFromFd(
- + open(GetFullPath(kTestModelPath).c_str(), O_RDONLY)));
-
- verify_classifier(std::move(classifier), false);
- }
-
- TEST_F(BertNLClassifierTest, TestNLClassifierFd_ClassifyNegative) {
- - SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<BertNLClassifier> classifier,
- - BertNLClassifier::CreateFromFd(open(
- - GetFullPath(kTestModelPath).c_str(), O_RDONLY)));
- + SUPPORT_ASSERT_OK_AND_ASSIGN(
- + std::unique_ptr<BertNLClassifier> classifier,
- + BertNLClassifier::CreateFromFd(
- + open(GetFullPath(kTestModelPath).c_str(), O_RDONLY)));
-
- verify_classifier(std::move(classifier), true);
- }
- @@ -191,8 +193,8 @@ TEST_F(BertNLClassifierTest, TestNLClassifier_ClassifyLongPositive_notOOB) {
- }
- ss_for_positive_review << " movie review";
- SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<BertNLClassifier> classifier,
- - BertNLClassifier::CreateFromBuffer(model_buffer.data(),
- - model_buffer.size()));
- + BertNLClassifier::CreateFromBuffer(
- + model_buffer.data(), model_buffer.size()));
-
- std::vector<core::Category> results =
- classifier->Classify(ss_for_positive_review.str());
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/bert_question_answerer_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/bert_question_answerer_test.cc
- index 252441df1cb59..a70dab7782044 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/bert_question_answerer_test.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/bert_question_answerer_test.cc
- @@ -69,8 +69,7 @@ constexpr int kPredictAnsNum = 5;
- class BertQuestionAnswererTest : public tflite_shims::testing::Test {};
-
- std::string GetFullPath(absl::string_view file_name) {
- - return JoinPath("./" /*test src dir*/, kTestDataDirectory,
- - file_name);
- + return JoinPath("./" /*test src dir*/, kTestDataDirectory, file_name);
- }
-
- TEST_F(BertQuestionAnswererTest,
- @@ -108,8 +107,8 @@ TEST_F(BertQuestionAnswererTest, AnswerSucceedsWithModelWithMetadata) {
- options.mutable_base_options()->mutable_model_file()->set_file_content(
- contents);
-
- - SUPPORT_ASSERT_OK_AND_ASSIGN(question_answerer,
- - BertQuestionAnswerer::CreateFromOptions(options));
- + SUPPORT_ASSERT_OK_AND_ASSIGN(
- + question_answerer, BertQuestionAnswerer::CreateFromOptions(options));
- }
-
- std::vector<QaAnswer> answer = question_answerer->Answer(kContext, kQuestion);
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/clu_lib/bert_utils_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/clu_lib/bert_utils_test.cc
- index 3d98fe16b07e9..6fd9508fd1ba0 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/clu_lib/bert_utils_test.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/clu_lib/bert_utils_test.cc
- @@ -128,10 +128,10 @@ TEST_F(BertUtilsTestClass, ZeroHistoryNotTrucated) {
- std::vector<int> subword_indicators;
- std::vector<int> segment_id_list;
- std::vector<int> turn_id_list;
- - SUPPORT_ASSERT_OK(BertPreprocessing(tokenizer_.get(), conversations_in_reverse_order,
- - max_seq_length, max_history_turns, &token_ids,
- - &token_alignments, &subword_indicators,
- - &segment_id_list, &turn_id_list));
- + SUPPORT_ASSERT_OK(BertPreprocessing(
- + tokenizer_.get(), conversations_in_reverse_order, max_seq_length,
- + max_history_turns, &token_ids, &token_alignments, &subword_indicators,
- + &segment_id_list, &turn_id_list));
- EXPECT_THAT(token_ids, expected_token_ids);
- EXPECT_THAT(token_alignments, expected_token_alignments);
- EXPECT_THAT(subword_indicators, expected_first_subword_indicators);
- @@ -193,10 +193,10 @@ TEST_F(BertUtilsTestClass, ZeroHistoryTrucated) {
- std::vector<int> subword_indicators;
- std::vector<int> segment_id_list;
- std::vector<int> turn_id_list;
- - SUPPORT_ASSERT_OK(BertPreprocessing(tokenizer_.get(), conversations_in_reverse_order,
- - max_seq_length, max_history_turns, &token_ids,
- - &token_alignments, &subword_indicators,
- - &segment_id_list, &turn_id_list));
- + SUPPORT_ASSERT_OK(BertPreprocessing(
- + tokenizer_.get(), conversations_in_reverse_order, max_seq_length,
- + max_history_turns, &token_ids, &token_alignments, &subword_indicators,
- + &segment_id_list, &turn_id_list));
- EXPECT_THAT(token_ids, expected_token_ids);
- EXPECT_THAT(token_alignments, expected_token_alignments);
- EXPECT_THAT(subword_indicators, expected_first_subword_indicators);
- @@ -342,10 +342,10 @@ TEST_F(BertUtilsTestClass, WithHistoryNotTrucated) {
- std::vector<int> subword_indicators;
- std::vector<int> segment_id_list;
- std::vector<int> turn_id_list;
- - SUPPORT_ASSERT_OK(BertPreprocessing(tokenizer_.get(), conversations_in_reverse_order,
- - max_seq_length, max_history_turns, &token_ids,
- - &token_alignments, &subword_indicators,
- - &segment_id_list, &turn_id_list));
- + SUPPORT_ASSERT_OK(BertPreprocessing(
- + tokenizer_.get(), conversations_in_reverse_order, max_seq_length,
- + max_history_turns, &token_ids, &token_alignments, &subword_indicators,
- + &segment_id_list, &turn_id_list));
- EXPECT_THAT(token_ids, expected_token_ids);
- EXPECT_THAT(token_alignments, expected_token_alignments);
- EXPECT_THAT(subword_indicators, expected_first_subword_indicators);
- @@ -458,10 +458,10 @@ TEST_F(BertUtilsTestClass, WithHistoryTrucated) {
- std::vector<int> subword_indicators;
- std::vector<int> segment_id_list;
- std::vector<int> turn_id_list;
- - SUPPORT_ASSERT_OK(BertPreprocessing(tokenizer_.get(), conversations_in_reverse_order,
- - max_seq_length, max_history_turns, &token_ids,
- - &token_alignments, &subword_indicators,
- - &segment_id_list, &turn_id_list));
- + SUPPORT_ASSERT_OK(BertPreprocessing(
- + tokenizer_.get(), conversations_in_reverse_order, max_seq_length,
- + max_history_turns, &token_ids, &token_alignments, &subword_indicators,
- + &segment_id_list, &turn_id_list));
- EXPECT_THAT(token_ids, expected_token_ids);
- EXPECT_THAT(token_alignments, expected_token_alignments);
- EXPECT_THAT(subword_indicators, expected_first_subword_indicators);
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/clu_lib/intent_repr_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/clu_lib/intent_repr_test.cc
- index 8341751bbbac2..0501ec4a669b5 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/clu_lib/intent_repr_test.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/clu_lib/intent_repr_test.cc
- @@ -29,7 +29,7 @@ TEST(IntentClassification, IntentRepr) {
-
- TEST(IntentClassification, IntentRepr2) {
- SUPPORT_ASSERT_OK_AND_ASSIGN(const auto intent_repr,
- - IntentRepr::CreateFromFullName("REQUEST"));
- + IntentRepr::CreateFromFullName("REQUEST"));
- EXPECT_EQ(intent_repr.Name(), "REQUEST");
- EXPECT_EQ(intent_repr.Domain(), "");
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/nlclassifier/nl_classifier_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/nlclassifier/nl_classifier_test.cc
- index 67b03c3a45323..81198cfca30fc 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/nlclassifier/nl_classifier_test.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/nlclassifier/nl_classifier_test.cc
- @@ -121,8 +121,7 @@ struct ProtoOptionsTestParam {
- };
-
- std::string GetFullPath(absl::string_view file_name) {
- - return JoinPath("./" /*test src dir*/, kTestDataDirectory,
- - file_name);
- + return JoinPath("./" /*test src dir*/, kTestDataDirectory, file_name);
- }
-
- class ProtoOptionsTest : public TestWithParam<ProtoOptionsTestParam> {
- @@ -163,7 +162,8 @@ TEST_F(ProtoOptionsTest, ClassifySucceedsWithBaseOptions) {
- options.mutable_base_options()->mutable_model_file()->set_file_content(
- contents);
-
- - SUPPORT_ASSERT_OK_AND_ASSIGN(classifier, NLClassifier::CreateFromOptions(options));
- + SUPPORT_ASSERT_OK_AND_ASSIGN(classifier,
- + NLClassifier::CreateFromOptions(options));
- }
-
- std::vector<core::Category> positive_results =
- @@ -180,8 +180,8 @@ TEST_F(ProtoOptionsTest, ClassifySucceedsWithBaseOptions) {
-
- TEST_F(ProtoOptionsTest, CreationFromIncorrectInputTensor) {
- NLClassifierProtoOptions options;
- - options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath(
- - "./" /*test src dir*/, kTestDataDirectory, kTestModelPath));
- + options.mutable_base_options()->mutable_model_file()->set_file_name(
- + JoinPath("./" /*test src dir*/, kTestDataDirectory, kTestModelPath));
- options.set_input_tensor_name("invalid_tensor_name");
- options.set_input_tensor_index(-1);
-
- @@ -200,8 +200,8 @@ TEST_F(ProtoOptionsTest, CreationFromIncorrectInputTensor) {
-
- TEST_F(ProtoOptionsTest, CreationFromIncorrectOutputScoreTensor) {
- NLClassifierProtoOptions options;
- - options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath(
- - "./" /*test src dir*/, kTestDataDirectory, kTestModelPath));
- + options.mutable_base_options()->mutable_model_file()->set_file_name(
- + JoinPath("./" /*test src dir*/, kTestDataDirectory, kTestModelPath));
- options.set_output_score_tensor_name("invalid_tensor_name");
- options.set_output_score_tensor_index(-1);
-
- @@ -224,7 +224,7 @@ TEST_F(ProtoOptionsTest, TestInferenceWithRegexTokenizer) {
- options.mutable_base_options()->mutable_model_file()->set_file_name(
- GetFullPath(kTestModelWithRegexTokenizer));
- SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<NLClassifier> classifier,
- - NLClassifier::CreateFromOptions(options));
- + NLClassifier::CreateFromOptions(options));
-
- std::vector<core::Category> positive_results =
- classifier->Classify(kPositiveInput);
- @@ -277,7 +277,7 @@ TEST_F(ProtoOptionsTest, TestInferenceWithAssociatedLabelBuiltinOps) {
- options.mutable_base_options()->mutable_model_file()->set_file_name(
- GetFullPath(kTestModelWithLabelBuiltInOpsPath));
- SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<NLClassifier> classifier,
- - NLClassifier::CreateFromOptions(options));
- + NLClassifier::CreateFromOptions(options));
- std::vector<core::Category> results = classifier->Classify(kInputStr);
- std::vector<core::Category> expected_class = {
- {"Negative", 0.49332118034362793},
- @@ -296,8 +296,10 @@ struct ProtoOptionsTestParamToString {
- };
-
- NLClassifierProtoOptions CreateProtoOptionsFromTensorName(
- - const char* input_tensor_name, const char* output_score_tensor_name,
- - const char* output_label_tensor_name, const char* model_path) {
- + const char* input_tensor_name,
- + const char* output_score_tensor_name,
- + const char* output_label_tensor_name,
- + const char* model_path) {
- NLClassifierProtoOptions options;
- options.set_input_tensor_name(input_tensor_name);
- options.set_output_score_tensor_name(output_score_tensor_name);
- @@ -310,8 +312,10 @@ NLClassifierProtoOptions CreateProtoOptionsFromTensorName(
- }
-
- NLClassifierProtoOptions CreateProtoOptionsFromTensorIndex(
- - const int input_tensor_index, const int output_score_tensor_index,
- - const int output_label_tensor_index, const char* model_path) {
- + const int input_tensor_index,
- + const int output_score_tensor_index,
- + const int output_label_tensor_index,
- + const char* model_path) {
- NLClassifierProtoOptions options;
- options.set_input_tensor_index(input_tensor_index);
- options.set_output_score_tensor_index(output_score_tensor_index);
- @@ -439,14 +443,16 @@ TEST_P(ProtoOptionsTest, TestClassify) {
- EXPECT_THAT(results, UnorderedElementsAreArray(expected_class));
- }
-
- -INSTANTIATE_TEST_SUITE_P(TestClassify, ProtoOptionsTest,
- +INSTANTIATE_TEST_SUITE_P(TestClassify,
- + ProtoOptionsTest,
- ValuesIn(ClassifyParams()),
- ProtoOptionsTestParamToString());
-
- // Tests for struct sNLClassifierOptions.
- class StructOptionsTest : public tflite_shims::testing::Test {};
-
- -void AssertStatus(absl::Status status, absl::StatusCode status_code,
- +void AssertStatus(absl::Status status,
- + absl::StatusCode status_code,
- TfLiteSupportStatus tfls_code) {
- ASSERT_EQ(status.code(), status_code);
- EXPECT_THAT(status.GetPayload(kTfLiteSupportPayload),
- @@ -454,30 +460,29 @@ void AssertStatus(absl::Status status, absl::StatusCode status_code,
- }
-
- TEST_F(StructOptionsTest, TestApiCreationFromBuffer) {
- - std::string model_buffer =
- - LoadBinaryContent(JoinPath("./" /*test src dir*/,
- - kTestDataDirectory, kTestModelPath)
- - .c_str());
- + std::string model_buffer = LoadBinaryContent(
- + JoinPath("./" /*test src dir*/, kTestDataDirectory, kTestModelPath)
- + .c_str());
- SUPPORT_ASSERT_OK(NLClassifier::CreateFromBufferAndOptions(
- model_buffer.data(), model_buffer.size(), {}, CreateCustomResolver()));
- }
-
- TEST_F(StructOptionsTest, TestApiCreationFromFile) {
- - SUPPORT_ASSERT_OK(NLClassifier::CreateFromFileAndOptions(GetFullPath(kTestModelPath),
- - {}, CreateCustomResolver()));
- + SUPPORT_ASSERT_OK(NLClassifier::CreateFromFileAndOptions(
- + GetFullPath(kTestModelPath), {}, CreateCustomResolver()));
- }
-
- TEST_F(StructOptionsTest, TestApiCreationFromIncorrectInputTensor) {
- NLClassifierOptions options;
- options.input_tensor_index = -1;
- options.input_tensor_name = "I do not exist";
- - AssertStatus(NLClassifier::CreateFromFileAndOptions(
- - JoinPath("./" /*test src dir*/,
- - kTestDataDirectory, kTestModelPath),
- - options, CreateCustomResolver())
- - .status(),
- - absl::StatusCode::kInvalidArgument,
- - TfLiteSupportStatus::kInputTensorNotFoundError);
- + AssertStatus(
- + NLClassifier::CreateFromFileAndOptions(
- + JoinPath("./" /*test src dir*/, kTestDataDirectory, kTestModelPath),
- + options, CreateCustomResolver())
- + .status(),
- + absl::StatusCode::kInvalidArgument,
- + TfLiteSupportStatus::kInputTensorNotFoundError);
- }
-
- TEST_F(StructOptionsTest, TestApiCreationFromIncorrectOutputScoreTensor) {
- @@ -497,9 +502,10 @@ TEST_F(StructOptionsTest, TestInferenceWithRegexTokenizer) {
- options.output_score_tensor_name = "probability";
-
- // The model with regex tokenizer doesn't need any custom ops.
- - SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<NLClassifier> classifier,
- - NLClassifier::CreateFromFileAndOptions(
- - GetFullPath(kTestModelWithRegexTokenizer), options));
- + SUPPORT_ASSERT_OK_AND_ASSIGN(
- + std::unique_ptr<NLClassifier> classifier,
- + NLClassifier::CreateFromFileAndOptions(
- + GetFullPath(kTestModelWithRegexTokenizer), options));
-
- std::vector<core::Category> positive_results =
- classifier->Classify(kPositiveInput);
- @@ -519,9 +525,9 @@ TEST_F(StructOptionsTest, TestInferenceWithBoolOutput) {
- options.output_score_tensor_index = 0;
-
- SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<NLClassifier> classifier,
- - NLClassifier::CreateFromFileAndOptions(
- - GetFullPath(kTestModelBoolOutputPath), options,
- - CreateCustomResolver()));
- + NLClassifier::CreateFromFileAndOptions(
- + GetFullPath(kTestModelBoolOutputPath),
- + options, CreateCustomResolver()));
- std::vector<core::Category> results = classifier->Classify(kInputStr);
- std::vector<core::Category> expected_class = {
- {"0", 1},
- @@ -535,10 +541,11 @@ TEST_F(StructOptionsTest, TestInferenceWithBoolOutput) {
- TEST_F(StructOptionsTest, TestInferenceWithAssociatedLabelCustomOps) {
- NLClassifierOptions options;
- options.output_score_tensor_name = kMetadataOutputScoreTensorName;
- - SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<NLClassifier> classifier,
- - NLClassifier::CreateFromFileAndOptions(
- - GetFullPath(kTestModelWithLabelCustomOpsPath),
- - options, CreateCustomResolver()));
- + SUPPORT_ASSERT_OK_AND_ASSIGN(
- + std::unique_ptr<NLClassifier> classifier,
- + NLClassifier::CreateFromFileAndOptions(
- + GetFullPath(kTestModelWithLabelCustomOpsPath), options,
- + CreateCustomResolver()));
- std::vector<core::Category> results = classifier->Classify(kInputStr);
- std::vector<core::Category> expected_class = {
- {"label0", 255},
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/text_embedder_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/text_embedder_test.cc
- index 5a86a288b4624..b097813ecedf7 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/text_embedder_test.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/text_embedder_test.cc
- @@ -17,7 +17,7 @@ limitations under the License.
-
- #include <iostream>
-
- -#include "absl/status/status.h" // from @com_google_absl
- +#include "absl/status/status.h" // from @com_google_absl
- #include "absl/strings/string_view.h" // from @com_google_absl
- #include "tensorflow/lite/core/shims/cc/shims_test_util.h"
- #include "tensorflow_lite_support/cc/port/gmock.h"
- @@ -56,8 +56,8 @@ class CreateFromOptionsTest : public tflite_shims::testing::Test {};
-
- TextEmbedderOptions GetBasicOptions(absl::string_view model_name) {
- TextEmbedderOptions options;
- - options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath(
- - "./" /*test src dir*/, kTestDataDirectory, model_name));
- + options.mutable_base_options()->mutable_model_file()->set_file_name(
- + JoinPath("./" /*test src dir*/, kTestDataDirectory, model_name));
- return options;
- }
-
- @@ -130,7 +130,7 @@ TEST(EmbedTest, SucceedsWithMobileBertModel) {
- TextEmbedderOptions options = GetBasicOptions(kMobileBert);
- // No Embedding options means all head get a default option.
- SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextEmbedder> text_embedder,
- - TextEmbedder::CreateFromOptions(options));
- + TextEmbedder::CreateFromOptions(options));
-
- SUPPORT_ASSERT_OK_AND_ASSIGN(
- auto result0,
- @@ -141,8 +141,8 @@ TEST(EmbedTest, SucceedsWithMobileBertModel) {
- EXPECT_NEAR(result0.embeddings(0).feature_vector().value_float(0), 19.9016f,
- kValueDiffTolerance);
-
- - SUPPORT_ASSERT_OK_AND_ASSIGN(auto result1,
- - text_embedder->Embed("what a great and fantastic trip"));
- + SUPPORT_ASSERT_OK_AND_ASSIGN(
- + auto result1, text_embedder->Embed("what a great and fantastic trip"));
- EXPECT_EQ(result1.embeddings_size(), 1);
- EXPECT_EQ(result1.embeddings(0).feature_vector().value_float_size(), 512);
-
- @@ -162,7 +162,7 @@ TEST(EmbedTest, SucceedsWithRegexModel) {
- TextEmbedderOptions options = GetBasicOptions(kRegexOneEmbeddingModel);
- // No Embedding options means all head get a default option.
- SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextEmbedder> text_embedder,
- - TextEmbedder::CreateFromOptions(options));
- + TextEmbedder::CreateFromOptions(options));
-
- SUPPORT_ASSERT_OK_AND_ASSIGN(
- auto result0,
- @@ -173,8 +173,8 @@ TEST(EmbedTest, SucceedsWithRegexModel) {
- EXPECT_NEAR(result0.embeddings(0).feature_vector().value_float(0), 0.0309356f,
- kValueDiffTolerance);
-
- - SUPPORT_ASSERT_OK_AND_ASSIGN(auto result1,
- - text_embedder->Embed("what a great and fantastic trip"));
- + SUPPORT_ASSERT_OK_AND_ASSIGN(
- + auto result1, text_embedder->Embed("what a great and fantastic trip"));
- EXPECT_EQ(result1.embeddings_size(), 1);
- EXPECT_EQ(result1.embeddings(0).feature_vector().value_float_size(), 16);
-
- @@ -206,8 +206,8 @@ TEST(EmbedTest, SucceedsWithUniversalSentenceEncoder) {
- EXPECT_NEAR(result0.embeddings(0).feature_vector().value_float(0), 1.422951f,
- kValueDiffTolerance);
-
- - SUPPORT_ASSERT_OK_AND_ASSIGN(auto result1,
- - text_embedder->Embed("what a great and fantastic trip"));
- + SUPPORT_ASSERT_OK_AND_ASSIGN(
- + auto result1, text_embedder->Embed("what a great and fantastic trip"));
- EXPECT_EQ(result1.embeddings_size(), 1);
- EXPECT_EQ(result1.embeddings(0).feature_vector().value_float_size(), 100);
-
- @@ -227,7 +227,7 @@ TEST(GetEmbeddingDimension, Succeeds) {
- // Create embedder.
- TextEmbedderOptions options = GetBasicOptions(kMobileBert);
- SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextEmbedder> text_embedder,
- - TextEmbedder::CreateFromOptions(options));
- + TextEmbedder::CreateFromOptions(options));
-
- EXPECT_EQ(text_embedder->GetEmbeddingDimension(0), 512);
- EXPECT_EQ(text_embedder->GetEmbeddingDimension(1), -1);
- @@ -238,7 +238,7 @@ TEST(GetNumberOfOutputLayers, Succeeds) {
- TextEmbedderOptions options = GetBasicOptions(kMobileBert);
- // No Embedding options means all head get a default option.
- SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextEmbedder> text_embedder,
- - TextEmbedder::CreateFromOptions(options));
- + TextEmbedder::CreateFromOptions(options));
- EXPECT_EQ(text_embedder->GetNumberOfOutputLayers(), kNumberOfOutputLayers);
- }
-
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/text_searcher_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/text_searcher_test.cc
- index fec09a1ad77cc..f38615c5b3092 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/text_searcher_test.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/text_searcher_test.cc
- @@ -18,9 +18,9 @@ limitations under the License.
- #include <memory>
- #include <string>
-
- -#include "absl/flags/flag.h" // from @com_google_absl
- -#include "absl/status/status.h" // from @com_google_absl
- -#include "absl/strings/cord.h" // from @com_google_absl
- +#include "absl/flags/flag.h" // from @com_google_absl
- +#include "absl/status/status.h" // from @com_google_absl
- +#include "absl/strings/cord.h" // from @com_google_absl
- #include "absl/strings/str_cat.h" // from @com_google_absl
- #include "tensorflow/lite/core/api/op_resolver.h"
- #include "tensorflow/lite/core/shims/cc/shims_test_util.h"
- @@ -219,7 +219,8 @@ TEST_P(CreateFromOptionsTest, FailsWithInvalidMaxResults) {
- }
-
- INSTANTIATE_TEST_SUITE_P(
- - CreateFromOptionsTest, CreateFromOptionsTest,
- + CreateFromOptionsTest,
- + CreateFromOptionsTest,
- Values(CreateFromOptionsParams{.name = "Bert",
- .embedder_model_name = kMobileBertEmbedder,
- .searcher_model_name = kMobileBertSearcher,
- @@ -267,7 +268,7 @@ TEST_P(SearchTest, SucceedsWithStandaloneIndex) {
-
- // Perform search.
- SUPPORT_ASSERT_OK_AND_ASSIGN(const SearchResult& result,
- - searcher->Search("The weather was excellent."));
- + searcher->Search("The weather was excellent."));
-
- // Check results.
- ExpectApproximatelyEqual(
- @@ -288,7 +289,7 @@ TEST_P(SearchTest, SucceedsWithMetadataIndex) {
-
- // Perform search.
- SUPPORT_ASSERT_OK_AND_ASSIGN(const SearchResult& result,
- - searcher->Search("The weather was excellent."));
- + searcher->Search("The weather was excellent."));
-
- // Check results.
- ExpectApproximatelyEqual(
- @@ -313,7 +314,7 @@ TEST_P(SearchTest, SucceedsWithMaxResults) {
-
- // Perform search.
- SUPPORT_ASSERT_OK_AND_ASSIGN(const SearchResult& result,
- - searcher->Search("The weather was excellent."));
- + searcher->Search("The weather was excellent."));
-
- // Check results.
- SearchResult all_results =
- @@ -327,7 +328,8 @@ TEST_P(SearchTest, SucceedsWithMaxResults) {
- }
-
- INSTANTIATE_TEST_SUITE_P(
- - SearchTest, SearchTest,
- + SearchTest,
- + SearchTest,
- Values(
- SearchParams{
- .name = "Bert",
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/universal_sentence_encoder_qa_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/universal_sentence_encoder_qa_test.cc
- index 2529060cab275..5f0535b5c1438 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/universal_sentence_encoder_qa_test.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/universal_sentence_encoder_qa_test.cc
- @@ -77,8 +77,7 @@ class UniversalSentenceEncoderQATest : public tflite_shims::testing::Test {
- public:
- UniversalSentenceEncoderQATest() {
- // Load model file, and create qa client.
- - const auto filename =
- - JoinPath("./" /*test src dir*/, kTestUseQaModelDir);
- + const auto filename = JoinPath("./" /*test src dir*/, kTestUseQaModelDir);
- RetrievalOptions options;
- options.mutable_base_options()->mutable_model_file()->set_file_name(
- filename);
- @@ -96,7 +95,7 @@ class UniversalSentenceEncoderQATest : public tflite_shims::testing::Test {
- TEST_F(UniversalSentenceEncoderQATest, TestEncodeQuery) {
- ASSERT_TRUE(qa_client_ != nullptr);
- SUPPORT_ASSERT_OK_AND_ASSIGN(const auto encoded_question,
- - qa_client_->EncodeQuery(kQuery));
- + qa_client_->EncodeQuery(kQuery));
- EXPECT_EQ(UniversalSentenceEncoderQA::kFinalEmbeddingSize,
- encoded_question.value_float_size());
-
- @@ -107,7 +106,7 @@ TEST_F(UniversalSentenceEncoderQATest, TestEncodeQuery) {
- TEST_F(UniversalSentenceEncoderQATest, TestEncodeResponse) {
- ASSERT_TRUE(qa_client_ != nullptr);
- SUPPORT_ASSERT_OK_AND_ASSIGN(const auto encoded_response,
- - qa_client_->EncodeResponse(kResponse, kContext));
- + qa_client_->EncodeResponse(kResponse, kContext));
- EXPECT_EQ(UniversalSentenceEncoderQA::kFinalEmbeddingSize,
- encoded_response.value_float_size());
-
- @@ -208,13 +207,14 @@ TEST_F(UniversalSentenceEncoderQATest, TestRetrieveWithEncoding) {
- ASSERT_TRUE(qa_client_ != nullptr);
- RetrievalInput input;
- input.set_query_text(kQueryComp);
- - SUPPORT_ASSERT_OK_AND_ASSIGN(const auto& query, qa_client_->EncodeQuery(kQueryComp));
- + SUPPORT_ASSERT_OK_AND_ASSIGN(const auto& query,
- + qa_client_->EncodeQuery(kQueryComp));
- SUPPORT_ASSERT_OK_AND_ASSIGN(const auto& resp0,
- - qa_client_->EncodeResponse(kResponseComp0, ""));
- + qa_client_->EncodeResponse(kResponseComp0, ""));
- SUPPORT_ASSERT_OK_AND_ASSIGN(const auto& resp1,
- - qa_client_->EncodeResponse(kResponseComp1, ""));
- + qa_client_->EncodeResponse(kResponseComp1, ""));
- SUPPORT_ASSERT_OK_AND_ASSIGN(const auto& resp2,
- - qa_client_->EncodeResponse(kResponseComp2, ""));
- + qa_client_->EncodeResponse(kResponseComp2, ""));
- *input.mutable_responses()->Add()->mutable_text_encoding() = resp0;
- *input.mutable_responses()->Add()->mutable_text_encoding() = resp1;
- *input.mutable_responses()->Add()->mutable_text_encoding() = resp2;
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_classifier_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_classifier_test.cc
- index 6a0ce66dde9b5..2daf293b48f05 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_classifier_test.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_classifier_test.cc
- @@ -17,9 +17,9 @@ limitations under the License.
-
- #include <memory>
-
- -#include "absl/flags/flag.h" // from @com_google_absl
- +#include "absl/flags/flag.h" // from @com_google_absl
- #include "absl/status/status.h" // from @com_google_absl
- -#include "absl/strings/cord.h" // from @com_google_absl
- +#include "absl/strings/cord.h" // from @com_google_absl
- #include "tensorflow/lite/c/common.h"
- #include "tensorflow/lite/core/shims/cc/shims_test_util.h"
- #include "tensorflow/lite/kernels/builtin_op_kernels.h"
- @@ -70,8 +70,8 @@ constexpr char kMobileNetQuantizedWithMetadata[] =
- constexpr char kAutoMLModelWithMetadata[] = "automl_labeler_model.tflite";
-
- StatusOr<ImageData> LoadImage(std::string image_name) {
- - return DecodeImageFromFile(JoinPath("./" /*test src dir*/,
- - kTestDataDirectory, image_name));
- + return DecodeImageFromFile(
- + JoinPath("./" /*test src dir*/, kTestDataDirectory, image_name));
- }
-
- // If the proto definition changes, please also change this function.
- @@ -159,9 +159,8 @@ TEST_F(CreateFromOptionsTest, FailsWithTwoModelSources) {
- options.mutable_model_file_with_metadata()->set_file_name(
- JoinPath("./" /*test src dir*/, kTestDataDirectory,
- kMobileNetQuantizedWithMetadata));
- - options.mutable_base_options()->mutable_model_file()->set_file_name(
- - JoinPath("./" /*test src dir*/, kTestDataDirectory,
- - kMobileNetFloatWithMetadata));
- + options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath(
- + "./" /*test src dir*/, kTestDataDirectory, kMobileNetFloatWithMetadata));
-
- StatusOr<std::unique_ptr<ImageClassifier>> image_classifier_or =
- ImageClassifier::CreateFromOptions(options);
- @@ -234,9 +233,8 @@ TEST_F(CreateFromOptionsTest, FailsWithCombinedWhitelistAndBlacklist) {
- TEST_F(CreateFromOptionsTest, SucceedsWithNumberOfThreads) {
- ImageClassifierOptions options;
- options.set_num_threads(4);
- - options.mutable_model_file_with_metadata()->set_file_name(
- - JoinPath("./" /*test src dir*/, kTestDataDirectory,
- - kMobileNetFloatWithMetadata));
- + options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
- + "./" /*test src dir*/, kTestDataDirectory, kMobileNetFloatWithMetadata));
-
- SUPPORT_ASSERT_OK(ImageClassifier::CreateFromOptions(options));
- }
- @@ -248,9 +246,8 @@ INSTANTIATE_TEST_SUITE_P(Default, NumThreadsTest, testing::Values(0, -2));
- TEST_P(NumThreadsTest, FailsWithInvalidNumberOfThreads) {
- ImageClassifierOptions options;
- options.set_num_threads(GetParam());
- - options.mutable_model_file_with_metadata()->set_file_name(
- - JoinPath("./" /*test src dir*/, kTestDataDirectory,
- - kMobileNetFloatWithMetadata));
- + options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
- + "./" /*test src dir*/, kTestDataDirectory, kMobileNetFloatWithMetadata));
-
- StatusOr<std::unique_ptr<ImageClassifier>> image_classifier_or =
- ImageClassifier::CreateFromOptions(options);
- @@ -273,12 +270,12 @@ TEST(ClassifyTest, SucceedsWithFloatModel) {
-
- ImageClassifierOptions options;
- options.set_max_results(3);
- - options.mutable_model_file_with_metadata()->set_file_name(
- - JoinPath("./" /*test src dir*/, kTestDataDirectory,
- - kMobileNetFloatWithMetadata));
- + options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
- + "./" /*test src dir*/, kTestDataDirectory, kMobileNetFloatWithMetadata));
-
- - SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier,
- - ImageClassifier::CreateFromOptions(options));
- + SUPPORT_ASSERT_OK_AND_ASSIGN(
- + std::unique_ptr<ImageClassifier> image_classifier,
- + ImageClassifier::CreateFromOptions(options));
-
- StatusOr<ClassificationResult> result_or =
- image_classifier->Classify(*frame_buffer);
- @@ -307,19 +304,20 @@ TEST(ClassifyTest, SucceedsWithFloatModel) {
- }
-
- TEST(ClassifyTest, SucceedsWithRegionOfInterest) {
- - SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData rgb_image, LoadImage("multi_objects.jpg"));
- + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData rgb_image,
- + LoadImage("multi_objects.jpg"));
- std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbRawBuffer(
- rgb_image.pixel_data,
- FrameBuffer::Dimension{rgb_image.width, rgb_image.height});
-
- ImageClassifierOptions options;
- options.set_max_results(1);
- - options.mutable_model_file_with_metadata()->set_file_name(
- - JoinPath("./" /*test src dir*/, kTestDataDirectory,
- - kMobileNetFloatWithMetadata));
- + options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
- + "./" /*test src dir*/, kTestDataDirectory, kMobileNetFloatWithMetadata));
-
- - SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier,
- - ImageClassifier::CreateFromOptions(options));
- + SUPPORT_ASSERT_OK_AND_ASSIGN(
- + std::unique_ptr<ImageClassifier> image_classifier,
- + ImageClassifier::CreateFromOptions(options));
-
- // Crop around the soccer ball.
- BoundingBox roi;
- @@ -358,8 +356,9 @@ TEST(ClassifyTest, SucceedsWithQuantizedModel) {
- JoinPath("./" /*test src dir*/, kTestDataDirectory,
- kMobileNetQuantizedWithMetadata));
-
- - SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier,
- - ImageClassifier::CreateFromOptions(options));
- + SUPPORT_ASSERT_OK_AND_ASSIGN(
- + std::unique_ptr<ImageClassifier> image_classifier,
- + ImageClassifier::CreateFromOptions(options));
-
- StatusOr<ClassificationResult> result_or =
- image_classifier->Classify(*frame_buffer);
- @@ -391,12 +390,12 @@ TEST(ClassifyTest, SucceedsWithBaseOptions) {
-
- ImageClassifierOptions options;
- options.set_max_results(3);
- - options.mutable_base_options()->mutable_model_file()->set_file_name(
- - JoinPath("./" /*test src dir*/, kTestDataDirectory,
- - kMobileNetFloatWithMetadata));
- + options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath(
- + "./" /*test src dir*/, kTestDataDirectory, kMobileNetFloatWithMetadata));
-
- - SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier,
- - ImageClassifier::CreateFromOptions(options));
- + SUPPORT_ASSERT_OK_AND_ASSIGN(
- + std::unique_ptr<ImageClassifier> image_classifier,
- + ImageClassifier::CreateFromOptions(options));
-
- StatusOr<ClassificationResult> result_or =
- image_classifier->Classify(*frame_buffer);
- @@ -452,8 +451,8 @@ TEST(ClassifyTest, SucceedsWithMiniBenchmark) {
- rgb_image.pixel_data,
- FrameBuffer::Dimension{rgb_image.width, rgb_image.height});
-
- - auto file_name = JoinPath("./" /*test src dir*/,
- - kTestDataDirectory, kMobileNetFloatWithMetadata);
- + auto file_name = JoinPath("./" /*test src dir*/, kTestDataDirectory,
- + kMobileNetFloatWithMetadata);
-
- ImageClassifierOptions options;
- options.set_max_results(3);
- @@ -462,8 +461,9 @@ TEST(ClassifyTest, SucceedsWithMiniBenchmark) {
-
- ConfigureXnnPackMiniBenchmark(/*num_threads=*/2, options);
-
- - SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier,
- - ImageClassifier::CreateFromOptions(options));
- + SUPPORT_ASSERT_OK_AND_ASSIGN(
- + std::unique_ptr<ImageClassifier> image_classifier,
- + ImageClassifier::CreateFromOptions(options));
-
- StatusOr<ClassificationResult> result_or =
- image_classifier->Classify(*frame_buffer);
- @@ -493,11 +493,11 @@ TEST(ClassifyTest, SucceedsWithMiniBenchmark) {
-
- TEST(ClassifyTest, GetInputCountSucceeds) {
- ImageClassifierOptions options;
- - options.mutable_base_options()->mutable_model_file()->set_file_name(
- - JoinPath("./" /*test src dir*/, kTestDataDirectory,
- - kMobileNetFloatWithMetadata));
- - SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier,
- - ImageClassifier::CreateFromOptions(options));
- + options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath(
- + "./" /*test src dir*/, kTestDataDirectory, kMobileNetFloatWithMetadata));
- + SUPPORT_ASSERT_OK_AND_ASSIGN(
- + std::unique_ptr<ImageClassifier> image_classifier,
- + ImageClassifier::CreateFromOptions(options));
-
- int32_t input_count = image_classifier->GetInputCount();
- EXPECT_THAT(input_count, 1);
- @@ -505,11 +505,11 @@ TEST(ClassifyTest, GetInputCountSucceeds) {
-
- TEST(ClassifyTest, GetInputShapeSucceeds) {
- ImageClassifierOptions options;
- - options.mutable_base_options()->mutable_model_file()->set_file_name(
- - JoinPath("./" /*test src dir*/, kTestDataDirectory,
- - kMobileNetFloatWithMetadata));
- - SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier,
- - ImageClassifier::CreateFromOptions(options));
- + options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath(
- + "./" /*test src dir*/, kTestDataDirectory, kMobileNetFloatWithMetadata));
- + SUPPORT_ASSERT_OK_AND_ASSIGN(
- + std::unique_ptr<ImageClassifier> image_classifier,
- + ImageClassifier::CreateFromOptions(options));
-
- // Verify the shape array size.
- const TfLiteIntArray* input_shape_0 = image_classifier->GetInputShape(0);
- @@ -523,11 +523,11 @@ TEST(ClassifyTest, GetInputShapeSucceeds) {
-
- TEST(ClassifyTest, GetOutputCountSucceeds) {
- ImageClassifierOptions options;
- - options.mutable_base_options()->mutable_model_file()->set_file_name(
- - JoinPath("./" /*test src dir*/, kTestDataDirectory,
- - kMobileNetFloatWithMetadata));
- - SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier,
- - ImageClassifier::CreateFromOptions(options));
- + options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath(
- + "./" /*test src dir*/, kTestDataDirectory, kMobileNetFloatWithMetadata));
- + SUPPORT_ASSERT_OK_AND_ASSIGN(
- + std::unique_ptr<ImageClassifier> image_classifier,
- + ImageClassifier::CreateFromOptions(options));
-
- int32_t output_count = image_classifier->GetOutputCount();
- EXPECT_THAT(output_count, 1);
- @@ -535,11 +535,11 @@ TEST(ClassifyTest, GetOutputCountSucceeds) {
-
- TEST(ClassifyTest, GetOutputShapeSucceeds) {
- ImageClassifierOptions options;
- - options.mutable_base_options()->mutable_model_file()->set_file_name(
- - JoinPath("./" /*test src dir*/, kTestDataDirectory,
- - kMobileNetFloatWithMetadata));
- - SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier,
- - ImageClassifier::CreateFromOptions(options));
- + options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath(
- + "./" /*test src dir*/, kTestDataDirectory, kMobileNetFloatWithMetadata));
- + SUPPORT_ASSERT_OK_AND_ASSIGN(
- + std::unique_ptr<ImageClassifier> image_classifier,
- + ImageClassifier::CreateFromOptions(options));
-
- // Verify the shape array size.
- const TfLiteIntArray* output_shape_0 = image_classifier->GetOutputShape(0);
- @@ -604,9 +604,8 @@ class PostprocessTest : public tflite_shims::testing::Test {
-
- TEST_F(PostprocessTest, SucceedsWithMaxResultsOption) {
- ImageClassifierOptions options;
- - options.mutable_model_file_with_metadata()->set_file_name(
- - JoinPath("./" /*test src dir*/, kTestDataDirectory,
- - kAutoMLModelWithMetadata));
- + options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
- + "./" /*test src dir*/, kTestDataDirectory, kAutoMLModelWithMetadata));
- options.set_max_results(3);
-
- SetUp(options);
- @@ -618,9 +617,10 @@ TEST_F(PostprocessTest, SucceedsWithMaxResultsOption) {
- std::vector<uint8_t> scores = {/*daisy*/ 0, /*dandelion*/ 64, /*roses*/ 255,
- /*sunflowers*/ 32, /*tulips*/ 128};
- SUPPORT_ASSERT_OK(PopulateTensor(scores, output_tensor));
- - SUPPORT_ASSERT_OK_AND_ASSIGN(ClassificationResult result,
- - test_image_classifier_->Postprocess(
- - {output_tensor}, *dummy_frame_buffer_, /*roi=*/{}));
- + SUPPORT_ASSERT_OK_AND_ASSIGN(
- + ClassificationResult result,
- + test_image_classifier_->Postprocess({output_tensor}, *dummy_frame_buffer_,
- + /*roi=*/{}));
- ExpectApproximatelyEqual(
- result,
- ParseTextProtoOrDie<ClassificationResult>(
- @@ -635,9 +635,8 @@ TEST_F(PostprocessTest, SucceedsWithMaxResultsOption) {
-
- TEST_F(PostprocessTest, SucceedsWithScoreThresholdOption) {
- ImageClassifierOptions options;
- - options.mutable_model_file_with_metadata()->set_file_name(
- - JoinPath("./" /*test src dir*/, kTestDataDirectory,
- - kAutoMLModelWithMetadata));
- + options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
- + "./" /*test src dir*/, kTestDataDirectory, kAutoMLModelWithMetadata));
- options.set_score_threshold(0.4);
-
- SetUp(options);
- @@ -649,9 +648,10 @@ TEST_F(PostprocessTest, SucceedsWithScoreThresholdOption) {
- std::vector<uint8_t> scores = {/*daisy*/ 0, /*dandelion*/ 64, /*roses*/ 255,
- /*sunflowers*/ 32, /*tulips*/ 128};
- SUPPORT_ASSERT_OK(PopulateTensor(scores, output_tensor));
- - SUPPORT_ASSERT_OK_AND_ASSIGN(ClassificationResult result,
- - test_image_classifier_->Postprocess(
- - {output_tensor}, *dummy_frame_buffer_, /*roi=*/{}));
- + SUPPORT_ASSERT_OK_AND_ASSIGN(
- + ClassificationResult result,
- + test_image_classifier_->Postprocess({output_tensor}, *dummy_frame_buffer_,
- + /*roi=*/{}));
-
- ExpectApproximatelyEqual(
- result,
- @@ -666,9 +666,8 @@ TEST_F(PostprocessTest, SucceedsWithScoreThresholdOption) {
-
- TEST_F(PostprocessTest, SucceedsWithWhitelistOption) {
- ImageClassifierOptions options;
- - options.mutable_model_file_with_metadata()->set_file_name(
- - JoinPath("./" /*test src dir*/, kTestDataDirectory,
- - kAutoMLModelWithMetadata));
- + options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
- + "./" /*test src dir*/, kTestDataDirectory, kAutoMLModelWithMetadata));
- options.add_class_name_whitelist("dandelion");
- options.add_class_name_whitelist("daisy");
-
- @@ -681,9 +680,10 @@ TEST_F(PostprocessTest, SucceedsWithWhitelistOption) {
- std::vector<uint8_t> scores = {/*daisy*/ 0, /*dandelion*/ 64, /*roses*/ 255,
- /*sunflowers*/ 32, /*tulips*/ 128};
- SUPPORT_ASSERT_OK(PopulateTensor(scores, output_tensor));
- - SUPPORT_ASSERT_OK_AND_ASSIGN(ClassificationResult result,
- - test_image_classifier_->Postprocess(
- - {output_tensor}, *dummy_frame_buffer_, /*roi=*/{}));
- + SUPPORT_ASSERT_OK_AND_ASSIGN(
- + ClassificationResult result,
- + test_image_classifier_->Postprocess({output_tensor}, *dummy_frame_buffer_,
- + /*roi=*/{}));
- ExpectApproximatelyEqual(
- result,
- ParseTextProtoOrDie<ClassificationResult>(
- @@ -697,9 +697,8 @@ TEST_F(PostprocessTest, SucceedsWithWhitelistOption) {
-
- TEST_F(PostprocessTest, SucceedsWithBlacklistOption) {
- ImageClassifierOptions options;
- - options.mutable_model_file_with_metadata()->set_file_name(
- - JoinPath("./" /*test src dir*/, kTestDataDirectory,
- - kAutoMLModelWithMetadata));
- + options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
- + "./" /*test src dir*/, kTestDataDirectory, kAutoMLModelWithMetadata));
- options.add_class_name_blacklist("dandelion");
- options.add_class_name_blacklist("daisy");
-
- @@ -712,9 +711,10 @@ TEST_F(PostprocessTest, SucceedsWithBlacklistOption) {
- std::vector<uint8_t> scores = {/*daisy*/ 0, /*dandelion*/ 64, /*roses*/ 255,
- /*sunflowers*/ 32, /*tulips*/ 128};
- SUPPORT_ASSERT_OK(PopulateTensor(scores, output_tensor));
- - SUPPORT_ASSERT_OK_AND_ASSIGN(ClassificationResult result,
- - test_image_classifier_->Postprocess(
- - {output_tensor}, *dummy_frame_buffer_, /*roi=*/{}));
- + SUPPORT_ASSERT_OK_AND_ASSIGN(
- + ClassificationResult result,
- + test_image_classifier_->Postprocess({output_tensor}, *dummy_frame_buffer_,
- + /*roi=*/{}));
-
- ExpectApproximatelyEqual(
- result,
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_embedder_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_embedder_test.cc
- index 6ce017d3f1728..41226f602a26b 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_embedder_test.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_embedder_test.cc
- @@ -17,7 +17,7 @@ limitations under the License.
-
- #include <memory>
-
- -#include "absl/flags/flag.h" // from @com_google_absl
- +#include "absl/flags/flag.h" // from @com_google_absl
- #include "absl/status/status.h" // from @com_google_absl
- #include "tensorflow/lite/c/common.h"
- #include "tensorflow/lite/core/shims/cc/shims_test_util.h"
- @@ -59,8 +59,8 @@ constexpr char kMobileNetV3[] = "mobilenet_v3_small_100_224_embedder.tflite";
- constexpr double kSimilarityTolerancy = 1e-6;
-
- StatusOr<ImageData> LoadImage(std::string image_name) {
- - return DecodeImageFromFile(JoinPath("./" /*test src dir*/,
- - kTestDataDirectory, image_name));
- + return DecodeImageFromFile(
- + JoinPath("./" /*test src dir*/, kTestDataDirectory, image_name));
- }
-
- class MobileNetV3OpResolver : public ::tflite::MutableOpResolver {
- @@ -93,8 +93,8 @@ class CreateFromOptionsTest : public tflite_shims::testing::Test {};
-
- TEST_F(CreateFromOptionsTest, SucceedsWithSelectiveOpResolver) {
- ImageEmbedderOptions options;
- - options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
- - "./" /*test src dir*/, kTestDataDirectory, kMobileNetV3));
- + options.mutable_model_file_with_metadata()->set_file_name(
- + JoinPath("./" /*test src dir*/, kTestDataDirectory, kMobileNetV3));
-
- SUPPORT_ASSERT_OK(ImageEmbedder::CreateFromOptions(
- options, absl::make_unique<MobileNetV3OpResolver>()));
- @@ -113,8 +113,8 @@ class MobileNetV3OpResolverMissingOps : public ::tflite::MutableOpResolver {
-
- TEST_F(CreateFromOptionsTest, FailsWithSelectiveOpResolverMissingOps) {
- ImageEmbedderOptions options;
- - options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
- - "./" /*test src dir*/, kTestDataDirectory, kMobileNetV3));
- + options.mutable_model_file_with_metadata()->set_file_name(
- + JoinPath("./" /*test src dir*/, kTestDataDirectory, kMobileNetV3));
-
- auto image_embedder_or = ImageEmbedder::CreateFromOptions(
- options, absl::make_unique<MobileNetV3OpResolverMissingOps>());
- @@ -231,8 +231,9 @@ TEST(CosineSimilarityTest, Succeeds) {
- // Prevent literal from being interpreted as null-terminated C-style string.
- *v_quantized.mutable_value_string() = std::string("\x80\x00\x00\x00", 4);
-
- - SUPPORT_ASSERT_OK_AND_ASSIGN(double float_similarity,
- - ImageEmbedder::CosineSimilarity(u_float, v_float));
- + SUPPORT_ASSERT_OK_AND_ASSIGN(
- + double float_similarity,
- + ImageEmbedder::CosineSimilarity(u_float, v_float));
- SUPPORT_ASSERT_OK_AND_ASSIGN(
- double quantized_similarity,
- ImageEmbedder::CosineSimilarity(u_quantized, v_quantized));
- @@ -246,10 +247,10 @@ TEST(CosineSimilarityTest, Succeeds) {
- TEST(EmbedTest, SucceedsWithoutL2Normalization) {
- // Create embedder.
- ImageEmbedderOptions options;
- - options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
- - "./" /*test src dir*/, kTestDataDirectory, kMobileNetV3));
- + options.mutable_model_file_with_metadata()->set_file_name(
- + JoinPath("./" /*test src dir*/, kTestDataDirectory, kMobileNetV3));
- SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> embedder,
- - ImageEmbedder::CreateFromOptions(options));
- + ImageEmbedder::CreateFromOptions(options));
- // Load images: one is a crop of the other.
- SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image, LoadImage("burger.jpg"));
- std::unique_ptr<FrameBuffer> image_frame_buffer = CreateFromRgbRawBuffer(
- @@ -260,10 +261,10 @@ TEST(EmbedTest, SucceedsWithoutL2Normalization) {
-
- // Extract both embeddings.
- SUPPORT_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& image_result,
- - embedder->Embed(*image_frame_buffer));
- + embedder->Embed(*image_frame_buffer));
- ImageDataFree(&image);
- SUPPORT_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result,
- - embedder->Embed(*crop_frame_buffer));
- + embedder->Embed(*crop_frame_buffer));
- ImageDataFree(&crop);
-
- // Check results sizes
- @@ -276,9 +277,9 @@ TEST(EmbedTest, SucceedsWithoutL2Normalization) {
- crop_result.embeddings(0).feature_vector();
- EXPECT_EQ(crop_feature_vector.value_float_size(), 1024);
- // Check cosine similarity.
- - SUPPORT_ASSERT_OK_AND_ASSIGN(double similarity,
- - ImageEmbedder::CosineSimilarity(image_feature_vector,
- - crop_feature_vector));
- + SUPPORT_ASSERT_OK_AND_ASSIGN(
- + double similarity, ImageEmbedder::CosineSimilarity(image_feature_vector,
- + crop_feature_vector));
- double expected_similarity = 0.932738;
- EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
- }
- @@ -287,11 +288,11 @@ TEST(EmbedTest, SucceedsWithoutL2Normalization) {
- TEST(EmbedTest, SucceedsWithL2Normalization) {
- // Create embedder.
- ImageEmbedderOptions options;
- - options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
- - "./" /*test src dir*/, kTestDataDirectory, kMobileNetV3));
- + options.mutable_model_file_with_metadata()->set_file_name(
- + JoinPath("./" /*test src dir*/, kTestDataDirectory, kMobileNetV3));
- options.set_l2_normalize(true);
- SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> embedder,
- - ImageEmbedder::CreateFromOptions(options));
- + ImageEmbedder::CreateFromOptions(options));
- // Load images: one is a crop of the other.
- SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image, LoadImage("burger.jpg"));
- std::unique_ptr<FrameBuffer> image_frame_buffer = CreateFromRgbRawBuffer(
- @@ -302,10 +303,10 @@ TEST(EmbedTest, SucceedsWithL2Normalization) {
-
- // Extract both embeddings.
- SUPPORT_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& image_result,
- - embedder->Embed(*image_frame_buffer));
- + embedder->Embed(*image_frame_buffer));
- ImageDataFree(&image);
- SUPPORT_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result,
- - embedder->Embed(*crop_frame_buffer));
- + embedder->Embed(*crop_frame_buffer));
- ImageDataFree(&crop);
-
- // Check results sizes
- @@ -318,9 +319,9 @@ TEST(EmbedTest, SucceedsWithL2Normalization) {
- crop_result.embeddings(0).feature_vector();
- EXPECT_EQ(crop_feature_vector.value_float_size(), 1024);
- // Check cosine similarity.
- - SUPPORT_ASSERT_OK_AND_ASSIGN(double similarity,
- - ImageEmbedder::CosineSimilarity(image_feature_vector,
- - crop_feature_vector));
- + SUPPORT_ASSERT_OK_AND_ASSIGN(
- + double similarity, ImageEmbedder::CosineSimilarity(image_feature_vector,
- + crop_feature_vector));
- double expected_similarity = 0.932738;
- EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
- }
- @@ -331,12 +332,12 @@ TEST(EmbedTest, SucceedsWithL2Normalization) {
- TEST(EmbedTest, SucceedsWithQuantization) {
- // Create embedder.
- ImageEmbedderOptions options;
- - options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
- - "./" /*test src dir*/, kTestDataDirectory, kMobileNetV3));
- + options.mutable_model_file_with_metadata()->set_file_name(
- + JoinPath("./" /*test src dir*/, kTestDataDirectory, kMobileNetV3));
- options.set_l2_normalize(true);
- options.set_quantize(true);
- SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> embedder,
- - ImageEmbedder::CreateFromOptions(options));
- + ImageEmbedder::CreateFromOptions(options));
- // Load images: one is a crop of the other.
- SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image, LoadImage("burger.jpg"));
- std::unique_ptr<FrameBuffer> image_frame_buffer = CreateFromRgbRawBuffer(
- @@ -347,10 +348,10 @@ TEST(EmbedTest, SucceedsWithQuantization) {
-
- // Extract both embeddings.
- SUPPORT_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& image_result,
- - embedder->Embed(*image_frame_buffer));
- + embedder->Embed(*image_frame_buffer));
- ImageDataFree(&image);
- SUPPORT_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result,
- - embedder->Embed(*crop_frame_buffer));
- + embedder->Embed(*crop_frame_buffer));
- ImageDataFree(&crop);
-
- // Check results sizes
- @@ -363,9 +364,9 @@ TEST(EmbedTest, SucceedsWithQuantization) {
- crop_result.embeddings(0).feature_vector();
- EXPECT_EQ(crop_feature_vector.value_string().size(), 1024);
- // Check cosine similarity.
- - SUPPORT_ASSERT_OK_AND_ASSIGN(double similarity,
- - ImageEmbedder::CosineSimilarity(image_feature_vector,
- - crop_feature_vector));
- + SUPPORT_ASSERT_OK_AND_ASSIGN(
- + double similarity, ImageEmbedder::CosineSimilarity(image_feature_vector,
- + crop_feature_vector));
- // Close to but expectedly different from the above tests due to slight loss
- // of precision during quantization:
- double expected_similarity = 0.929717;
- @@ -378,10 +379,10 @@ TEST(EmbedTest, SucceedsWithQuantization) {
- TEST(EmbedTest, SucceedsWithRegionOfInterest) {
- // Create embedder.
- ImageEmbedderOptions options;
- - options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
- - "./" /*test src dir*/, kTestDataDirectory, kMobileNetV3));
- + options.mutable_model_file_with_metadata()->set_file_name(
- + JoinPath("./" /*test src dir*/, kTestDataDirectory, kMobileNetV3));
- SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> embedder,
- - ImageEmbedder::CreateFromOptions(options));
- + ImageEmbedder::CreateFromOptions(options));
- // Load images: one is a crop of the other.
- SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image, LoadImage("burger.jpg"));
- std::unique_ptr<FrameBuffer> image_frame_buffer = CreateFromRgbRawBuffer(
- @@ -398,10 +399,10 @@ TEST(EmbedTest, SucceedsWithRegionOfInterest) {
-
- // Extract both embeddings.
- SUPPORT_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& image_result,
- - embedder->Embed(*image_frame_buffer, roi));
- + embedder->Embed(*image_frame_buffer, roi));
- ImageDataFree(&image);
- SUPPORT_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result,
- - embedder->Embed(*crop_frame_buffer));
- + embedder->Embed(*crop_frame_buffer));
- ImageDataFree(&crop);
-
- // Check results sizes
- @@ -414,9 +415,9 @@ TEST(EmbedTest, SucceedsWithRegionOfInterest) {
- crop_result.embeddings(0).feature_vector();
- EXPECT_EQ(crop_feature_vector.value_float_size(), 1024);
- // Check cosine similarity.
- - SUPPORT_ASSERT_OK_AND_ASSIGN(double similarity,
- - ImageEmbedder::CosineSimilarity(image_feature_vector,
- - crop_feature_vector));
- + SUPPORT_ASSERT_OK_AND_ASSIGN(
- + double similarity, ImageEmbedder::CosineSimilarity(image_feature_vector,
- + crop_feature_vector));
- double expected_similarity = 0.999914;
- EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
- }
- @@ -424,10 +425,10 @@ TEST(EmbedTest, SucceedsWithRegionOfInterest) {
- TEST(GetEmbeddingDimension, Succeeds) {
- // Create embedder.
- ImageEmbedderOptions options;
- - options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
- - "./" /*test src dir*/, kTestDataDirectory, kMobileNetV3));
- + options.mutable_model_file_with_metadata()->set_file_name(
- + JoinPath("./" /*test src dir*/, kTestDataDirectory, kMobileNetV3));
- SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> embedder,
- - ImageEmbedder::CreateFromOptions(options));
- + ImageEmbedder::CreateFromOptions(options));
-
- EXPECT_EQ(embedder->GetEmbeddingDimension(0), 1024);
- EXPECT_EQ(embedder->GetEmbeddingDimension(1), -1);
- @@ -436,10 +437,10 @@ TEST(GetEmbeddingDimension, Succeeds) {
- TEST(GetNumberOfOutputLayers, Succeeds) {
- // Create embedder.
- ImageEmbedderOptions options;
- - options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
- - "./" /*test src dir*/, kTestDataDirectory, kMobileNetV3));
- + options.mutable_model_file_with_metadata()->set_file_name(
- + JoinPath("./" /*test src dir*/, kTestDataDirectory, kMobileNetV3));
- SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> embedder,
- - ImageEmbedder::CreateFromOptions(options));
- + ImageEmbedder::CreateFromOptions(options));
-
- EXPECT_EQ(embedder->GetNumberOfOutputLayers(), 1);
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_searcher_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_searcher_test.cc
- index 0b1f3b11b383c..00183eb65b5df 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_searcher_test.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_searcher_test.cc
- @@ -18,9 +18,9 @@ limitations under the License.
- #include <memory>
- #include <string>
-
- -#include "absl/flags/flag.h" // from @com_google_absl
- -#include "absl/status/status.h" // from @com_google_absl
- -#include "absl/strings/cord.h" // from @com_google_absl
- +#include "absl/flags/flag.h" // from @com_google_absl
- +#include "absl/status/status.h" // from @com_google_absl
- +#include "absl/strings/cord.h" // from @com_google_absl
- #include "absl/strings/str_cat.h" // from @com_google_absl
- #include "tensorflow/lite/core/shims/cc/shims_test_util.h"
- #include "tensorflow_lite_support/cc/common.h"
- @@ -66,8 +66,8 @@ constexpr char kMobileNetV3Searcher[] =
- "mobilenet_v3_small_100_224_searcher.tflite";
-
- StatusOr<ImageData> LoadImage(std::string image_name) {
- - return DecodeImageFromFile(JoinPath("./" /*test src dir*/,
- - kTestDataDirectory, image_name));
- + return DecodeImageFromFile(
- + JoinPath("./" /*test src dir*/, kTestDataDirectory, image_name));
- }
-
- // Checks that the two provided `SearchResult` protos are equal, with a
- @@ -88,9 +88,8 @@ class CreateFromOptionsTest : public tflite_shims::testing::Test {};
-
- TEST_F(CreateFromOptionsTest, SucceedsWithStandaloneIndex) {
- ImageSearcherOptions options;
- - options.mutable_base_options()->mutable_model_file()->set_file_name(
- - JoinPath("./" /*test src dir*/, kTestDataDirectory,
- - kMobileNetV3Embedder));
- + options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath(
- + "./" /*test src dir*/, kTestDataDirectory, kMobileNetV3Embedder));
- options.mutable_embedding_options()->set_l2_normalize(true);
- options.mutable_search_options()->mutable_index_file()->set_file_name(
- JoinPath("./" /*test src dir*/, kTestDataDirectory, kIndex));
- @@ -100,9 +99,8 @@ TEST_F(CreateFromOptionsTest, SucceedsWithStandaloneIndex) {
-
- TEST_F(CreateFromOptionsTest, SucceedsWithMetadataIndex) {
- ImageSearcherOptions options;
- - options.mutable_base_options()->mutable_model_file()->set_file_name(
- - JoinPath("./" /*test src dir*/, kTestDataDirectory,
- - kMobileNetV3Searcher));
- + options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath(
- + "./" /*test src dir*/, kTestDataDirectory, kMobileNetV3Searcher));
- options.mutable_embedding_options()->set_l2_normalize(true);
-
- SUPPORT_ASSERT_OK(ImageSearcher::CreateFromOptions(options));
- @@ -129,9 +127,8 @@ TEST_F(CreateFromOptionsTest, FailsWithMissingModel) {
-
- TEST_F(CreateFromOptionsTest, FailsWithMissingIndex) {
- ImageSearcherOptions options;
- - options.mutable_base_options()->mutable_model_file()->set_file_name(
- - JoinPath("./" /*test src dir*/, kTestDataDirectory,
- - kMobileNetV3Embedder));
- + options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath(
- + "./" /*test src dir*/, kTestDataDirectory, kMobileNetV3Embedder));
- options.mutable_embedding_options()->set_l2_normalize(true);
-
- StatusOr<std::unique_ptr<ImageSearcher>> image_searcher_or =
- @@ -151,9 +148,8 @@ TEST_F(CreateFromOptionsTest, FailsWithMissingIndex) {
-
- TEST_F(CreateFromOptionsTest, FailsWithQuantization) {
- ImageSearcherOptions options;
- - options.mutable_base_options()->mutable_model_file()->set_file_name(
- - JoinPath("./" /*test src dir*/, kTestDataDirectory,
- - kMobileNetV3Embedder));
- + options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath(
- + "./" /*test src dir*/, kTestDataDirectory, kMobileNetV3Embedder));
- options.mutable_embedding_options()->set_l2_normalize(true);
- options.mutable_embedding_options()->set_quantize(true);
- options.mutable_search_options()->mutable_index_file()->set_file_name(
- @@ -174,9 +170,8 @@ TEST_F(CreateFromOptionsTest, FailsWithQuantization) {
-
- TEST_F(CreateFromOptionsTest, FailsWithInvalidMaxResults) {
- ImageSearcherOptions options;
- - options.mutable_base_options()->mutable_model_file()->set_file_name(
- - JoinPath("./" /*test src dir*/, kTestDataDirectory,
- - kMobileNetV3Embedder));
- + options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath(
- + "./" /*test src dir*/, kTestDataDirectory, kMobileNetV3Embedder));
- options.mutable_embedding_options()->set_l2_normalize(true);
- options.mutable_search_options()->mutable_index_file()->set_file_name(
- JoinPath("./" /*test src dir*/, kTestDataDirectory, kIndex));
- @@ -197,14 +192,13 @@ TEST_F(CreateFromOptionsTest, FailsWithInvalidMaxResults) {
- TEST(SearchTest, SucceedsWithStandaloneIndex) {
- // Create Searcher.
- ImageSearcherOptions options;
- - options.mutable_base_options()->mutable_model_file()->set_file_name(
- - JoinPath("./" /*test src dir*/, kTestDataDirectory,
- - kMobileNetV3Embedder));
- + options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath(
- + "./" /*test src dir*/, kTestDataDirectory, kMobileNetV3Embedder));
- options.mutable_embedding_options()->set_l2_normalize(true);
- options.mutable_search_options()->mutable_index_file()->set_file_name(
- JoinPath("./" /*test src dir*/, kTestDataDirectory, kIndex));
- SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSearcher> searcher,
- - ImageSearcher::CreateFromOptions(options));
- + ImageSearcher::CreateFromOptions(options));
- // Load image.
- SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image, LoadImage("burger.jpg"));
- std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbRawBuffer(
- @@ -212,7 +206,7 @@ TEST(SearchTest, SucceedsWithStandaloneIndex) {
-
- // Perform search.
- SUPPORT_ASSERT_OK_AND_ASSIGN(const SearchResult& result,
- - searcher->Search(*frame_buffer));
- + searcher->Search(*frame_buffer));
- ImageDataFree(&image);
-
- // Check results.
- @@ -229,12 +223,11 @@ TEST(SearchTest, SucceedsWithStandaloneIndex) {
- TEST(SearchTest, SucceedsWithMetadataIndex) {
- // Create Searcher.
- ImageSearcherOptions options;
- - options.mutable_base_options()->mutable_model_file()->set_file_name(
- - JoinPath("./" /*test src dir*/, kTestDataDirectory,
- - kMobileNetV3Searcher));
- + options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath(
- + "./" /*test src dir*/, kTestDataDirectory, kMobileNetV3Searcher));
- options.mutable_embedding_options()->set_l2_normalize(true);
- SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSearcher> searcher,
- - ImageSearcher::CreateFromOptions(options));
- + ImageSearcher::CreateFromOptions(options));
- // Load image.
- SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image, LoadImage("burger.jpg"));
- std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbRawBuffer(
- @@ -242,7 +235,7 @@ TEST(SearchTest, SucceedsWithMetadataIndex) {
-
- // Perform search.
- SUPPORT_ASSERT_OK_AND_ASSIGN(const SearchResult& result,
- - searcher->Search(*frame_buffer));
- + searcher->Search(*frame_buffer));
- ImageDataFree(&image);
-
- // Check results.
- @@ -259,15 +252,14 @@ TEST(SearchTest, SucceedsWithMetadataIndex) {
- TEST(SearchTest, SucceedsWithMaxResults) {
- // Create Searcher.
- ImageSearcherOptions options;
- - options.mutable_base_options()->mutable_model_file()->set_file_name(
- - JoinPath("./" /*test src dir*/, kTestDataDirectory,
- - kMobileNetV3Embedder));
- + options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath(
- + "./" /*test src dir*/, kTestDataDirectory, kMobileNetV3Embedder));
- options.mutable_embedding_options()->set_l2_normalize(true);
- options.mutable_search_options()->mutable_index_file()->set_file_name(
- JoinPath("./" /*test src dir*/, kTestDataDirectory, kIndex));
- options.mutable_search_options()->set_max_results(2);
- SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSearcher> searcher,
- - ImageSearcher::CreateFromOptions(options));
- + ImageSearcher::CreateFromOptions(options));
- // Load image.
- SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image, LoadImage("burger.jpg"));
- std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbRawBuffer(
- @@ -275,7 +267,7 @@ TEST(SearchTest, SucceedsWithMaxResults) {
-
- // Perform search.
- SUPPORT_ASSERT_OK_AND_ASSIGN(const SearchResult& result,
- - searcher->Search(*frame_buffer));
- + searcher->Search(*frame_buffer));
- ImageDataFree(&image);
-
- // Check results.
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_segmenter_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_segmenter_test.cc
- index e32b8e4c27524..8671b68c3b884 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_segmenter_test.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_segmenter_test.cc
- @@ -17,9 +17,9 @@ limitations under the License.
-
- #include <memory>
-
- -#include "absl/flags/flag.h" // from @com_google_absl
- +#include "absl/flags/flag.h" // from @com_google_absl
- #include "absl/status/status.h" // from @com_google_absl
- -#include "absl/strings/cord.h" // from @com_google_absl
- +#include "absl/strings/cord.h" // from @com_google_absl
- #include "tensorflow/lite/c/common.h"
- #include "tensorflow/lite/core/shims/cc/shims_test_util.h"
- #include "tensorflow/lite/kernels/builtin_op_kernels.h"
- @@ -99,8 +99,8 @@ constexpr float kGoldenMaskTolerance = 1e-2;
- constexpr int kGoldenMaskMagnificationFactor = 10;
-
- StatusOr<ImageData> LoadImage(std::string image_name) {
- - return DecodeImageFromFile(JoinPath("./" /*test src dir*/,
- - kTestDataDirectory, image_name));
- + return DecodeImageFromFile(
- + JoinPath("./" /*test src dir*/, kTestDataDirectory, image_name));
- }
-
- // Checks that the two provided `Segmentation` protos are equal.
- @@ -141,8 +141,8 @@ class CreateFromOptionsTest : public tflite_shims::testing::Test {};
-
- TEST_F(CreateFromOptionsTest, SucceedsWithSelectiveOpResolver) {
- ImageSegmenterOptions options;
- - options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
- - "./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
- + options.mutable_model_file_with_metadata()->set_file_name(
- + JoinPath("./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
-
- SUPPORT_ASSERT_OK(ImageSegmenter::CreateFromOptions(
- options, absl::make_unique<DeepLabOpResolver>()));
- @@ -160,8 +160,8 @@ class DeepLabOpResolverMissingOps : public ::tflite::MutableOpResolver {
-
- TEST_F(CreateFromOptionsTest, FailsWithSelectiveOpResolverMissingOps) {
- ImageSegmenterOptions options;
- - options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
- - "./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
- + options.mutable_model_file_with_metadata()->set_file_name(
- + JoinPath("./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
-
- auto image_segmenter_or = ImageSegmenter::CreateFromOptions(
- options, absl::make_unique<DeepLabOpResolverMissingOps>());
- @@ -177,10 +177,10 @@ TEST_F(CreateFromOptionsTest, FailsWithSelectiveOpResolverMissingOps) {
-
- TEST_F(CreateFromOptionsTest, FailsWithTwoModelSources) {
- ImageSegmenterOptions options;
- - options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
- - "./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
- - options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath(
- - "./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
- + options.mutable_model_file_with_metadata()->set_file_name(
- + JoinPath("./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
- + options.mutable_base_options()->mutable_model_file()->set_file_name(
- + JoinPath("./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
-
- StatusOr<std::unique_ptr<ImageSegmenter>> image_segmenter_or =
- ImageSegmenter::CreateFromOptions(options);
- @@ -212,8 +212,8 @@ TEST_F(CreateFromOptionsTest, FailsWithMissingModel) {
-
- TEST_F(CreateFromOptionsTest, FailsWithUnspecifiedOutputType) {
- ImageSegmenterOptions options;
- - options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
- - "./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
- + options.mutable_model_file_with_metadata()->set_file_name(
- + JoinPath("./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
- options.set_output_type(ImageSegmenterOptions::UNSPECIFIED);
-
- auto image_segmenter_or = ImageSegmenter::CreateFromOptions(options);
- @@ -230,8 +230,8 @@ TEST_F(CreateFromOptionsTest, FailsWithUnspecifiedOutputType) {
- TEST_F(CreateFromOptionsTest, SucceedsWithNumberOfThreads) {
- ImageSegmenterOptions options;
- options.set_num_threads(4);
- - options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
- - "./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
- + options.mutable_model_file_with_metadata()->set_file_name(
- + JoinPath("./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
-
- SUPPORT_ASSERT_OK(ImageSegmenter::CreateFromOptions(options));
- }
- @@ -243,8 +243,8 @@ INSTANTIATE_TEST_SUITE_P(Default, NumThreadsTest, testing::Values(0, -2));
- TEST_P(NumThreadsTest, FailsWithInvalidNumberOfThreads) {
- ImageSegmenterOptions options;
- options.set_num_threads(GetParam());
- - options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
- - "./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
- + options.mutable_model_file_with_metadata()->set_file_name(
- + JoinPath("./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
-
- StatusOr<std::unique_ptr<ImageSegmenter>> image_segmenter_or =
- ImageSegmenter::CreateFromOptions(options);
- @@ -263,21 +263,21 @@ TEST_P(NumThreadsTest, FailsWithInvalidNumberOfThreads) {
- TEST(SegmentTest, SucceedsWithCategoryMask) {
- // Load input and build frame buffer.
- SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData rgb_image,
- - LoadImage("segmentation_input_rotation0.jpg"));
- + LoadImage("segmentation_input_rotation0.jpg"));
- std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbRawBuffer(
- rgb_image.pixel_data,
- FrameBuffer::Dimension{rgb_image.width, rgb_image.height});
- // Load golden mask output.
- SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData golden_mask,
- - LoadImage("segmentation_golden_rotation0.png"));
- + LoadImage("segmentation_golden_rotation0.png"));
-
- ImageSegmenterOptions options;
- - options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
- - "./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
- + options.mutable_model_file_with_metadata()->set_file_name(
- + JoinPath("./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
- SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> image_segmenter,
- - ImageSegmenter::CreateFromOptions(options));
- + ImageSegmenter::CreateFromOptions(options));
- SUPPORT_ASSERT_OK_AND_ASSIGN(const SegmentationResult result,
- - image_segmenter->Segment(*frame_buffer));
- + image_segmenter->Segment(*frame_buffer));
-
- EXPECT_EQ(result.segmentation_size(), 1);
- const Segmentation& segmentation = result.segmentation(0);
- @@ -301,23 +301,24 @@ TEST(SegmentTest, SucceedsWithCategoryMask) {
-
- TEST(SegmentTest, SucceedsWithOrientation) {
- // Load input and build frame buffer with kRightBottom orientation.
- - SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData rgb_image,
- - LoadImage("segmentation_input_rotation90_flop.jpg"));
- + SUPPORT_ASSERT_OK_AND_ASSIGN(
- + ImageData rgb_image, LoadImage("segmentation_input_rotation90_flop.jpg"));
- std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbRawBuffer(
- rgb_image.pixel_data,
- FrameBuffer::Dimension{rgb_image.width, rgb_image.height},
- FrameBuffer::Orientation::kRightBottom);
- // Load golden mask output.
- - SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData golden_mask,
- - LoadImage("segmentation_golden_rotation90_flop.png"));
- + SUPPORT_ASSERT_OK_AND_ASSIGN(
- + ImageData golden_mask,
- + LoadImage("segmentation_golden_rotation90_flop.png"));
-
- ImageSegmenterOptions options;
- - options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
- - "./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
- + options.mutable_model_file_with_metadata()->set_file_name(
- + JoinPath("./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
- SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> image_segmenter,
- - ImageSegmenter::CreateFromOptions(options));
- + ImageSegmenter::CreateFromOptions(options));
- SUPPORT_ASSERT_OK_AND_ASSIGN(const SegmentationResult result,
- - image_segmenter->Segment(*frame_buffer));
- + image_segmenter->Segment(*frame_buffer));
-
- EXPECT_EQ(result.segmentation_size(), 1);
- const Segmentation& segmentation = result.segmentation(0);
- @@ -341,21 +342,21 @@ TEST(SegmentTest, SucceedsWithOrientation) {
- TEST(SegmentTest, SucceedsWithBaseOptions) {
- // Load input and build frame buffer.
- SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData rgb_image,
- - LoadImage("segmentation_input_rotation0.jpg"));
- + LoadImage("segmentation_input_rotation0.jpg"));
- std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbRawBuffer(
- rgb_image.pixel_data,
- FrameBuffer::Dimension{rgb_image.width, rgb_image.height});
- // Load golden mask output.
- SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData golden_mask,
- - LoadImage("segmentation_golden_rotation0.png"));
- + LoadImage("segmentation_golden_rotation0.png"));
-
- ImageSegmenterOptions options;
- - options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath(
- - "./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
- + options.mutable_base_options()->mutable_model_file()->set_file_name(
- + JoinPath("./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
- SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> image_segmenter,
- - ImageSegmenter::CreateFromOptions(options));
- + ImageSegmenter::CreateFromOptions(options));
- SUPPORT_ASSERT_OK_AND_ASSIGN(const SegmentationResult result,
- - image_segmenter->Segment(*frame_buffer));
- + image_segmenter->Segment(*frame_buffer));
-
- EXPECT_EQ(result.segmentation_size(), 1);
- const Segmentation& segmentation = result.segmentation(0);
- @@ -461,18 +462,18 @@ class PostprocessTest : public tflite_shims::testing::Test {
-
- TEST_F(PostprocessTest, SucceedsWithCategoryMask) {
- ImageSegmenterOptions options;
- - options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
- - "./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
- + options.mutable_model_file_with_metadata()->set_file_name(
- + JoinPath("./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
- std::unique_ptr<FrameBuffer> frame_buffer =
- CreateFromRgbaRawBuffer(/*input=*/nullptr, {});
-
- SetUp(options);
- ASSERT_TRUE(test_image_segmenter_ != nullptr) << init_status_;
- SUPPORT_ASSERT_OK_AND_ASSIGN(const TfLiteTensor* output_tensor,
- - FillAndGetOutputTensor());
- + FillAndGetOutputTensor());
- SUPPORT_ASSERT_OK_AND_ASSIGN(SegmentationResult result,
- - test_image_segmenter_->Postprocess(
- - {output_tensor}, *frame_buffer, /*roi=*/{}));
- + test_image_segmenter_->Postprocess(
- + {output_tensor}, *frame_buffer, /*roi=*/{}));
-
- EXPECT_EQ(result.segmentation_size(), 1);
- const Segmentation& segmentation = result.segmentation(0);
- @@ -487,8 +488,8 @@ TEST_F(PostprocessTest, SucceedsWithCategoryMask) {
-
- TEST_F(PostprocessTest, SucceedsWithCategoryMaskAndOrientation) {
- ImageSegmenterOptions options;
- - options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
- - "./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
- + options.mutable_model_file_with_metadata()->set_file_name(
- + JoinPath("./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
- // Frame buffer with kRightBottom orientation.
- std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbaRawBuffer(
- /*input=*/nullptr, {}, FrameBuffer::Orientation::kRightBottom);
- @@ -496,10 +497,10 @@ TEST_F(PostprocessTest, SucceedsWithCategoryMaskAndOrientation) {
- SetUp(options);
- ASSERT_TRUE(test_image_segmenter_ != nullptr) << init_status_;
- SUPPORT_ASSERT_OK_AND_ASSIGN(const TfLiteTensor* output_tensor,
- - FillAndGetOutputTensor());
- + FillAndGetOutputTensor());
- SUPPORT_ASSERT_OK_AND_ASSIGN(SegmentationResult result,
- - test_image_segmenter_->Postprocess(
- - {output_tensor}, *frame_buffer, /*roi=*/{}));
- + test_image_segmenter_->Postprocess(
- + {output_tensor}, *frame_buffer, /*roi=*/{}));
-
- EXPECT_EQ(result.segmentation_size(), 1);
- const Segmentation& segmentation = result.segmentation(0);
- @@ -515,18 +516,18 @@ TEST_F(PostprocessTest, SucceedsWithCategoryMaskAndOrientation) {
- TEST_F(PostprocessTest, SucceedsWithConfidenceMask) {
- ImageSegmenterOptions options;
- options.set_output_type(ImageSegmenterOptions::CONFIDENCE_MASK);
- - options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
- - "./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
- + options.mutable_model_file_with_metadata()->set_file_name(
- + JoinPath("./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
- std::unique_ptr<FrameBuffer> frame_buffer =
- CreateFromRgbaRawBuffer(/*input=*/nullptr, {});
-
- SetUp(options);
- ASSERT_TRUE(test_image_segmenter_ != nullptr) << init_status_;
- SUPPORT_ASSERT_OK_AND_ASSIGN(const TfLiteTensor* output_tensor,
- - FillAndGetOutputTensor());
- + FillAndGetOutputTensor());
- SUPPORT_ASSERT_OK_AND_ASSIGN(SegmentationResult result,
- - test_image_segmenter_->Postprocess(
- - {output_tensor}, *frame_buffer, /*roi=*/{}));
- + test_image_segmenter_->Postprocess(
- + {output_tensor}, *frame_buffer, /*roi=*/{}));
-
- EXPECT_EQ(result.segmentation_size(), 1);
- const Segmentation& segmentation = result.segmentation(0);
- @@ -547,8 +548,8 @@ TEST_F(PostprocessTest, SucceedsWithConfidenceMask) {
- TEST_F(PostprocessTest, SucceedsWithConfidenceMaskAndOrientation) {
- ImageSegmenterOptions options;
- options.set_output_type(ImageSegmenterOptions::CONFIDENCE_MASK);
- - options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
- - "./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
- + options.mutable_model_file_with_metadata()->set_file_name(
- + JoinPath("./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
- // Frame buffer with kRightBottom orientation.
- std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbaRawBuffer(
- /*input=*/nullptr, {}, FrameBuffer::Orientation::kRightBottom);
- @@ -556,10 +557,10 @@ TEST_F(PostprocessTest, SucceedsWithConfidenceMaskAndOrientation) {
- SetUp(options);
- ASSERT_TRUE(test_image_segmenter_ != nullptr) << init_status_;
- SUPPORT_ASSERT_OK_AND_ASSIGN(const TfLiteTensor* output_tensor,
- - FillAndGetOutputTensor());
- + FillAndGetOutputTensor());
- SUPPORT_ASSERT_OK_AND_ASSIGN(SegmentationResult result,
- - test_image_segmenter_->Postprocess(
- - {output_tensor}, *frame_buffer, /*roi=*/{}));
- + test_image_segmenter_->Postprocess(
- + {output_tensor}, *frame_buffer, /*roi=*/{}));
-
- EXPECT_EQ(result.segmentation_size(), 1);
- const Segmentation& segmentation = result.segmentation(0);
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/object_detector_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/object_detector_test.cc
- index a4f35574d7bfe..6c0f395868e20 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/object_detector_test.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/object_detector_test.cc
- @@ -17,9 +17,9 @@ limitations under the License.
-
- #include <memory>
-
- -#include "absl/flags/flag.h" // from @com_google_absl
- +#include "absl/flags/flag.h" // from @com_google_absl
- #include "absl/status/status.h" // from @com_google_absl
- -#include "absl/strings/cord.h" // from @com_google_absl
- +#include "absl/strings/cord.h" // from @com_google_absl
- #include "tensorflow/lite/c/common.h"
- #include "tensorflow/lite/core/shims/cc/shims_test_util.h"
- #include "tensorflow/lite/kernels/builtin_op_kernels.h"
- @@ -103,8 +103,8 @@ constexpr char kEfficientDetWithMetadata[] =
- "coco_efficientdet_lite0_v1_1.0_quant_2021_09_06.tflite";
-
- StatusOr<ImageData> LoadImage(std::string image_name) {
- - return DecodeImageFromFile(JoinPath("./" /*test src dir*/,
- - kTestDataDirectory, image_name));
- + return DecodeImageFromFile(
- + JoinPath("./" /*test src dir*/, kTestDataDirectory, image_name));
- }
-
- // Checks that the two provided `DetectionResult` protos are equal, with a
- @@ -153,9 +153,8 @@ class CreateFromOptionsTest : public tflite_shims::testing::Test {};
-
- TEST_F(CreateFromOptionsTest, SucceedsWithSelectiveOpResolver) {
- ObjectDetectorOptions options;
- - options.mutable_model_file_with_metadata()->set_file_name(
- - JoinPath("./" /*test src dir*/, kTestDataDirectory,
- - kMobileSsdWithMetadata));
- + options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
- + "./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata));
-
- SUPPORT_ASSERT_OK(ObjectDetector::CreateFromOptions(
- options, absl::make_unique<MobileSsdQuantizedOpResolver>()));
- @@ -186,9 +185,8 @@ class MobileSsdQuantizedOpResolverMissingOps
-
- TEST_F(CreateFromOptionsTest, FailsWithSelectiveOpResolverMissingOps) {
- ObjectDetectorOptions options;
- - options.mutable_model_file_with_metadata()->set_file_name(
- - JoinPath("./" /*test src dir*/, kTestDataDirectory,
- - kMobileSsdWithMetadata));
- + options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
- + "./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata));
-
- auto object_detector_or = ObjectDetector::CreateFromOptions(
- options, absl::make_unique<MobileSsdQuantizedOpResolverMissingOps>());
- @@ -203,12 +201,10 @@ TEST_F(CreateFromOptionsTest, FailsWithSelectiveOpResolverMissingOps) {
-
- TEST_F(CreateFromOptionsTest, FailsWithTwoModelSources) {
- ObjectDetectorOptions options;
- - options.mutable_model_file_with_metadata()->set_file_name(
- - JoinPath("./" /*test src dir*/, kTestDataDirectory,
- - kMobileSsdWithMetadata));
- - options.mutable_base_options()->mutable_model_file()->set_file_name(
- - JoinPath("./" /*test src dir*/, kTestDataDirectory,
- - kMobileSsdWithMetadata));
- + options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
- + "./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata));
- + options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath(
- + "./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata));
-
- StatusOr<std::unique_ptr<ObjectDetector>> object_detector_or =
- ObjectDetector::CreateFromOptions(options);
- @@ -241,9 +237,8 @@ TEST_F(CreateFromOptionsTest, FailsWithMissingModel) {
-
- TEST_F(CreateFromOptionsTest, FailsWithInvalidMaxResults) {
- ObjectDetectorOptions options;
- - options.mutable_model_file_with_metadata()->set_file_name(
- - JoinPath("./" /*test src dir*/, kTestDataDirectory,
- - kMobileSsdWithMetadata));
- + options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
- + "./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata));
- options.set_max_results(0);
-
- StatusOr<std::unique_ptr<ObjectDetector>> object_detector_or =
- @@ -260,9 +255,8 @@ TEST_F(CreateFromOptionsTest, FailsWithInvalidMaxResults) {
-
- TEST_F(CreateFromOptionsTest, FailsWithCombinedWhitelistAndBlacklist) {
- ObjectDetectorOptions options;
- - options.mutable_model_file_with_metadata()->set_file_name(
- - JoinPath("./" /*test src dir*/, kTestDataDirectory,
- - kMobileSsdWithMetadata));
- + options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
- + "./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata));
- options.add_class_name_whitelist("foo");
- options.add_class_name_blacklist("bar");
-
- @@ -281,9 +275,8 @@ TEST_F(CreateFromOptionsTest, FailsWithCombinedWhitelistAndBlacklist) {
- TEST_F(CreateFromOptionsTest, SucceedsWithNumberOfThreads) {
- ObjectDetectorOptions options;
- options.set_num_threads(4);
- - options.mutable_model_file_with_metadata()->set_file_name(
- - JoinPath("./" /*test src dir*/, kTestDataDirectory,
- - kMobileSsdWithMetadata));
- + options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
- + "./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata));
-
- SUPPORT_ASSERT_OK(ObjectDetector::CreateFromOptions(options));
- }
- @@ -295,9 +288,8 @@ INSTANTIATE_TEST_SUITE_P(Default, NumThreadsTest, testing::Values(0, -2));
- TEST_P(NumThreadsTest, FailsWithInvalidNumberOfThreads) {
- ObjectDetectorOptions options;
- options.set_num_threads(GetParam());
- - options.mutable_model_file_with_metadata()->set_file_name(
- - JoinPath("./" /*test src dir*/, kTestDataDirectory,
- - kMobileSsdWithMetadata));
- + options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
- + "./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata));
-
- StatusOr<std::unique_ptr<ObjectDetector>> object_detector_or =
- ObjectDetector::CreateFromOptions(options);
- @@ -315,51 +307,52 @@ TEST_P(NumThreadsTest, FailsWithInvalidNumberOfThreads) {
- class DetectTest : public tflite_shims::testing::Test {};
-
- TEST_F(DetectTest, Succeeds) {
- - SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData rgb_image, LoadImage("cats_and_dogs.jpg"));
- + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData rgb_image,
- + LoadImage("cats_and_dogs.jpg"));
- std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbRawBuffer(
- rgb_image.pixel_data,
- FrameBuffer::Dimension{rgb_image.width, rgb_image.height});
-
- ObjectDetectorOptions options;
- options.set_max_results(4);
- - options.mutable_model_file_with_metadata()->set_file_name(
- - JoinPath("./" /*test src dir*/, kTestDataDirectory,
- - kMobileSsdWithMetadata));
- + options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
- + "./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata));
-
- SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ObjectDetector> object_detector,
- - ObjectDetector::CreateFromOptions(options));
- + ObjectDetector::CreateFromOptions(options));
-
- SUPPORT_ASSERT_OK_AND_ASSIGN(const DetectionResult result,
- - object_detector->Detect(*frame_buffer));
- + object_detector->Detect(*frame_buffer));
- ImageDataFree(&rgb_image);
- ExpectApproximatelyEqual(
- result, ParseTextProtoOrDie<DetectionResult>(kExpectResults));
- }
-
- TEST_F(DetectTest, SucceedswithBaseOptions) {
- - SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData rgb_image, LoadImage("cats_and_dogs.jpg"));
- + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData rgb_image,
- + LoadImage("cats_and_dogs.jpg"));
- std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbRawBuffer(
- rgb_image.pixel_data,
- FrameBuffer::Dimension{rgb_image.width, rgb_image.height});
-
- ObjectDetectorOptions options;
- options.set_max_results(4);
- - options.mutable_base_options()->mutable_model_file()->set_file_name(
- - JoinPath("./" /*test src dir*/, kTestDataDirectory,
- - kMobileSsdWithMetadata));
- + options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath(
- + "./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata));
-
- SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ObjectDetector> object_detector,
- - ObjectDetector::CreateFromOptions(options));
- + ObjectDetector::CreateFromOptions(options));
-
- SUPPORT_ASSERT_OK_AND_ASSIGN(const DetectionResult result,
- - object_detector->Detect(*frame_buffer));
- + object_detector->Detect(*frame_buffer));
- ImageDataFree(&rgb_image);
- ExpectApproximatelyEqual(
- result, ParseTextProtoOrDie<DetectionResult>(kExpectResults));
- }
-
- TEST_F(DetectTest, SucceedswithScoreCalibrations) {
- - SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData rgb_image, LoadImage("cats_and_dogs.jpg"));
- + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData rgb_image,
- + LoadImage("cats_and_dogs.jpg"));
- std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbRawBuffer(
- rgb_image.pixel_data,
- FrameBuffer::Dimension{rgb_image.width, rgb_image.height});
- @@ -371,10 +364,10 @@ TEST_F(DetectTest, SucceedswithScoreCalibrations) {
- kMobileSsdWithMetadataDummyScoreCalibration));
-
- SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ObjectDetector> object_detector,
- - ObjectDetector::CreateFromOptions(options));
- + ObjectDetector::CreateFromOptions(options));
-
- SUPPORT_ASSERT_OK_AND_ASSIGN(const DetectionResult result,
- - object_detector->Detect(*frame_buffer));
- + object_detector->Detect(*frame_buffer));
- ImageDataFree(&rgb_image);
- ExpectApproximatelyEqual(
- result, ParseTextProtoOrDie<DetectionResult>(kExpectResults));
- @@ -482,20 +475,21 @@ class PostprocessTest : public tflite_shims::testing::Test {
-
- TEST_F(PostprocessTest, SucceedsWithScoreThresholdOption) {
- ObjectDetectorOptions options;
- - options.mutable_model_file_with_metadata()->set_file_name(
- - JoinPath("./" /*test src dir*/, kTestDataDirectory,
- - kMobileSsdWithMetadata));
- + options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
- + "./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata));
- options.set_score_threshold(0.5);
-
- SetUp(options);
- ASSERT_TRUE(test_object_detector_ != nullptr) << init_status_;
-
- - SUPPORT_ASSERT_OK_AND_ASSIGN(const std::vector<const TfLiteTensor*> output_tensors,
- - FillAndGetOutputTensors());
- + SUPPORT_ASSERT_OK_AND_ASSIGN(
- + const std::vector<const TfLiteTensor*> output_tensors,
- + FillAndGetOutputTensors());
-
- - SUPPORT_ASSERT_OK_AND_ASSIGN(DetectionResult result,
- - test_object_detector_->Postprocess(
- - output_tensors, *dummy_frame_buffer_, /*roi=*/{}));
- + SUPPORT_ASSERT_OK_AND_ASSIGN(
- + DetectionResult result,
- + test_object_detector_->Postprocess(output_tensors, *dummy_frame_buffer_,
- + /*roi=*/{}));
-
- ExpectApproximatelyEqual(
- result,
- @@ -517,16 +511,16 @@ TEST_F(PostprocessTest, SucceedsWithFrameBufferOrientation) {
- FrameBuffer::Orientation::kBottomRight);
-
- ObjectDetectorOptions options;
- - options.mutable_model_file_with_metadata()->set_file_name(
- - JoinPath("./" /*test src dir*/, kTestDataDirectory,
- - kMobileSsdWithMetadata));
- + options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
- + "./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata));
- options.set_score_threshold(0.5);
-
- SetUp(options);
- ASSERT_TRUE(test_object_detector_ != nullptr) << init_status_;
-
- - SUPPORT_ASSERT_OK_AND_ASSIGN(const std::vector<const TfLiteTensor*> output_tensors,
- - FillAndGetOutputTensors());
- + SUPPORT_ASSERT_OK_AND_ASSIGN(
- + const std::vector<const TfLiteTensor*> output_tensors,
- + FillAndGetOutputTensors());
-
- SUPPORT_ASSERT_OK_AND_ASSIGN(
- DetectionResult result,
- @@ -549,20 +543,21 @@ TEST_F(PostprocessTest, SucceedsWithFrameBufferOrientation) {
-
- TEST_F(PostprocessTest, SucceedsWithMaxResultsOption) {
- ObjectDetectorOptions options;
- - options.mutable_model_file_with_metadata()->set_file_name(
- - JoinPath("./" /*test src dir*/, kTestDataDirectory,
- - kMobileSsdWithMetadata));
- + options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
- + "./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata));
- options.set_max_results(1);
-
- SetUp(options);
- ASSERT_TRUE(test_object_detector_ != nullptr) << init_status_;
-
- - SUPPORT_ASSERT_OK_AND_ASSIGN(const std::vector<const TfLiteTensor*> output_tensors,
- - FillAndGetOutputTensors());
- + SUPPORT_ASSERT_OK_AND_ASSIGN(
- + const std::vector<const TfLiteTensor*> output_tensors,
- + FillAndGetOutputTensors());
-
- - SUPPORT_ASSERT_OK_AND_ASSIGN(DetectionResult result,
- - test_object_detector_->Postprocess(
- - output_tensors, *dummy_frame_buffer_, /*roi=*/{}));
- + SUPPORT_ASSERT_OK_AND_ASSIGN(
- + DetectionResult result,
- + test_object_detector_->Postprocess(output_tensors, *dummy_frame_buffer_,
- + /*roi=*/{}));
-
- ExpectApproximatelyEqual(
- result,
- @@ -576,21 +571,22 @@ TEST_F(PostprocessTest, SucceedsWithMaxResultsOption) {
-
- TEST_F(PostprocessTest, SucceedsWithWhitelistOption) {
- ObjectDetectorOptions options;
- - options.mutable_model_file_with_metadata()->set_file_name(
- - JoinPath("./" /*test src dir*/, kTestDataDirectory,
- - kMobileSsdWithMetadata));
- + options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
- + "./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata));
- options.add_class_name_whitelist("car");
- options.add_class_name_whitelist("motorcycle");
-
- SetUp(options);
- ASSERT_TRUE(test_object_detector_ != nullptr) << init_status_;
-
- - SUPPORT_ASSERT_OK_AND_ASSIGN(const std::vector<const TfLiteTensor*> output_tensors,
- - FillAndGetOutputTensors());
- + SUPPORT_ASSERT_OK_AND_ASSIGN(
- + const std::vector<const TfLiteTensor*> output_tensors,
- + FillAndGetOutputTensors());
-
- - SUPPORT_ASSERT_OK_AND_ASSIGN(DetectionResult result,
- - test_object_detector_->Postprocess(
- - output_tensors, *dummy_frame_buffer_, /*roi=*/{}));
- + SUPPORT_ASSERT_OK_AND_ASSIGN(
- + DetectionResult result,
- + test_object_detector_->Postprocess(output_tensors, *dummy_frame_buffer_,
- + /*roi=*/{}));
-
- ExpectApproximatelyEqual(
- result,
- @@ -608,9 +604,8 @@ TEST_F(PostprocessTest, SucceedsWithWhitelistOption) {
-
- TEST_F(PostprocessTest, SucceedsWithBlacklistOption) {
- ObjectDetectorOptions options;
- - options.mutable_model_file_with_metadata()->set_file_name(
- - JoinPath("./" /*test src dir*/, kTestDataDirectory,
- - kMobileSsdWithMetadata));
- + options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
- + "./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata));
- options.add_class_name_blacklist("car");
- // Setting score threshold to discard the 7 padded-with-zeros results.
- options.set_score_threshold(0.1);
- @@ -618,12 +613,14 @@ TEST_F(PostprocessTest, SucceedsWithBlacklistOption) {
- SetUp(options);
- ASSERT_TRUE(test_object_detector_ != nullptr) << init_status_;
-
- - SUPPORT_ASSERT_OK_AND_ASSIGN(const std::vector<const TfLiteTensor*> output_tensors,
- - FillAndGetOutputTensors());
- + SUPPORT_ASSERT_OK_AND_ASSIGN(
- + const std::vector<const TfLiteTensor*> output_tensors,
- + FillAndGetOutputTensors());
-
- - SUPPORT_ASSERT_OK_AND_ASSIGN(DetectionResult result,
- - test_object_detector_->Postprocess(
- - output_tensors, *dummy_frame_buffer_, /*roi=*/{}));
- + SUPPORT_ASSERT_OK_AND_ASSIGN(
- + DetectionResult result,
- + test_object_detector_->Postprocess(output_tensors, *dummy_frame_buffer_,
- + /*roi=*/{}));
-
- ExpectApproximatelyEqual(
- result,
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/test_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/test_utils.cc
- index 7937dbafb090b..c16815cb38061 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/test_utils.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/test_utils.cc
- @@ -21,13 +21,16 @@ namespace tflite {
- namespace task {
-
- std::string JoinPath(absl::string_view path1, absl::string_view path2) {
- - if (path1.empty()) return std::string(path2);
- - if (path2.empty()) return std::string(path1);
- + if (path1.empty())
- + return std::string(path2);
- + if (path2.empty())
- + return std::string(path1);
- if (path1.back() == '/') {
- if (path2.front() == '/')
- return absl::StrCat(path1, absl::ClippedSubstr(path2, 1));
- } else {
- - if (path2.front() != '/') return absl::StrCat(path1, "/", path2);
- + if (path2.front() != '/')
- + return absl::StrCat(path1, "/", path2);
- }
- return absl::StrCat(path1, path2);
- }
- @@ -44,14 +47,16 @@ std::string JoinPathImpl(bool honor_abs,
- // This size calculation is worst-case: it assumes one extra "/" for every
- // path other than the first.
- size_t total_size = paths.size() - 1;
- - for (const absl::string_view path : paths) total_size += path.size();
- + for (const absl::string_view path : paths)
- + total_size += path.size();
- result.resize(total_size);
-
- auto begin = result.begin();
- auto out = begin;
- bool trailing_slash = false;
- for (absl::string_view path : paths) {
- - if (path.empty()) continue;
- + if (path.empty())
- + continue;
- if (path.front() == '/') {
- if (honor_abs) {
- out = begin; // wipe out whatever we've built up so far.
- @@ -59,7 +64,8 @@ std::string JoinPathImpl(bool honor_abs,
- path.remove_prefix(1);
- }
- } else {
- - if (!trailing_slash && out != begin) *out++ = '/';
- + if (!trailing_slash && out != begin)
- + *out++ = '/';
- }
- const size_t this_size = path.size();
- memcpy(&*out, path.data(), this_size);
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/test_utils.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/test_utils.h
- index db72bc5d5ae98..1d730d5a6d981 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/test_utils.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/test_utils.h
- @@ -33,8 +33,10 @@ std::string JoinPathImpl(bool honor_abs,
- std::string JoinPath(absl::string_view path1, absl::string_view path2);
-
- template <typename... T>
- -inline std::string JoinPath(absl::string_view path1, absl::string_view path2,
- - absl::string_view path3, const T&... args) {
- +inline std::string JoinPath(absl::string_view path1,
- + absl::string_view path2,
- + absl::string_view path3,
- + const T&... args) {
- return internal::JoinPathImpl(false, {path1, path2, path3, args...});
- }
-
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.cc
- index 6a050668edcbe..53c88310dde43 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.cc
- @@ -31,7 +31,8 @@ FlatHashMapBackedWordpiece::FlatHashMapBackedWordpiece(
- }
-
- tensorflow::text::LookupStatus FlatHashMapBackedWordpiece::Contains(
- - absl::string_view key, bool* value) const {
- + absl::string_view key,
- + bool* value) const {
- *value = index_map_.contains(key);
- return tensorflow::text::LookupStatus();
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.h
- index aec178daf3cc5..1de54fa8f651c 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.h
- @@ -103,7 +103,8 @@ class BertTokenizer : public tflite::support::text::tokenizer::Tokenizer {
-
- // Initialize the tokenizer from buffer and size of vocab and tokenizer
- // configs.
- - BertTokenizer(const char* vocab_buffer_data, size_t vocab_buffer_size,
- + BertTokenizer(const char* vocab_buffer_data,
- + size_t vocab_buffer_size,
- const BertTokenizerOptions& options = {})
- : BertTokenizer(
- utils::LoadVocabFromBuffer(vocab_buffer_data, vocab_buffer_size),
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer_jni.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer_jni.cc
- index 151161777863f..249bc2d1b6bc2 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer_jni.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer_jni.cc
- @@ -31,9 +31,14 @@ using ::tflite::support::utils::StringListToVector;
-
- extern "C" JNIEXPORT jlong JNICALL
- Java_org_tensorflow_lite_support_text_tokenizers_BertTokenizer_nativeLoadResource( // NOLINT
- - JNIEnv* env, jobject thiz, jobject vocab_list, jint max_bytes_per_token,
- - jint max_chars_per_sub_token, jstring jsuffix_indicator,
- - jboolean use_unknown_token, jstring junknown_token,
- + JNIEnv* env,
- + jobject thiz,
- + jobject vocab_list,
- + jint max_bytes_per_token,
- + jint max_chars_per_sub_token,
- + jstring jsuffix_indicator,
- + jboolean use_unknown_token,
- + jstring junknown_token,
- jboolean split_unknown_chars) {
- // Convert java.util.List<String> into std::vector<string>
- std::vector<std::string> vocab = StringListToVector(env, vocab_list);
- @@ -66,20 +71,28 @@ Java_org_tensorflow_lite_support_text_tokenizers_BertTokenizer_nativeLoadResourc
-
- extern "C" JNIEXPORT jlong JNICALL
- Java_org_tensorflow_lite_support_text_tokenizers_BertTokenizer_nativeUnloadResource( // NOLINT
- - JNIEnv* env, jobject thiz, jlong handle) {
- + JNIEnv* env,
- + jobject thiz,
- + jlong handle) {
- delete reinterpret_cast<BertTokenizer*>(handle);
- return 0;
- }
-
- extern "C" JNIEXPORT jobjectArray JNICALL
- Java_org_tensorflow_lite_support_text_tokenizers_BertTokenizer_nativeTokenize(
- - JNIEnv* env, jobject thiz, jlong handle, jstring jtext) {
- + JNIEnv* env,
- + jobject thiz,
- + jlong handle,
- + jstring jtext) {
- return nativeTokenize(env, handle, jtext);
- }
-
- extern "C" JNIEXPORT jintArray JNICALL
- Java_org_tensorflow_lite_support_text_tokenizers_BertTokenizer_nativeConvertTokensToIds( // NOLINT
- - JNIEnv* env, jobject thiz, jlong handle, jobjectArray jtokens) {
- + JNIEnv* env,
- + jobject thiz,
- + jlong handle,
- + jobjectArray jtokens) {
- return nativeConvertTokensToIds(env, handle, jtokens);
- }
-
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.cc
- index 832f9df42f824..ded6fbd13ea4a 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.cc
- @@ -17,7 +17,7 @@ limitations under the License.
-
- #include <iostream>
-
- -#include "absl/strings/str_cat.h" // from @com_google_absl
- +#include "absl/strings/str_cat.h" // from @com_google_absl
- #include "absl/strings/substitute.h" // from @com_google_absl
- #include "tensorflow_lite_support/cc/utils/common_utils.h"
- namespace tflite {
- @@ -70,7 +70,7 @@ TokenizerResult RegexTokenizer::Tokenize(const std::string& input) {
- re2::StringPiece extracted_delim_token;
- while (RE2::FindAndConsume(&leftover, delim_re_, &extracted_delim_token)) {
- re2::StringPiece token(last_end.data(),
- - extracted_delim_token.data() - last_end.data());
- + extracted_delim_token.data() - last_end.data());
- bool has_non_empty_token = token.length() > 0;
-
- last_end = leftover;
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/sentencepiece_jni.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/sentencepiece_jni.cc
- index 6ecfff0d2baa1..8ca14c52eb262 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/sentencepiece_jni.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/sentencepiece_jni.cc
- @@ -20,7 +20,7 @@ limitations under the License.
- #include <utility>
- #include <vector>
-
- -#include "absl/memory/memory.h" // from @com_google_absl
- +#include "absl/memory/memory.h" // from @com_google_absl
- #include "absl/strings/str_split.h" // from @com_google_absl
- #include "tensorflow_lite_support/cc/text/tokenizers/sentencepiece_tokenizer.h"
- #include "tensorflow_lite_support/cc/text/tokenizers/tokenizer_jni_lib.h"
- @@ -34,7 +34,9 @@ using ::tflite::support::utils::GetMappedFileBuffer;
-
- extern "C" JNIEXPORT jlong JNICALL
- Java_org_tensorflow_lite_support_text_tokenizers_SentencePieceTokenizer_nativeLoadResource( // NOLINT
- - JNIEnv* env, jobject obj, jobject model_buffer) {
- + JNIEnv* env,
- + jobject obj,
- + jobject model_buffer) {
- auto model = GetMappedFileBuffer(env, model_buffer);
- auto handle =
- absl::make_unique<SentencePieceTokenizer>(model.data(), model.size());
- @@ -43,20 +45,28 @@ Java_org_tensorflow_lite_support_text_tokenizers_SentencePieceTokenizer_nativeLo
-
- extern "C" JNIEXPORT jlong JNICALL
- Java_org_tensorflow_lite_support_text_tokenizers_SentencePieceTokenizer_nativeUnloadResource( // NOLINT
- - JNIEnv* env, jobject obj, jlong handle) {
- + JNIEnv* env,
- + jobject obj,
- + jlong handle) {
- delete reinterpret_cast<SentencePieceTokenizer*>(handle);
- return 0;
- }
-
- extern "C" JNIEXPORT jobjectArray JNICALL
- Java_org_tensorflow_lite_support_text_tokenizers_SentencePieceTokenizer_nativeTokenize( // NOLINT
- - JNIEnv* env, jobject thiz, jlong handle, jstring jtext) {
- + JNIEnv* env,
- + jobject thiz,
- + jlong handle,
- + jstring jtext) {
- return nativeTokenize(env, handle, jtext);
- }
-
- extern "C" JNIEXPORT jintArray JNICALL
- Java_org_tensorflow_lite_support_text_tokenizers_SentencePieceTokenizer_nativeConvertTokensToIds( // NOLINT
- - JNIEnv* env, jobject thiz, jlong handle, jobjectArray jtokens) {
- + JNIEnv* env,
- + jobject thiz,
- + jlong handle,
- + jobjectArray jtokens) {
- return nativeConvertTokensToIds(env, handle, jtokens);
- }
-
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_jni_lib.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_jni_lib.cc
- index a72523be5984e..4e32bc5581a48 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_jni_lib.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_jni_lib.cc
- @@ -54,7 +54,8 @@ jobjectArray nativeTokenize(JNIEnv* env, jlong handle, jstring jtext) {
- return result;
- }
-
- -jintArray nativeConvertTokensToIds(JNIEnv* env, jlong handle,
- +jintArray nativeConvertTokensToIds(JNIEnv* env,
- + jlong handle,
- jobjectArray jtokens) {
- if (handle == 0) {
- env->ThrowNew(env->FindClass(kIllegalStateException),
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_jni_lib.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_jni_lib.h
- index 33677d305a853..fd76f3aa553e4 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_jni_lib.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_jni_lib.h
- @@ -25,7 +25,8 @@ namespace support {
-
- jobjectArray nativeTokenize(JNIEnv* env, jlong handle, jstring jtext);
-
- -jintArray nativeConvertTokensToIds(JNIEnv* env, jlong handle,
- +jintArray nativeConvertTokensToIds(JNIEnv* env,
- + jlong handle,
- jobjectArray jtokens);
-
- } // namespace support
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.cc
- index 28f0137f54278..32957d155dce6 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.cc
- @@ -73,9 +73,9 @@ StatusOr<std::unique_ptr<Tokenizer>> CreateTokenizerFromProcessUnit(
- }
- case ProcessUnitOptions_SentencePieceTokenizerOptions: {
- return CreateStatusWithPayload(
- - absl::StatusCode::kInvalidArgument,
- - "Chromium does not support sentencepiece tokenization",
- - TfLiteSupportStatus::kMetadataInvalidTokenizerError);
- + absl::StatusCode::kInvalidArgument,
- + "Chromium does not support sentencepiece tokenization",
- + TfLiteSupportStatus::kMetadataInvalidTokenizerError);
- }
- case ProcessUnitOptions_RegexTokenizerOptions: {
- const tflite::RegexTokenizerOptions* options =
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.h
- index 2e50a79963f82..696c5d4e27db7 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.h
- @@ -26,7 +26,6 @@ namespace support {
- namespace text {
- namespace tokenizer {
-
- -
- // Create a Tokenizer from model metadata by extracting
- tflite::support::StatusOr<std::unique_ptr<Tokenizer>>
- CreateTokenizerFromProcessUnit(
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/common_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/common_utils.cc
- index 84cc0ef6ae52e..3ea6b147fcdd6 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/common_utils.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/common_utils.cc
- @@ -83,7 +83,8 @@ absl::node_hash_map<std::string, int> LoadVocabAndIndexFromFile(
- }
-
- absl::node_hash_map<std::string, int> LoadVocabAndIndexFromBuffer(
- - const char* vocab_buffer_data, const size_t vocab_buffer_size) {
- + const char* vocab_buffer_data,
- + const size_t vocab_buffer_size) {
- membuf sbuf(const_cast<char*>(vocab_buffer_data),
- const_cast<char*>(vocab_buffer_data + vocab_buffer_size));
- absl::node_hash_map<std::string, int> vocab_index_map;
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/common_utils.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/common_utils.h
- index 6921d2f5ac01b..275c4932f8ec0 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/common_utils.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/common_utils.h
- @@ -41,7 +41,8 @@ absl::node_hash_map<std::string, int> LoadVocabAndIndexFromFile(
- // Read a vocab buffer with one vocabulary and its corresponding index on each
- // line separated by space, create a map of <vocab, index>.
- absl::node_hash_map<std::string, int> LoadVocabAndIndexFromBuffer(
- - const char* vocab_buffer_data, const size_t vocab_buffer_size);
- + const char* vocab_buffer_data,
- + const size_t vocab_buffer_size);
- } // namespace utils
- } // namespace support
- } // namespace tflite
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/jni_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/jni_utils.cc
- index bf9e93f9aa24a..35ce822951ad8 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/jni_utils.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/jni_utils.cc
- @@ -18,8 +18,8 @@ limitations under the License.
- #include <dlfcn.h>
- #include <string.h>
-
- -#include "absl/memory/memory.h" // from @com_google_absl
- -#include "absl/status/status.h" // from @com_google_absl
- +#include "absl/memory/memory.h" // from @com_google_absl
- +#include "absl/status/status.h" // from @com_google_absl
- #include "absl/strings/str_format.h" // from @com_google_absl
- #include "tensorflow/lite/core/shims/c/experimental/acceleration/configuration/delegate_plugin.h"
- #include "tensorflow/lite/core/shims/cc/experimental/acceleration/configuration/delegate_registry.h"
- @@ -168,7 +168,8 @@ void ThrowException(JNIEnv* env, const char* clazz, const char* fmt, ...) {
- va_end(args);
- }
-
- -void ThrowExceptionWithMessage(JNIEnv* env, const char* clazz,
- +void ThrowExceptionWithMessage(JNIEnv* env,
- + const char* clazz,
- const char* message) {
- jclass e_class = env->FindClass(clazz);
- if (strcmp(clazz, kAssertionError) == 0) {
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/jni_utils.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/jni_utils.h
- index 6d15bb43e75b3..f92f838bb9a71 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/jni_utils.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/jni_utils.h
- @@ -22,7 +22,7 @@ limitations under the License.
- #include <string>
- #include <vector>
-
- -#include "absl/status/status.h" // from @com_google_absl
- +#include "absl/status/status.h" // from @com_google_absl
- #include "absl/strings/string_view.h" // from @com_google_absl
- #include "tensorflow_lite_support/cc/port/configuration_proto_inc.h"
- #include "tensorflow_lite_support/cc/port/statusor.h"
- @@ -59,7 +59,9 @@ T CheckNotNull(JNIEnv* env, T&& t) {
- // interable before adding it to the ArrayList.
- template <typename Iterator>
- jobject ConvertVectorToArrayList(
- - JNIEnv* env, const Iterator& begin, const Iterator& end,
- + JNIEnv* env,
- + const Iterator& begin,
- + const Iterator& end,
- std::function<jobject(typename std::iterator_traits<Iterator>::value_type)>
- converter) {
- jclass array_list_class = env->FindClass("java/util/ArrayList");
- @@ -94,7 +96,8 @@ jbyteArray CreateByteArray(JNIEnv* env, const jbyte* data, int num_bytes);
-
- void ThrowException(JNIEnv* env, const char* clazz, const char* fmt, ...);
-
- -void ThrowExceptionWithMessage(JNIEnv* env, const char* clazz,
- +void ThrowExceptionWithMessage(JNIEnv* env,
- + const char* clazz,
- const char* message);
-
- const char* GetExceptionClassNameForStatusCode(absl::StatusCode status_code);
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/codegen/android_java_generator.cc b/third_party/tflite_support/src/tensorflow_lite_support/codegen/android_java_generator.cc
- index eb94cb7020475..bb8f1f4d40655 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/codegen/android_java_generator.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/codegen/android_java_generator.cc
- @@ -63,7 +63,8 @@ using details_android_java::TensorInfo;
- // Using ctor and dtor to simulate an enter/exit schema like `with` in Python.
- class AsBlock {
- public:
- - AsBlock(CodeWriter* code_writer, const std::string& before,
- + AsBlock(CodeWriter* code_writer,
- + const std::string& before,
- bool trailing_blank_line = false)
- : code_writer_(code_writer), trailing_blank_line_(trailing_blank_line) {
- code_writer_->AppendNoNewLine(before);
- @@ -105,7 +106,9 @@ std::string GetModelVersionedName(const ModelMetadata* metadata) {
- }
-
- TensorInfo CreateTensorInfo(const TensorMetadata* metadata,
- - const std::string& name, bool is_input, int index,
- + const std::string& name,
- + bool is_input,
- + int index,
- ErrorReporter* err) {
- TensorInfo tensor_info;
- std::string tensor_identifier = is_input ? "input" : "output";
- @@ -273,7 +276,8 @@ bool IsImageUsed(const ModelInfo& model) {
-
- // The following functions generates the wrapper Java code for a model.
-
- -bool GenerateWrapperFileContent(CodeWriter* code_writer, const ModelInfo& model,
- +bool GenerateWrapperFileContent(CodeWriter* code_writer,
- + const ModelInfo& model,
- ErrorReporter* err) {
- code_writer->Append("// Generated by TFLite Support.");
- code_writer->Append("package {{PACKAGE}};");
- @@ -291,7 +295,8 @@ bool GenerateWrapperFileContent(CodeWriter* code_writer, const ModelInfo& model,
- return true;
- }
-
- -bool GenerateWrapperImports(CodeWriter* code_writer, const ModelInfo& model,
- +bool GenerateWrapperImports(CodeWriter* code_writer,
- + const ModelInfo& model,
- ErrorReporter* err) {
- const std::string support_pkg = "org.tensorflow.lite.support.";
- std::vector<std::string> imports{
- @@ -336,7 +341,8 @@ bool GenerateWrapperImports(CodeWriter* code_writer, const ModelInfo& model,
- return true;
- }
-
- -bool GenerateWrapperClass(CodeWriter* code_writer, const ModelInfo& model,
- +bool GenerateWrapperClass(CodeWriter* code_writer,
- + const ModelInfo& model,
- ErrorReporter* err) {
- code_writer->SetTokenValue("MODEL_VERSIONED_NAME",
- model.model_versioned_name);
- @@ -373,7 +379,8 @@ private static final String MODEL_NAME = "{{MODEL_PATH}}";)");
- return true;
- }
-
- -bool GenerateWrapperOutputs(CodeWriter* code_writer, const ModelInfo& model,
- +bool GenerateWrapperOutputs(CodeWriter* code_writer,
- + const ModelInfo& model,
- ErrorReporter* err) {
- code_writer->Append("/** Output wrapper of {@link {{MODEL_CLASS_NAME}}} */");
- auto class_block = AsBlock(code_writer, "public static class Outputs");
- @@ -459,7 +466,8 @@ bool GenerateWrapperOutputs(CodeWriter* code_writer, const ModelInfo& model,
- return true;
- }
-
- -bool GenerateWrapperMetadata(CodeWriter* code_writer, const ModelInfo& model,
- +bool GenerateWrapperMetadata(CodeWriter* code_writer,
- + const ModelInfo& model,
- ErrorReporter* err) {
- code_writer->Append(
- "/** Metadata accessors of {@link {{MODEL_CLASS_NAME}}} */");
- @@ -605,7 +613,8 @@ public List<String> get{{NAME_U}}Labels() {
- return true;
- }
-
- -bool GenerateWrapperAPI(CodeWriter* code_writer, const ModelInfo& model,
- +bool GenerateWrapperAPI(CodeWriter* code_writer,
- + const ModelInfo& model,
- ErrorReporter* err) {
- code_writer->Append(R"(public Metadata getMetadata() {
- return metadata;
- @@ -980,8 +989,10 @@ AndroidJavaGenerator::AndroidJavaGenerator(const std::string& module_root)
- : CodeGenerator(), module_root_(module_root) {}
-
- GenerationResult AndroidJavaGenerator::Generate(
- - const Model* model, const std::string& package_name,
- - const std::string& model_class_name, const std::string& model_asset_path) {
- + const Model* model,
- + const std::string& package_name,
- + const std::string& model_class_name,
- + const std::string& model_asset_path) {
- GenerationResult result;
- if (model == nullptr) {
- err_.Error(
- @@ -1006,8 +1017,10 @@ GenerationResult AndroidJavaGenerator::Generate(
- }
-
- GenerationResult AndroidJavaGenerator::Generate(
- - const char* model_storage, const std::string& package_name,
- - const std::string& model_class_name, const std::string& model_asset_path) {
- + const char* model_storage,
- + const std::string& package_name,
- + const std::string& model_class_name,
- + const std::string& model_asset_path) {
- const Model* model = GetModel(model_storage);
- return Generate(model, package_name, model_class_name, model_asset_path);
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/codegen/android_java_generator.h b/third_party/tflite_support/src/tensorflow_lite_support/codegen/android_java_generator.h
- index 634ccf69f6c1a..1ea8bb2182a67 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/codegen/android_java_generator.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/codegen/android_java_generator.h
- @@ -20,10 +20,10 @@ limitations under the License.
- #include <string>
- #include <vector>
-
- +#include "tensorflow/lite/schema/schema_generated.h"
- #include "tensorflow_lite_support/codegen/code_generator.h"
- #include "tensorflow_lite_support/codegen/utils.h"
- #include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
- -#include "tensorflow/lite/schema/schema_generated.h"
-
- namespace tflite {
- namespace support {
- @@ -90,7 +90,8 @@ class AndroidJavaGenerator : public CodeGenerator {
- /// as "ImageClassifier", "MobileNetV2" or "MyModel".
- /// - model_asset_path: The relevant path to the model file in the asset.
- // TODO(b/141225157): Automatically generate model_class_name.
- - GenerationResult Generate(const Model* model, const std::string& package_name,
- + GenerationResult Generate(const Model* model,
- + const std::string& package_name,
- const std::string& model_class_name,
- const std::string& model_asset_path);
-
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/codegen/code_generator.cc b/third_party/tflite_support/src/tensorflow_lite_support/codegen/code_generator.cc
- index 1337708d4ac66..b6ec55cbc5e8b 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/codegen/code_generator.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/codegen/code_generator.cc
- @@ -144,7 +144,8 @@ std::string CodeGenerator::NameTensor(const TensorMetadata& tensor,
- }
-
- void CodeGenerator::ResolveConflictedInputAndOutputNames(
- - std::vector<std::string>* inputs, std::vector<std::string>* outputs) {
- + std::vector<std::string>* inputs,
- + std::vector<std::string>* outputs) {
- std::unordered_set<std::string> io_conflict;
- auto& input_names = *inputs;
- auto& output_names = *outputs;
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/codegen/code_generator.h b/third_party/tflite_support/src/tensorflow_lite_support/codegen/code_generator.h
- index b557773ddcc7a..fe67327986bd7 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/codegen/code_generator.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/codegen/code_generator.h
- @@ -70,7 +70,8 @@ class CodeGenerator {
- static std::string NameTensor(const TensorMetadata& tensor,
- const std::string& default_name);
- static void ResolveConflictedInputAndOutputNames(
- - std::vector<std::string>* input, std::vector<std::string>* output);
- + std::vector<std::string>* input,
- + std::vector<std::string>* output);
- };
-
- } // namespace codegen
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/codegen/code_generator_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/codegen/code_generator_test.cc
- index 5e9d64a0d8f98..ccc87668ed3cb 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/codegen/code_generator_test.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/codegen/code_generator_test.cc
- @@ -36,7 +36,8 @@ class CodeGeneratorTest : public ::testing::Test {
- return CodeGenerator::ConvertToValidName(name);
- }
- static void ResolveConflictedInputAndOutputNames(
- - std::vector<std::string>* input, std::vector<std::string>* output) {
- + std::vector<std::string>* input,
- + std::vector<std::string>* output) {
- CodeGenerator::ResolveConflictedInputAndOutputNames(input, output);
- }
- };
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/codegen/metadata_helper.h b/third_party/tflite_support/src/tensorflow_lite_support/codegen/metadata_helper.h
- index 8e3dc6abaed66..193dfb2fb23f3 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/codegen/metadata_helper.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/codegen/metadata_helper.h
- @@ -18,9 +18,9 @@ limitations under the License.
-
- #include <string>
-
- +#include "tensorflow/lite/schema/schema_generated.h"
- #include "tensorflow_lite_support/codegen/utils.h"
- #include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
- -#include "tensorflow/lite/schema/schema_generated.h"
-
- namespace tflite {
- namespace support {
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/codegen/python/codegen_lib.cc b/third_party/tflite_support/src/tensorflow_lite_support/codegen/python/codegen_lib.cc
- index 6b2cd5ea9a778..a9da2403afc4f 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/codegen/python/codegen_lib.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/codegen/python/codegen_lib.cc
- @@ -29,11 +29,10 @@ using overload_cast_ = pybind11::detail::overload_cast_impl<Args...>;
-
- PYBIND11_MODULE(_pywrap_codegen, m) {
- pybind11::class_<AndroidJavaGenerator>(m, "AndroidJavaGenerator")
- - .def(pybind11::init<const std::string &>())
- - .def("generate",
- - overload_cast_<const char *, const std::string &,
- - const std::string &, const std::string &>()(
- - &AndroidJavaGenerator::Generate))
- + .def(pybind11::init<const std::string&>())
- + .def("generate", overload_cast_<const char*, const std::string&,
- + const std::string&, const std::string&>()(
- + &AndroidJavaGenerator::Generate))
- .def("get_error_message", &AndroidJavaGenerator::GetErrorMessage);
- pybind11::class_<GenerationResult>(m, "GenerationResult")
- .def(pybind11::init<>())
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/codegen/utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/codegen/utils.cc
- index c75fc5fae631d..e89d09629dda1 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/codegen/utils.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/codegen/utils.cc
- @@ -32,7 +32,8 @@ int ErrorReporter::Error(const char* format, ...) {
- return Report("[ERROR] ", format, args);
- }
-
- -int ErrorReporter::Report(const char* prefix, const char* format,
- +int ErrorReporter::Report(const char* prefix,
- + const char* format,
- va_list args) {
- char buf[1024];
- int formatted = vsnprintf(buf, sizeof(buf), format, args);
- @@ -69,9 +70,13 @@ void CodeWriter::SetIndentString(const std::string& indent_str) {
- indent_str_ = indent_str;
- }
-
- -void CodeWriter::Indent() { indent_++; }
- +void CodeWriter::Indent() {
- + indent_++;
- +}
-
- -void CodeWriter::Outdent() { indent_--; }
- +void CodeWriter::Outdent() {
- + indent_--;
- +}
-
- std::string CodeWriter::GenerateIndent() const {
- std::string res;
- @@ -82,7 +87,9 @@ std::string CodeWriter::GenerateIndent() const {
- return res;
- }
-
- -void CodeWriter::Append(const std::string& text) { AppendInternal(text, true); }
- +void CodeWriter::Append(const std::string& text) {
- + AppendInternal(text, true);
- +}
-
- void CodeWriter::AppendNoNewLine(const std::string& text) {
- AppendInternal(text, false);
- @@ -144,15 +151,21 @@ void CodeWriter::AppendInternal(const std::string& text, bool newline) {
- }
- }
-
- -void CodeWriter::NewLine() { Append(""); }
- +void CodeWriter::NewLine() {
- + Append("");
- +}
-
- void CodeWriter::Backspace(int n) {
- buffer_.resize(buffer_.size() > n ? buffer_.size() - n : 0);
- }
-
- -std::string CodeWriter::ToString() const { return buffer_; }
- +std::string CodeWriter::ToString() const {
- + return buffer_;
- +}
-
- -bool CodeWriter::IsStreamEmpty() const { return buffer_.empty(); }
- +bool CodeWriter::IsStreamEmpty() const {
- + return buffer_.empty();
- +}
-
- void CodeWriter::Clear() {
- buffer_.clear();
- @@ -181,11 +194,14 @@ std::string SnakeCaseToCamelCase(const std::string& s) {
- }
-
- std::string JoinPath(const std::string& a, const std::string& b) {
- - if (a.empty()) return b;
- + if (a.empty())
- + return b;
- std::string a_fixed = a;
- - if (!a_fixed.empty() && a_fixed.back() == '/') a_fixed.pop_back();
- + if (!a_fixed.empty() && a_fixed.back() == '/')
- + a_fixed.pop_back();
- std::string b_fixed = b;
- - if (!b_fixed.empty() && b_fixed.front() == '/') b_fixed.erase(0, 1);
- + if (!b_fixed.empty() && b_fixed.front() == '/')
- + b_fixed.erase(0, 1);
- return a_fixed + "/" + b_fixed;
- }
-
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ngrams.cc b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ngrams.cc
- index 3831c63ca17cc..f55ffb907f133 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ngrams.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ngrams.cc
- @@ -66,7 +66,9 @@ struct NgramsAttributes {
- string_separator(m["string_separator"].ToString()) {}
- };
-
- -inline bool OutputIsTensor(TfLiteNode* node) { return NumOutputs(node) == 1; }
- +inline bool OutputIsTensor(TfLiteNode* node) {
- + return NumOutputs(node) == 1;
- +}
- inline int NumRowSplits(TfLiteNode* node) {
- return NumInputs(node) - kRowSplitsStart;
- }
- @@ -176,7 +178,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
- std::vector<StringRef> tokens;
- for (int j = input_row_splits[i]; j < input_row_splits[i + 1]; ++j) {
- tokens.emplace_back(GetString(input_values, j));
- - if (tokens.size() < attributes.width) continue;
- + if (tokens.size() < attributes.width)
- + continue;
- tokens.erase(tokens.begin(),
- tokens.begin() + tokens.size() - attributes.width);
- buffer.AddJoinedString(tokens, separator);
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ngrams_op_resolver.cc b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ngrams_op_resolver.cc
- index b87fcac328623..dc21f37beb3bf 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ngrams_op_resolver.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ngrams_op_resolver.cc
- @@ -15,8 +15,8 @@ limitations under the License.
-
- #include "tensorflow_lite_support/custom_ops/kernel/ngrams_op_resolver.h"
-
- -#include "tensorflow_lite_support/custom_ops/kernel/ngrams.h"
- #include "tensorflow/lite/mutable_op_resolver.h"
- +#include "tensorflow_lite_support/custom_ops/kernel/ngrams.h"
-
- namespace tflite {
- namespace ops {
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ngrams_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ngrams_test.cc
- index 91ef47af6fd0f..4a5e671fa0987 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ngrams_test.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ngrams_test.cc
- @@ -40,7 +40,8 @@ using ::testing::ElementsAreArray;
- class NgramsModel : public SingleOpModel {
- public:
- // Constructor for testing the op with a tf.Tensor
- - NgramsModel(int width, const std::string& string_separator,
- + NgramsModel(int width,
- + const std::string& string_separator,
- const std::vector<std::string>& input_values,
- const std::vector<int>& input_shape) {
- input_values_ = AddInput(TensorType_STRING);
- @@ -56,7 +57,8 @@ class NgramsModel : public SingleOpModel {
- // Constructor for the op with a tf.RaggedTensor
- // Note: This interface uses row_lengths, as they're closer to the
- // dimensions in a TensorShape, but internally everything is row_splits.
- - NgramsModel(int width, const std::string& string_separator,
- + NgramsModel(int width,
- + const std::string& string_separator,
- const std::vector<std::string>& input_values,
- const std::vector<std::vector<int64_t>> nested_row_lengths) {
- std::vector<std::vector<int>> input_shapes;
- @@ -203,8 +205,7 @@ TEST(NgramsTest, TensorMultidimensionalInputWidthTwo) {
- TEST(NgramsTest, RaggedTensorSingleSequenceWidthTwo) {
- std::vector<std::vector<int64_t>> nested_row_lengths;
- nested_row_lengths.push_back({4});
- - NgramsModel m(2, " ", {"this", "is", "a", "test"},
- - nested_row_lengths);
- + NgramsModel m(2, " ", {"this", "is", "a", "test"}, nested_row_lengths);
- EXPECT_THAT(m.GetValuesTensorShape(), ElementsAre(3));
- EXPECT_THAT(m.ExtractValuesTensorVector(),
- ElementsAre("this is", "is a", "a test"));
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/py_tflite_registerer.h b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/py_tflite_registerer.h
- index ade3c5c178920..811be781d27fe 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/py_tflite_registerer.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/py_tflite_registerer.h
- @@ -20,6 +20,6 @@ limitations under the License.
- // C-function that is called from the Python Wrapper.
-
- extern "C" void TFLite_RaggedTensorToTensorRegisterer(
- - tflite::MutableOpResolver *resolver);
- + tflite::MutableOpResolver* resolver);
-
- #endif // TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_RAGGED_PY_TFLITE_REGISTERER_H_
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_range_tflite.cc b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_range_tflite.cc
- index a35a6db9ad48f..9fc73dd0f9778 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_range_tflite.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_range_tflite.cc
- @@ -71,9 +71,12 @@ TfLiteStatus EvalT(TfLiteContext* context, TfLiteNode* node) {
- // nrows (number of output rows) is the size of the non-broadcast inputs,
- // or 1 if all inputs are scalars.
- std::vector<int> in_sizes;
- - if (!broadcast_starts) in_sizes.push_back(input_starts.dims->data[0]);
- - if (!broadcast_limits) in_sizes.push_back(input_limits.dims->data[0]);
- - if (!broadcast_deltas) in_sizes.push_back(input_deltas.dims->data[0]);
- + if (!broadcast_starts)
- + in_sizes.push_back(input_starts.dims->data[0]);
- + if (!broadcast_limits)
- + in_sizes.push_back(input_limits.dims->data[0]);
- + if (!broadcast_deltas)
- + in_sizes.push_back(input_deltas.dims->data[0]);
- if (std::adjacent_find(std::begin(in_sizes), std::end(in_sizes),
- std::not_equal_to<>()) != std::end(in_sizes)) {
- context->ReportError(
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_range_tflite_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_range_tflite_test.cc
- index 75a460538aaaa..fc838bee4d98b 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_range_tflite_test.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_range_tflite_test.cc
- @@ -39,7 +39,8 @@ class RaggedRangeOpModel : public SingleOpModel {
- public:
- static TensorType GetType();
-
- - RaggedRangeOpModel(const std::vector<T>& start, const std::vector<T>& limits,
- + RaggedRangeOpModel(const std::vector<T>& start,
- + const std::vector<T>& limits,
- const std::vector<T>& deltas) {
- const TensorType value_type = GetType();
- std::vector<std::vector<int>> shapes;
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_tensor_to_tensor_tflite.cc b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_tensor_to_tensor_tflite.cc
- index 09ac76c71b26c..ff5c14b8e5e08 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_tensor_to_tensor_tflite.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_tensor_to_tensor_tflite.cc
- @@ -140,8 +140,10 @@ RuntimeShape TensorShapeFromTensor(const TfLiteTensor& tensor) {
- }
-
- const TfLiteTensor* GetRowPartitionTensor(
- - const ConversionAttributes& conversion_attributes, TfLiteContext* context,
- - TfLiteNode* node, int dimension) {
- + const ConversionAttributes& conversion_attributes,
- + TfLiteContext* context,
- + TfLiteNode* node,
- + int dimension) {
- if (conversion_attributes.partition_types.front() ==
- tensorflow::RowPartitionType::FIRST_DIM_SIZE) {
- return &context->tensors[node->inputs->data[kFirstPartitionInputIndex + 1 +
- @@ -211,7 +213,9 @@ int GetMaxWidthRowSplit(const TfLiteTensor* tensor) {
- }
-
- int GetMaxWidth(const ConversionAttributes& conversion_attributes,
- - TfLiteContext* context, TfLiteNode* node, int dimension) {
- + TfLiteContext* context,
- + TfLiteNode* node,
- + int dimension) {
- const TfLiteTensor* tensor = GetRowPartitionTensor(
- conversion_attributes, context, node, dimension - 1);
- switch (conversion_attributes.GetRowPartitionTypeByDimension(dimension - 1)) {
- @@ -226,7 +230,8 @@ int GetMaxWidth(const ConversionAttributes& conversion_attributes,
- }
-
- RuntimeShape CombineRaggedTensorToTensorShapes(
- - int ragged_rank, const RuntimeShape& output_shape,
- + int ragged_rank,
- + const RuntimeShape& output_shape,
- const RuntimeShape& value_shape) {
- // TODO(mgubin): No checks, see
- // third_party/tensorflow/core/ops/ragged_to_dense_util.cc
- @@ -247,9 +252,13 @@ RuntimeShape CombineRaggedTensorToTensorShapes(
- }
-
- RuntimeShape CalculateOutputSize(
- - const ConversionAttributes& conversion_attributes, TfLiteContext* context,
- - TfLiteNode* node, int first_dimension, int ragged_rank,
- - const TfLiteTensor& values, const TfLiteTensor& default_value,
- + const ConversionAttributes& conversion_attributes,
- + TfLiteContext* context,
- + TfLiteNode* node,
- + int first_dimension,
- + int ragged_rank,
- + const TfLiteTensor& values,
- + const TfLiteTensor& default_value,
- const TfLiteTensor& output_shape) {
- RuntimeShape values_shape(values.dims->size, values.dims->data);
- RuntimeShape default_value_shape(default_value.dims->size,
- @@ -331,7 +340,8 @@ void CalculateFirstParentOutputIndex(int first_dimension,
- void CalculateOutputIndexValueRowID(const TfLiteTensor& value_rowids,
- const std::vector<int>& parent_output_index,
- int output_index_multiplier,
- - int output_size, std::vector<int>* result) {
- + int output_size,
- + std::vector<int>* result) {
- const RuntimeShape tensor_shape(value_rowids.dims->size,
- value_rowids.dims->data);
- const int index_size = tensor_shape.FlatSize();
- @@ -380,7 +390,8 @@ void CalculateOutputIndexValueRowID(const TfLiteTensor& value_rowids,
-
- void CalculateOutputIndexRowSplit(const TfLiteTensor& row_split,
- const std::vector<int>& parent_output_index,
- - int output_index_multiplier, int output_size,
- + int output_index_multiplier,
- + int output_size,
- std::vector<int>* result) {
- const RuntimeShape row_split_shape(row_split.dims->size,
- row_split.dims->data);
- @@ -421,10 +432,14 @@ void CalculateOutputIndexRowSplit(const TfLiteTensor& row_split,
- }
-
- TfLiteStatus CalculateOutputIndex(
- - const ConversionAttributes& conversion_attributes, TfLiteContext* context,
- - TfLiteNode* node, int dimension,
- - const std::vector<int>& parent_output_index, int output_index_multiplier,
- - int output_size, std::vector<int>* result) {
- + const ConversionAttributes& conversion_attributes,
- + TfLiteContext* context,
- + TfLiteNode* node,
- + int dimension,
- + const std::vector<int>& parent_output_index,
- + int output_index_multiplier,
- + int output_size,
- + std::vector<int>* result) {
- const TfLiteTensor* row_partition_tensor =
- GetRowPartitionTensor(conversion_attributes, context, node, dimension);
- auto partition_type =
- @@ -447,7 +462,8 @@ TfLiteStatus CalculateOutputIndex(
- }
-
- template <typename VALUE_TYPE>
- -void SetOutputT(TfLiteContext* context, int ragged_rank,
- +void SetOutputT(TfLiteContext* context,
- + int ragged_rank,
- const std::vector<int>& output_index,
- const TfLiteTensor& values_tensor,
- const TfLiteTensor& default_value_tensor,
- @@ -522,7 +538,8 @@ void SetOutputT(TfLiteContext* context, int ragged_rank,
- }
- }
-
- -void SetOutput(TfLiteContext* context, int ragged_rank,
- +void SetOutput(TfLiteContext* context,
- + int ragged_rank,
- const std::vector<int>& output_index,
- const TfLiteTensor& values_tensor,
- const TfLiteTensor& default_value_tensor,
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_tensor_to_tensor_tflite_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_tensor_to_tensor_tflite_test.cc
- index b1cde57c47c68..2f7a2a95b8478 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_tensor_to_tensor_tflite_test.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_tensor_to_tensor_tflite_test.cc
- @@ -82,7 +82,8 @@ class RaggedTensorToTensorOpModel : public SingleOpModel {
- std::vector<int32> GetOutputInt() { return ExtractVector<int32>(output_); }
-
- void InvokeFloat(const std::vector<int>& shape,
- - const std::vector<float>& values, float default_value,
- + const std::vector<float>& values,
- + float default_value,
- const std::vector<std::vector<int>>& partition_values) {
- PopulateTensor(input_shape_, shape);
- PopulateTensor(input_values_, values);
- @@ -93,7 +94,8 @@ class RaggedTensorToTensorOpModel : public SingleOpModel {
- SingleOpModel::Invoke();
- }
- void InvokeInt(const std::vector<int>& shape,
- - const std::vector<int32>& values, int32 default_value,
- + const std::vector<int32>& values,
- + int32 default_value,
- const std::vector<std::vector<int>>& partition_values) {
- PopulateTensor(input_shape_, shape);
- PopulateTensor(input_values_, values);
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/model_converter.cc b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/model_converter.cc
- index 4e2b87de37327..47ba9fdfebcae 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/model_converter.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/model_converter.cc
- @@ -15,8 +15,8 @@ limitations under the License.
-
- #include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/model_converter.h"
-
- -#include "absl/status/status.h" // from @com_google_absl
- -#include "absl/strings/str_replace.h" // from @com_google_absl
- +#include "absl/status/status.h" // from @com_google_absl
- +#include "absl/strings/str_replace.h" // from @com_google_absl
- #include "src/sentencepiece_model.pb.h" // from @com_google_sentencepiece
- #include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/decoder_config_generated.h"
- #include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/double_array_trie_builder.h"
- @@ -48,7 +48,8 @@ DecodePrecompiledCharsmap(
- }
-
- tflite::support::StatusOr<std::string> ConvertSentencepieceModelToFlatBuffer(
- - const std::string& model_config_str, int encoding_offset) {
- + const std::string& model_config_str,
- + int encoding_offset) {
- ::sentencepiece::ModelProto model_config;
- if (!model_config.ParseFromString(model_config_str)) {
- return absl::InvalidArgumentError(
- @@ -128,7 +129,8 @@ tflite::support::StatusOr<std::string> ConvertSentencepieceModelToFlatBuffer(
-
- tflite::support::StatusOr<std::string>
- ConvertSentencepieceModelToFlatBufferForDecoder(
- - const std::string& model_config_str, int encoding_offset) {
- + const std::string& model_config_str,
- + int encoding_offset) {
- ::sentencepiece::ModelProto model_config;
- if (!model_config.ParseFromString(model_config_str)) {
- return absl::InvalidArgumentError(
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/model_converter.h b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/model_converter.h
- index 5687b6287d140..03b3596820886 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/model_converter.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/model_converter.h
- @@ -27,13 +27,15 @@ namespace sentencepiece {
- // Converts Sentencepiece configuration to flatbuffer format.
- // encoding_offset is used by some encoders that combine different encodings.
- tflite::support::StatusOr<std::string> ConvertSentencepieceModelToFlatBuffer(
- - const std::string& model_config_str, int encoding_offset = 0);
- + const std::string& model_config_str,
- + int encoding_offset = 0);
-
- // Converts Sentencepiece configuration to flatbuffer format for encoder.
- // encoding_offset is used by some encoders that combine different encodings.
- tflite::support::StatusOr<std::string>
- ConvertSentencepieceModelToFlatBufferForDecoder(
- - const std::string& model_config_str, int encoding_offset = 0);
- + const std::string& model_config_str,
- + int encoding_offset = 0);
-
- // The functions that are provided for the Python wrapper.
- std::string ConvertSentencepieceModel(const std::string& model_string);
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_decoder_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_decoder_test.cc
- index 8e130ef73b9b6..94161c2ac4c4e 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_decoder_test.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_decoder_test.cc
- @@ -19,9 +19,9 @@ limitations under the License.
-
- #include <gmock/gmock.h>
- #include <gtest/gtest.h>
- -#include "absl/flags/flag.h" // from @com_google_absl
- -#include "absl/strings/str_format.h" // from @com_google_absl
- -#include "src/sentencepiece.pb.h" // from @com_google_sentencepiece
- +#include "absl/flags/flag.h" // from @com_google_absl
- +#include "absl/strings/str_format.h" // from @com_google_absl
- +#include "src/sentencepiece.pb.h" // from @com_google_sentencepiece
- #include "src/sentencepiece_processor.h" // from @com_google_sentencepiece
- #include "tensorflow/core/platform/env.h"
- #include "tensorflow_lite_support/cc/test/test_utils.h"
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder.cc b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder.cc
- index 45fde32237c65..4148f8e96627a 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder.cc
- @@ -31,7 +31,8 @@ const char kSpaceSymbol[] = "\xe2\x96\x81";
-
- template <typename processing_callback>
- std::tuple<std::string, std::vector<int>> process_string(
- - const std::string& input, const std::vector<int>& offsets,
- + const std::string& input,
- + const std::vector<int>& offsets,
- const processing_callback& pc) {
- std::string result_string;
- result_string.reserve(input.size());
- @@ -78,7 +79,9 @@ std::tuple<int, utils::string_view> remove_extra_whitespaces(const char* data,
- }
-
- std::tuple<int, utils::string_view> find_replacement(
- - const char* data, int len, const DoubleArrayTrie& dat,
- + const char* data,
- + int len,
- + const DoubleArrayTrie& dat,
- const flatbuffers::Vector<int8_t>& replacements) {
- const auto max_match = dat.LongestPrefixMatch(utils::string_view(data, len));
- if (!max_match.empty()) {
- @@ -94,7 +97,8 @@ std::tuple<int, utils::string_view> find_replacement(
- } // namespace
-
- std::tuple<std::string, std::vector<int>> NormalizeString(
- - const std::string& in_string, const EncoderConfig& config) {
- + const std::string& in_string,
- + const EncoderConfig& config) {
- std::vector<int> output_offsets;
- std::string result = in_string;
- output_offsets.reserve(in_string.length());
- @@ -145,8 +149,10 @@ std::tuple<std::string, std::vector<int>> NormalizeString(
-
- EncoderResult EncodeNormalizedString(const std::string& str,
- const std::vector<int>& offsets,
- - const EncoderConfig& config, bool add_bos,
- - bool add_eos, bool reverse) {
- + const EncoderConfig& config,
- + bool add_bos,
- + bool add_eos,
- + bool reverse) {
- const DoubleArrayTrie piece_matcher(config.pieces()->nodes());
- const flatbuffers::Vector<float>* piece_scores = config.pieces_scores();
- const int unknown_code = config.unknown_code();
- @@ -219,8 +225,11 @@ EncoderResult EncodeNormalizedString(const std::string& str,
- return result;
- }
-
- -EncoderResult EncodeString(const std::string& string, const void* config_buffer,
- - bool add_bos, bool add_eos, bool reverse) {
- +EncoderResult EncodeString(const std::string& string,
- + const void* config_buffer,
- + bool add_bos,
- + bool add_eos,
- + bool reverse) {
- // Get the config from the buffer.
- const EncoderConfig* config = GetEncoderConfig(config_buffer);
- if (config->version() != EncoderVersion::EncoderVersion_SENTENCE_PIECE) {
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder.h b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder.h
- index 44d6e88f2531c..b89154cbfa396 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder.h
- @@ -37,12 +37,16 @@ struct EncoderResult {
- std::vector<int> offsets;
- };
- std::tuple<std::string, std::vector<int>> NormalizeString(
- - const std::string& in_string, const EncoderConfig& config);
- + const std::string& in_string,
- + const EncoderConfig& config);
-
- // Encodes one string and returns ids and offsets. Takes the configuration as a
- // type-erased buffer.
- -EncoderResult EncodeString(const std::string& string, const void* config_buffer,
- - bool add_bos, bool add_eos, bool reverse);
- +EncoderResult EncodeString(const std::string& string,
- + const void* config_buffer,
- + bool add_bos,
- + bool add_eos,
- + bool reverse);
-
- } // namespace sentencepiece
- } // namespace custom
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder_test.cc
- index e2787c785e8c4..dd956a22b26c1 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder_test.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder_test.cc
- @@ -19,10 +19,10 @@ limitations under the License.
-
- #include <gmock/gmock.h>
- #include <gtest/gtest.h>
- -#include "absl/flags/flag.h" // from @com_google_absl
- -#include "absl/status/status.h" // from @com_google_absl
- -#include "absl/strings/str_format.h" // from @com_google_absl
- -#include "src/sentencepiece.pb.h" // from @com_google_sentencepiece
- +#include "absl/flags/flag.h" // from @com_google_absl
- +#include "absl/status/status.h" // from @com_google_absl
- +#include "absl/strings/str_format.h" // from @com_google_absl
- +#include "src/sentencepiece.pb.h" // from @com_google_sentencepiece
- #include "src/sentencepiece_processor.h" // from @com_google_sentencepiece
- #include "tensorflow/core/platform/env.h"
- #include "tensorflow_lite_support/cc/test/test_utils.h"
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/py_tflite_registerer.h b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/py_tflite_registerer.h
- index deb4e4ee08dc2..3efcfefc6438d 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/py_tflite_registerer.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/py_tflite_registerer.h
- @@ -20,6 +20,6 @@ limitations under the License.
- // C-function that is called from the Python Wrapper.
-
- extern "C" void TFLite_SentencepieceTokenizerRegisterer(
- - tflite::MutableOpResolver *resolver);
- + tflite::MutableOpResolver* resolver);
-
- #endif // TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_PY_TFLITE_REGISTERER_H_
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_detokenizer_tflite.cc b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_detokenizer_tflite.cc
- index 54b34e4e33196..f5be376b45e12 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_detokenizer_tflite.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_detokenizer_tflite.cc
- @@ -35,7 +35,8 @@ namespace detokenizer {
-
- constexpr int kOutputValuesInd = 0;
- // Initializes text encoder object from serialized parameters.
- -void* Initialize(TfLiteContext* /*context*/, const char* /*buffer*/,
- +void* Initialize(TfLiteContext* /*context*/,
- + const char* /*buffer*/,
- size_t /*length*/) {
- return nullptr;
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_tokenizer_op.cc b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_tokenizer_op.cc
- index 41fc5aa28bf30..68f8e64492394 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_tokenizer_op.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_tokenizer_op.cc
- @@ -16,16 +16,16 @@ limitations under the License.
- #include <iterator>
- #include <vector>
-
- -#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder.h"
- -#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_tokenizer.h"
- #include "tensorflow/core/framework/op.h"
- #include "tensorflow/core/framework/op_kernel.h"
- #include "tensorflow/core/framework/shape_inference.h"
- #include "tensorflow/core/framework/tensor.h"
- #include "tensorflow/core/protobuf/error_codes.pb.h"
- +#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder.h"
- +#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_tokenizer.h"
-
- namespace tensorflow {
- -namespace ops{
- +namespace ops {
-
- // copied from third_party/tensorflow_text/core/ops/sentencepiece_ops.cc
- REGISTER_OP("TFSentencepieceTokenizeOp")
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_tokenizer_tflite.cc b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_tokenizer_tflite.cc
- index 8309a6a2616fd..edb0160b508a3 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_tokenizer_tflite.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_tokenizer_tflite.cc
- @@ -16,8 +16,6 @@ limitations under the License.
- /**
- * Sentencepiece tflite tokenizer implementation.
- */
- -#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder.h"
- -#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_tokenizer.h"
- #include "flatbuffers/flexbuffers.h" // from @flatbuffers
- #include "tensorflow/lite/c/common.h"
- #include "tensorflow/lite/context.h"
- @@ -25,6 +23,8 @@ limitations under the License.
- #include "tensorflow/lite/kernels/kernel_util.h"
- #include "tensorflow/lite/model.h"
- #include "tensorflow/lite/string_util.h"
- +#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder.h"
- +#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_tokenizer.h"
-
- namespace tflite {
- namespace ops {
- @@ -47,7 +47,8 @@ TfLiteIntArray* CreateSizeArray(const std::initializer_list<int>& sizes) {
- } // namespace
-
- // Initializes text encoder object from serialized parameters.
- -void* Initialize(TfLiteContext* /*context*/, const char* /*buffer*/,
- +void* Initialize(TfLiteContext* /*context*/,
- + const char* /*buffer*/,
- size_t /*length*/) {
- return nullptr;
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer.cc b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer.cc
- index dad2f0004be06..8096a5008bd12 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer.cc
- @@ -19,10 +19,10 @@ limitations under the License.
- #include <utility>
- #include <vector>
-
- +#include "libutf/utf.h"
- #include "tensorflow/lite/context.h"
- #include "tensorflow/lite/kernels/kernel_util.h"
- #include "tensorflow/lite/string_util.h"
- -#include "libutf/utf.h"
-
- constexpr int kInput = 0;
- constexpr int kOutputValues = 0;
- @@ -49,7 +49,7 @@ inline bool OutputIsPaddedTensor(TfLiteNode* node) {
- }
-
- inline int charntorune(Rune* r, const char* s, int n) {
- - const int bytes_read = chartorune(r, const_cast<char *>(s));
- + const int bytes_read = chartorune(r, const_cast<char*>(s));
- if (bytes_read > n) {
- *r = Runeerror;
- return 0;
- @@ -66,7 +66,8 @@ std::vector<std::pair<const char*, int>> Tokenize(StringRef str) {
- while (n > 0) {
- Rune r;
- int c = charntorune(&r, p, n);
- - if (r == Runeerror) break;
- + if (r == Runeerror)
- + break;
-
- if (isspacerune(r)) {
- if (start != nullptr) {
- @@ -91,7 +92,8 @@ std::vector<std::pair<const char*, int>> Tokenize(StringRef str) {
-
- TfLiteStatus WritePaddedOutput(
- const std::vector<std::vector<std::pair<const char*, int>>>& list_of_tokens,
- - const TfLiteTensor* input, TfLiteTensor* output_values) {
- + const TfLiteTensor* input,
- + TfLiteTensor* output_values) {
- TfLiteIntArray* output_shape = TfLiteIntArrayCreate(NumDimensions(input) + 1);
- for (int i = 0; i < NumDimensions(input); ++i) {
- output_shape->data[i] = SizeOfDimension(input, i);
- @@ -118,7 +120,8 @@ TfLiteStatus WritePaddedOutput(
-
- TfLiteStatus WriteRaggedOutput(
- const std::vector<std::vector<std::pair<const char*, int>>>& list_of_tokens,
- - const TfLiteTensor* input, TfLiteTensor* output_values,
- + const TfLiteTensor* input,
- + TfLiteTensor* output_values,
- std::vector<TfLiteTensor*> nested_row_splits) {
- // The outer dimensions of the ragged tensor are all non-ragged.
- for (int i = 0; i < nested_row_splits.size() - 1; ++i) {
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer_op_resolver.cc b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer_op_resolver.cc
- index 534fbef4aff2d..6166bc149bc00 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer_op_resolver.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer_op_resolver.cc
- @@ -15,8 +15,8 @@ limitations under the License.
-
- #include "tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer_op_resolver.h"
-
- -#include "tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer.h"
- #include "tensorflow/lite/mutable_op_resolver.h"
- +#include "tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer.h"
-
- namespace tflite {
- namespace ops {
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/audio_classifier_demo.cc b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/audio_classifier_demo.cc
- index 7447870046f48..904673a95b799 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/audio_classifier_demo.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/audio_classifier_demo.cc
- @@ -28,18 +28,26 @@ limitations under the License.
- #include "absl/flags/parse.h" // from @com_google_absl
- #include "tensorflow_lite_support/examples/task/audio/desktop/audio_classifier_lib.h"
-
- -ABSL_FLAG(std::string, model_path, "",
- +ABSL_FLAG(std::string,
- + model_path,
- + "",
- "Absolute path to the '.tflite' audio classification model.");
- -ABSL_FLAG(std::string, audio_wav_path, "",
- +ABSL_FLAG(std::string,
- + audio_wav_path,
- + "",
- "Absolute path to the 16-bit PCM WAV file to classify. The WAV "
- "file must be monochannel and has a sampling rate matches the model "
- "expected sampling rate (as in the Metadata). If the WAV file is "
- "longer than what the model requires, only the beginning section is "
- "used for inference.");
- -ABSL_FLAG(float, score_threshold, 0.001f,
- +ABSL_FLAG(float,
- + score_threshold,
- + 0.001f,
- "Apply a filter on the results. Only display classes with score "
- "higher than the threshold.");
- -ABSL_FLAG(bool, use_coral, false,
- +ABSL_FLAG(bool,
- + use_coral,
- + false,
- "If true, inference will be delegated to a connected Coral Edge TPU "
- "device.");
-
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/audio_classifier_lib.cc b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/audio_classifier_lib.cc
- index 36d6633d902e3..a843501ec3d75 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/audio_classifier_lib.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/audio_classifier_lib.cc
- @@ -19,7 +19,7 @@ limitations under the License.
- #include <string>
- #include <vector>
-
- -#include "absl/status/status.h" // from @com_google_absl
- +#include "absl/status/status.h" // from @com_google_absl
- #include "absl/strings/str_format.h" // from @com_google_absl
- #include "tensorflow_lite_support/cc/port/status_macros.h"
- #include "tensorflow_lite_support/cc/port/statusor.h"
- @@ -34,7 +34,8 @@ namespace task {
- namespace audio {
-
- tflite::support::StatusOr<AudioBuffer> LoadAudioBufferFromFile(
- - const std::string& wav_file, int buffer_size,
- + const std::string& wav_file,
- + int buffer_size,
- std::vector<float>* wav_data) {
- std::string contents = ReadFile(wav_file);
-
- @@ -55,7 +56,8 @@ tflite::support::StatusOr<AudioBuffer> LoadAudioBufferFromFile(
- }
-
- tflite::support::StatusOr<ClassificationResult> Classify(
- - const std::string& model_path, const std::string& wav_file,
- + const std::string& model_path,
- + const std::string& wav_file,
- bool use_coral) {
- AudioClassifierOptions options;
- options.mutable_base_options()->mutable_model_file()->set_file_name(
- @@ -97,7 +99,8 @@ void Display(const ClassificationResult& result, float score_threshold) {
- std::cout << absl::StrFormat("\nHead[%d]: %s\n", i, head.head_name());
- for (int j = 0; j < head.classes_size(); j++) {
- const auto& category = head.classes(j);
- - if (category.score() < score_threshold) continue;
- + if (category.score() < score_threshold)
- + continue;
- std::cout << absl::StrFormat("\tcategory[%s]: %.5f\t",
- category.class_name(), category.score());
- if (!category.display_name().empty()) {
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/audio_classifier_lib.h b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/audio_classifier_lib.h
- index 6d23078ba3e19..13b2d7792e025 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/audio_classifier_lib.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/audio_classifier_lib.h
- @@ -28,7 +28,8 @@ namespace audio {
- // than what the model requires, only the beginning section is used for
- // inference.
- tflite::support::StatusOr<ClassificationResult> Classify(
- - const std::string& model_path, const std::string& wav_file,
- + const std::string& model_path,
- + const std::string& wav_file,
- bool use_coral = false);
-
- // Prints the output classification result in the standard output. It only
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/bert_nl_classifier_demo.cc b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/bert_nl_classifier_demo.cc
- index 02eed2332b2e4..5203200808d60 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/bert_nl_classifier_demo.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/bert_nl_classifier_demo.cc
- @@ -15,18 +15,22 @@ limitations under the License.
- #include <iostream>
- #include <limits>
-
- -#include "absl/flags/flag.h" // from @com_google_absl
- -#include "absl/flags/parse.h" // from @com_google_absl
- -#include "absl/status/status.h" // from @com_google_absl
- +#include "absl/flags/flag.h" // from @com_google_absl
- +#include "absl/flags/parse.h" // from @com_google_absl
- +#include "absl/status/status.h" // from @com_google_absl
- #include "absl/strings/str_format.h" // from @com_google_absl
- #include "tensorflow_lite_support/cc/port/statusor.h"
- #include "tensorflow_lite_support/cc/task/core/category.h"
- #include "tensorflow_lite_support/cc/task/text/bert_nl_classifier.h"
-
- -ABSL_FLAG(std::string, model_path, "",
- +ABSL_FLAG(std::string,
- + model_path,
- + "",
- "Absolute path to the '.tflite' bert classification model.");
- ABSL_FLAG(std::string, text, "", "Text to classify.");
- -ABSL_FLAG(bool, use_coral, false,
- +ABSL_FLAG(bool,
- + use_coral,
- + false,
- "If true, inference will be delegated to a connected Coral Edge TPU "
- "device.");
-
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/bert_question_answerer_demo.cc b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/bert_question_answerer_demo.cc
- index 4eaa2bbbdd9f5..f2577cfad54c2 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/bert_question_answerer_demo.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/bert_question_answerer_demo.cc
- @@ -15,19 +15,25 @@ limitations under the License.
- #include <iostream>
- #include <limits>
-
- -#include "absl/flags/flag.h" // from @com_google_absl
- -#include "absl/flags/parse.h" // from @com_google_absl
- -#include "absl/status/status.h" // from @com_google_absl
- +#include "absl/flags/flag.h" // from @com_google_absl
- +#include "absl/flags/parse.h" // from @com_google_absl
- +#include "absl/status/status.h" // from @com_google_absl
- #include "absl/strings/str_format.h" // from @com_google_absl
- #include "tensorflow_lite_support/cc/port/statusor.h"
- #include "tensorflow_lite_support/cc/task/text/bert_question_answerer.h"
-
- -ABSL_FLAG(std::string, model_path, "",
- +ABSL_FLAG(std::string,
- + model_path,
- + "",
- "Absolute path to the '.tflite' bert question answerer model.");
- ABSL_FLAG(std::string, question, "", "Question to ask.");
- -ABSL_FLAG(std::string, context, "",
- +ABSL_FLAG(std::string,
- + context,
- + "",
- "Context the asked question is based upon.");
- -ABSL_FLAG(bool, use_coral, false,
- +ABSL_FLAG(bool,
- + use_coral,
- + false,
- "If true, inference will be delegated to a connected Coral Edge TPU "
- "device.");
-
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/nl_classifier_demo.cc b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/nl_classifier_demo.cc
- index 49f233ce1e74c..613744ffdb20b 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/nl_classifier_demo.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/nl_classifier_demo.cc
- @@ -15,18 +15,22 @@ limitations under the License.
- #include <iostream>
- #include <limits>
-
- -#include "absl/flags/flag.h" // from @com_google_absl
- -#include "absl/flags/parse.h" // from @com_google_absl
- -#include "absl/status/status.h" // from @com_google_absl
- +#include "absl/flags/flag.h" // from @com_google_absl
- +#include "absl/flags/parse.h" // from @com_google_absl
- +#include "absl/status/status.h" // from @com_google_absl
- #include "absl/strings/str_format.h" // from @com_google_absl
- #include "tensorflow_lite_support/cc/port/statusor.h"
- #include "tensorflow_lite_support/cc/task/core/category.h"
- #include "tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h"
-
- -ABSL_FLAG(std::string, model_path, "",
- +ABSL_FLAG(std::string,
- + model_path,
- + "",
- "Absolute path to the '.tflite' classification model.");
- ABSL_FLAG(std::string, text, "", "Text to classify.");
- -ABSL_FLAG(bool, use_coral, false,
- +ABSL_FLAG(bool,
- + use_coral,
- + false,
- "If true, inference will be delegated to a connected Coral Edge TPU "
- "device.");
-
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/text_embedder_demo.cc b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/text_embedder_demo.cc
- index 875b5f4a771bd..eca8a002d3293 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/text_embedder_demo.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/text_embedder_demo.cc
- @@ -24,9 +24,9 @@ limitations under the License.
- #include <iostream>
- #include <memory>
-
- -#include "absl/flags/flag.h" // from @com_google_absl
- -#include "absl/flags/parse.h" // from @com_google_absl
- -#include "absl/status/status.h" // from @com_google_absl
- +#include "absl/flags/flag.h" // from @com_google_absl
- +#include "absl/flags/parse.h" // from @com_google_absl
- +#include "absl/status/status.h" // from @com_google_absl
- #include "absl/strings/str_format.h" // from @com_google_absl
- #include "tensorflow_lite_support/cc/port/configuration_proto_inc.h"
- #include "tensorflow_lite_support/cc/port/status_macros.h"
- @@ -36,19 +36,29 @@ limitations under the License.
- #include "tensorflow_lite_support/cc/task/text/text_embedder.h"
- #include "tensorflow_lite_support/examples/task/text/desktop/universal_sentence_encoder_qa_op_resolver.h"
-
- -ABSL_FLAG(std::string, model_path, "",
- +ABSL_FLAG(std::string,
- + model_path,
- + "",
- "Absolute path to the '.tflite' text embedder model.");
- -ABSL_FLAG(std::string, first_sentence, "",
- +ABSL_FLAG(std::string,
- + first_sentence,
- + "",
- "First sentence, whose feature vector will be extracted and compared "
- "to the second sentence using cosine similarity.");
- -ABSL_FLAG(std::string, second_sentence, "",
- +ABSL_FLAG(std::string,
- + second_sentence,
- + "",
- "Second sentence, whose feature vector will be extracted and "
- "compared to the first sentence using cosine similarity.");
- -ABSL_FLAG(bool, l2_normalize, false,
- +ABSL_FLAG(bool,
- + l2_normalize,
- + false,
- "If true, the raw feature vectors returned by the image embedder "
- "will be normalized with L2-norm. Generally only needed if the model "
- "doesn't already contain a L2_NORMALIZATION TFLite Op.");
- -ABSL_FLAG(bool, use_coral, false,
- +ABSL_FLAG(bool,
- + use_coral,
- + false,
- "If true, inference will be delegated to a connected Coral Edge TPU "
- "device.");
-
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/text_searcher_demo.cc b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/text_searcher_demo.cc
- index 5ea9b7e63b50e..0299428964797 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/text_searcher_demo.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/text_searcher_demo.cc
- @@ -24,9 +24,9 @@ limitations under the License.
- #include <iostream>
- #include <memory>
-
- -#include "absl/flags/flag.h" // from @com_google_absl
- -#include "absl/flags/parse.h" // from @com_google_absl
- -#include "absl/status/status.h" // from @com_google_absl
- +#include "absl/flags/flag.h" // from @com_google_absl
- +#include "absl/flags/parse.h" // from @com_google_absl
- +#include "absl/status/status.h" // from @com_google_absl
- #include "absl/strings/str_format.h" // from @com_google_absl
- #include "tensorflow_lite_support/cc/port/configuration_proto_inc.h"
- #include "tensorflow_lite_support/cc/port/status_macros.h"
- @@ -39,21 +39,33 @@ limitations under the License.
- #include "tensorflow_lite_support/cc/task/text/text_searcher.h"
- #include "tensorflow_lite_support/examples/task/text/desktop/universal_sentence_encoder_qa_op_resolver.h"
-
- -ABSL_FLAG(std::string, model_path, "",
- +ABSL_FLAG(std::string,
- + model_path,
- + "",
- "Absolute path to the '.tflite' text embedder model.");
- -ABSL_FLAG(std::string, index_path, "",
- +ABSL_FLAG(std::string,
- + index_path,
- + "",
- "Absolute path to the index to search into. Mandatory only if the "
- "index is not attached to the output tensor metadata of the embedder "
- "model as an AssociatedFile with type SCANN_INDEX_FILE.");
- -ABSL_FLAG(std::string, input_sentence, "",
- +ABSL_FLAG(std::string,
- + input_sentence,
- + "",
- "Input sentence whose nearest-neighbors to search for in the index.");
- -ABSL_FLAG(int32, max_results, 5,
- +ABSL_FLAG(int32,
- + max_results,
- + 5,
- "Maximum number of nearest-neghbors to display.");
- -ABSL_FLAG(bool, l2_normalize, false,
- +ABSL_FLAG(bool,
- + l2_normalize,
- + false,
- "If true, the raw feature vectors returned by the image embedder "
- "will be normalized with L2-norm. Generally only needed if the model "
- "doesn't already contain a L2_NORMALIZATION TFLite Op.");
- -ABSL_FLAG(bool, use_coral, false,
- +ABSL_FLAG(bool,
- + use_coral,
- + false,
- "If true, inference will be delegated to a connected Coral Edge TPU "
- "device.");
-
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/universal_sentence_encoder_qa_demo.cc b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/universal_sentence_encoder_qa_demo.cc
- index 076a60a2330af..f7621a5a8a1b4 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/universal_sentence_encoder_qa_demo.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/universal_sentence_encoder_qa_demo.cc
- @@ -14,9 +14,9 @@ limitations under the License.
- ==============================================================================*/
-
- // Demostration the usage of UniversalSentenceEncoderQA.
- -#include "absl/flags/flag.h" // from @com_google_absl
- -#include "absl/flags/parse.h" // from @com_google_absl
- -#include "absl/status/status.h" // from @com_google_absl
- +#include "absl/flags/flag.h" // from @com_google_absl
- +#include "absl/flags/parse.h" // from @com_google_absl
- +#include "absl/status/status.h" // from @com_google_absl
- #include "absl/strings/str_split.h" // from @com_google_absl
- #include "tensorflow_lite_support/cc/task/text/universal_sentence_encoder_qa.h"
- #include "tensorflow_lite_support/examples/task/text/desktop/universal_sentence_encoder_qa_op_resolver.h"
- @@ -29,12 +29,17 @@ using tflite::task::text::RetrievalOutput;
- using tflite::task::text::UniversalSentenceEncoderQA;
- } // namespace
-
- -ABSL_FLAG(std::string, model_path, "",
- +ABSL_FLAG(std::string,
- + model_path,
- + "",
- "Absolute path to the '.tflite' UniversalSentenceEncoderQA model.");
- -ABSL_FLAG(std::string, question, "How are you feeling today?",
- +ABSL_FLAG(std::string,
- + question,
- + "How are you feeling today?",
- "Question to ask.");
- ABSL_FLAG(
- - std::string, answers,
- + std::string,
- + answers,
- "I'm not feeling very well.:Paris is the capital of France.:He looks good.",
- "Candidate answers seperated by `:`.");
-
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_classifier_demo.cc b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_classifier_demo.cc
- index f29bd2de9c535..0904920faa7dd 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_classifier_demo.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_classifier_demo.cc
- @@ -22,9 +22,9 @@ limitations under the License.
-
- #include <iostream>
-
- -#include "absl/flags/flag.h" // from @com_google_absl
- -#include "absl/flags/parse.h" // from @com_google_absl
- -#include "absl/status/status.h" // from @com_google_absl
- +#include "absl/flags/flag.h" // from @com_google_absl
- +#include "absl/flags/parse.h" // from @com_google_absl
- +#include "absl/status/status.h" // from @com_google_absl
- #include "absl/strings/str_format.h" // from @com_google_absl
- #include "tensorflow_lite_support/cc/port/statusor.h"
- #include "tensorflow_lite_support/cc/task/core/external_file_handler.h"
- @@ -36,29 +36,43 @@ limitations under the License.
- #include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h"
- #include "tensorflow_lite_support/cc/task/vision/utils/image_utils.h"
-
- -ABSL_FLAG(std::string, model_path, "",
- +ABSL_FLAG(std::string,
- + model_path,
- + "",
- "Absolute path to the '.tflite' image classifier model.");
- -ABSL_FLAG(std::string, image_path, "",
- +ABSL_FLAG(std::string,
- + image_path,
- + "",
- "Absolute path to the image to classify. The image must be RGB or "
- "RGBA (grayscale is not supported). The image EXIF orientation "
- "flag, if any, is NOT taken into account.");
- -ABSL_FLAG(int32, max_results, 5,
- +ABSL_FLAG(int32,
- + max_results,
- + 5,
- "Maximum number of classification results to display.");
- -ABSL_FLAG(float, score_threshold, 0,
- +ABSL_FLAG(float,
- + score_threshold,
- + 0,
- "Classification results with a confidence score below this value are "
- "rejected. If >= 0, overrides the score threshold(s) provided in the "
- "TFLite Model Metadata. Ignored otherwise.");
- ABSL_FLAG(
- - std::vector<std::string>, class_name_whitelist, {},
- + std::vector<std::string>,
- + class_name_whitelist,
- + {},
- "Comma-separated list of class names that acts as a whitelist. If "
- "non-empty, classification results whose 'class_name' is not in this list "
- "are filtered out. Mutually exclusive with 'class_name_blacklist'.");
- ABSL_FLAG(
- - std::vector<std::string>, class_name_blacklist, {},
- + std::vector<std::string>,
- + class_name_blacklist,
- + {},
- "Comma-separated list of class names that acts as a blacklist. If "
- "non-empty, classification results whose 'class_name' is in this list "
- "are filtered out. Mutually exclusive with 'class_name_whitelist'.");
- -ABSL_FLAG(bool, use_coral, false,
- +ABSL_FLAG(bool,
- + use_coral,
- + false,
- "If true, inference will be delegated to a connected Coral Edge TPU "
- "device.");
-
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_embedder_demo.cc b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_embedder_demo.cc
- index 50d615a486751..f8b1796bc3865 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_embedder_demo.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_embedder_demo.cc
- @@ -26,9 +26,9 @@ limitations under the License.
-
- #include <iostream>
-
- -#include "absl/flags/flag.h" // from @com_google_absl
- -#include "absl/flags/parse.h" // from @com_google_absl
- -#include "absl/status/status.h" // from @com_google_absl
- +#include "absl/flags/flag.h" // from @com_google_absl
- +#include "absl/flags/parse.h" // from @com_google_absl
- +#include "absl/status/status.h" // from @com_google_absl
- #include "absl/strings/str_format.h" // from @com_google_absl
- #include "tensorflow_lite_support/cc/port/statusor.h"
- #include "tensorflow_lite_support/cc/task/core/external_file_handler.h"
- @@ -39,28 +39,40 @@ limitations under the License.
- #include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h"
- #include "tensorflow_lite_support/cc/task/vision/utils/image_utils.h"
-
- -ABSL_FLAG(std::string, model_path, "",
- +ABSL_FLAG(std::string,
- + model_path,
- + "",
- "Absolute path to the '.tflite' image embedder model.");
- -ABSL_FLAG(std::string, first_image_path, "",
- +ABSL_FLAG(std::string,
- + first_image_path,
- + "",
- "Absolute path to the first image, whose feature vector will be "
- "extracted and compared to the second image using cosine similarity. "
- "The image must be RGB or RGBA (grayscale is not supported). The "
- "image EXIF orientation flag, if any, is NOT taken into account.");
- -ABSL_FLAG(std::string, second_image_path, "",
- +ABSL_FLAG(std::string,
- + second_image_path,
- + "",
- "Absolute path to the second image, whose feature vector will be "
- "extracted and compared to the first image using cosine similarity. "
- "The image must be RGB or RGBA (grayscale is not supported). The "
- "image EXIF orientation flag, if any, is NOT taken into account.");
- -ABSL_FLAG(bool, l2_normalize, false,
- +ABSL_FLAG(bool,
- + l2_normalize,
- + false,
- "If true, the raw feature vectors returned by the image embedder "
- "will be normalized with L2-norm. Generally only needed if the model "
- "doesn't already contain a L2_NORMALIZATION TFLite Op.");
- ABSL_FLAG(
- - bool, quantize, false,
- + bool,
- + quantize,
- + false,
- "If true, the raw feature vectors returned by the image embedder will "
- "be quantized to 8 bit integers (uniform quantization) via post-processing "
- "before cosine similarity is computed.");
- -ABSL_FLAG(bool, use_coral, false,
- +ABSL_FLAG(bool,
- + use_coral,
- + false,
- "If true, inference will be delegated to a connected Coral Edge TPU "
- "device.");
-
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_searcher_demo.cc b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_searcher_demo.cc
- index b661447614bc7..e4074f76dba5b 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_searcher_demo.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_searcher_demo.cc
- @@ -25,9 +25,9 @@ limitations under the License.
- #include <iostream>
- #include <memory>
-
- -#include "absl/flags/flag.h" // from @com_google_absl
- -#include "absl/flags/parse.h" // from @com_google_absl
- -#include "absl/status/status.h" // from @com_google_absl
- +#include "absl/flags/flag.h" // from @com_google_absl
- +#include "absl/flags/parse.h" // from @com_google_absl
- +#include "absl/status/status.h" // from @com_google_absl
- #include "absl/strings/str_format.h" // from @com_google_absl
- #include "tensorflow_lite_support/cc/port/status_macros.h"
- #include "tensorflow_lite_support/cc/port/statusor.h"
- @@ -42,23 +42,35 @@ limitations under the License.
- #include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h"
- #include "tensorflow_lite_support/cc/task/vision/utils/image_utils.h"
-
- -ABSL_FLAG(std::string, model_path, "",
- +ABSL_FLAG(std::string,
- + model_path,
- + "",
- "Absolute path to the '.tflite' image embedder model.");
- -ABSL_FLAG(std::string, index_path, "",
- +ABSL_FLAG(std::string,
- + index_path,
- + "",
- "Absolute path to the index to search into. Mandatory only if the "
- "index is not attached to the output tensor metadata of the embedder "
- "model as an AssociatedFile with type SCANN_INDEX_FILE.");
- -ABSL_FLAG(std::string, image_path, "",
- +ABSL_FLAG(std::string,
- + image_path,
- + "",
- "Absolute path to the image to search. The image must be RGB or "
- "RGBA (grayscale is not supported). The image EXIF orientation "
- "flag, if any, is NOT taken into account.");
- -ABSL_FLAG(int32, max_results, 5,
- +ABSL_FLAG(int32,
- + max_results,
- + 5,
- "Maximum number of nearest-neighbor results to display.");
- -ABSL_FLAG(bool, l2_normalize, false,
- +ABSL_FLAG(bool,
- + l2_normalize,
- + false,
- "If true, the raw feature vectors returned by the image embedder "
- "will be normalized with L2-norm. Generally only needed if the model "
- "doesn't already contain a L2_NORMALIZATION TFLite Op.");
- -ABSL_FLAG(bool, use_coral, false,
- +ABSL_FLAG(bool,
- + use_coral,
- + false,
- "If true, inference will be delegated to a connected Coral Edge TPU "
- "device.");
-
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_segmenter_demo.cc b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_segmenter_demo.cc
- index 5a566ecbcf921..fdc787288fa06 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_segmenter_demo.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_segmenter_demo.cc
- @@ -23,10 +23,10 @@ limitations under the License.
-
- #include <iostream>
-
- -#include "absl/flags/flag.h" // from @com_google_absl
- -#include "absl/flags/parse.h" // from @com_google_absl
- -#include "absl/status/status.h" // from @com_google_absl
- -#include "absl/strings/match.h" // from @com_google_absl
- +#include "absl/flags/flag.h" // from @com_google_absl
- +#include "absl/flags/parse.h" // from @com_google_absl
- +#include "absl/status/status.h" // from @com_google_absl
- +#include "absl/strings/match.h" // from @com_google_absl
- #include "absl/strings/str_format.h" // from @com_google_absl
- #include "tensorflow_lite_support/cc/port/statusor.h"
- #include "tensorflow_lite_support/cc/task/core/external_file_handler.h"
- @@ -37,16 +37,24 @@ limitations under the License.
- #include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h"
- #include "tensorflow_lite_support/cc/task/vision/utils/image_utils.h"
-
- -ABSL_FLAG(std::string, model_path, "",
- +ABSL_FLAG(std::string,
- + model_path,
- + "",
- "Absolute path to the '.tflite' image segmenter model.");
- -ABSL_FLAG(std::string, image_path, "",
- +ABSL_FLAG(std::string,
- + image_path,
- + "",
- "Absolute path to the image to segment. The image must be RGB or "
- "RGBA (grayscale is not supported). The image EXIF orientation "
- "flag, if any, is NOT taken into account.");
- -ABSL_FLAG(std::string, output_mask_png, "",
- +ABSL_FLAG(std::string,
- + output_mask_png,
- + "",
- "Absolute path to the output category mask (confidence masks outputs "
- "are not supported by this tool). Must have a '.png' extension.");
- -ABSL_FLAG(bool, use_coral, false,
- +ABSL_FLAG(bool,
- + use_coral,
- + false,
- "If true, inference will be delegated to a connected Coral Edge TPU "
- "device.");
-
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/object_detector_demo.cc b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/object_detector_demo.cc
- index 20f7403207c2e..fd000fccf2f29 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/object_detector_demo.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/object_detector_demo.cc
- @@ -24,10 +24,10 @@ limitations under the License.
- #include <iostream>
- #include <limits>
-
- -#include "absl/flags/flag.h" // from @com_google_absl
- -#include "absl/flags/parse.h" // from @com_google_absl
- -#include "absl/status/status.h" // from @com_google_absl
- -#include "absl/strings/match.h" // from @com_google_absl
- +#include "absl/flags/flag.h" // from @com_google_absl
- +#include "absl/flags/parse.h" // from @com_google_absl
- +#include "absl/status/status.h" // from @com_google_absl
- +#include "absl/strings/match.h" // from @com_google_absl
- #include "absl/strings/str_format.h" // from @com_google_absl
- #include "tensorflow_lite_support/cc/port/statusor.h"
- #include "tensorflow_lite_support/cc/task/core/external_file_handler.h"
- @@ -40,32 +40,48 @@ limitations under the License.
- #include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h"
- #include "tensorflow_lite_support/cc/task/vision/utils/image_utils.h"
-
- -ABSL_FLAG(std::string, model_path, "",
- +ABSL_FLAG(std::string,
- + model_path,
- + "",
- "Absolute path to the '.tflite' object detector model.");
- -ABSL_FLAG(std::string, image_path, "",
- +ABSL_FLAG(std::string,
- + image_path,
- + "",
- "Absolute path to the image to run detection on. The image must be "
- "RGB or RGBA (grayscale is not supported). The image EXIF "
- "orientation flag, if any, is NOT taken into account.");
- -ABSL_FLAG(std::string, output_png, "",
- +ABSL_FLAG(std::string,
- + output_png,
- + "",
- "Absolute path to a file where to draw the detection results on top "
- "of the input image. Must have a '.png' extension.");
- -ABSL_FLAG(int32, max_results, 5,
- +ABSL_FLAG(int32,
- + max_results,
- + 5,
- "Maximum number of detection results to display.");
- ABSL_FLAG(
- - float, score_threshold, std::numeric_limits<float>::lowest(),
- + float,
- + score_threshold,
- + std::numeric_limits<float>::lowest(),
- "Detection results with a confidence score below this value are "
- "rejected. If specified, overrides the score threshold(s) provided in the "
- "TFLite Model Metadata. Ignored otherwise.");
- ABSL_FLAG(
- - std::vector<std::string>, class_name_whitelist, {},
- + std::vector<std::string>,
- + class_name_whitelist,
- + {},
- "Comma-separated list of class names that acts as a whitelist. If "
- "non-empty, detections results whose 'class_name' is not in this list "
- "are filtered out. Mutually exclusive with 'class_name_blacklist'.");
- -ABSL_FLAG(std::vector<std::string>, class_name_blacklist, {},
- +ABSL_FLAG(std::vector<std::string>,
- + class_name_blacklist,
- + {},
- "Comma-separated list of class names that acts as a blacklist. If "
- "non-empty, detections results whose 'class_name' is in this list "
- "are filtered out. Mutually exclusive with 'class_name_whitelist'.");
- -ABSL_FLAG(bool, use_coral, false,
- +ABSL_FLAG(bool,
- + use_coral,
- + false,
- "If true, inference will be delegated to a connected Coral Edge TPU "
- "device.");
-
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/sources/TFLCommon.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/sources/TFLCommon.h
- index a4fee55abe158..2ca42fb7f3fbe 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/sources/TFLCommon.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/sources/TFLCommon.h
- @@ -56,7 +56,8 @@ typedef NS_ENUM(NSUInteger, TFLSupportErrorCode) {
-
- /** TensorFlow Lite metadata error codes. */
-
- - /** Unexpected schema version (aka file_identifier) in the Metadata FlatBuffer. */
- + /** Unexpected schema version (aka file_identifier) in the Metadata
- + FlatBuffer. */
- TFLSupportErrorCodeMetadataInvalidSchemaVersionError = 200,
-
- /** No such associated file within metadata, or file has not been packed. */
- @@ -198,11 +199,13 @@ typedef NS_ENUM(NSUInteger, TFLSupportErrorCode) {
- */
- TFLSupportErrorCodeImageProcessingBackendError,
-
- - /** kNotFound indicates some requested entity (such as a file or directory) was not found. */
- + /** kNotFound indicates some requested entity (such as a file or directory)
- + was not found. */
- TFLSupportErrorCodeNotFoundError = 900,
-
- - /** kInternal indicates an internal error has occurred and some invariants expected by the
- - * underlying system have not been satisfied. This error code is reserved for serious errors.
- + /** kInternal indicates an internal error has occurred and some invariants
- + * expected by the underlying system have not been satisfied. This error code
- + * is reserved for serious errors.
- */
- TFLSupportErrorCodeInternalError,
- } NS_SWIFT_NAME(SupportErrorCode);
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/sources/TFLCommonUtils.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/sources/TFLCommonUtils.h
- index f3d71984a3213..58710c6f8eeeb 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/sources/TFLCommonUtils.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/sources/TFLCommonUtils.h
- @@ -25,36 +25,36 @@ NS_ASSUME_NONNULL_BEGIN
- *
- * @param code Error code.
- * @param description Error description.
- - * @param error Pointer to the memory location where the created error should be saved. If `nil`,
- - * no error will be saved.
- + * @param error Pointer to the memory location where the created error should be
- + * saved. If `nil`, no error will be saved.
- */
- -+ (void)createCustomError:(NSError **)error
- ++ (void)createCustomError:(NSError**)error
- withCode:(NSInteger)code
- - description:(NSString *)description;
- + description:(NSString*)description;
-
- /**
- * Converts a C library error, TfLiteSupportError to an NSError.
- *
- * @param supportError C library error.
- - * @param error Pointer to the memory location where the created error should be saved. If `nil`,
- - * no error will be saved.
- + * @param error Pointer to the memory location where the created error should be
- + * saved. If `nil`, no error will be saved.
- */
- -+ (BOOL)checkCError:(TfLiteSupportError *)supportError toError:(NSError **)error;
- ++ (BOOL)checkCError:(TfLiteSupportError*)supportError toError:(NSError**)error;
-
- /**
- - * Allocates a block of memory with the specified size and returns a pointer to it. If memory
- - * cannot be allocated because of an invalid memSize, it saves an error. In other cases, it
- - * terminates program execution.
- + * Allocates a block of memory with the specified size and returns a pointer to
- + * it. If memory cannot be allocated because of an invalid memSize, it saves an
- + * error. In other cases, it terminates program execution.
- *
- * @param memSize size of memory to be allocated
- - * @param error Pointer to the memory location where errors if any should be saved. If `nil`, no
- - * error will be saved.
- + * @param error Pointer to the memory location where errors if any should be
- + * saved. If `nil`, no error will be saved.
- *
- - * @return Pointer to the allocated block of memory on successfull allocation. nil in case as
- - * error is encountered because of invalid memSize. If failure is due to any other reason, method
- - * terminates program execution.
- + * @return Pointer to the allocated block of memory on successfull allocation.
- + * nil in case as error is encountered because of invalid memSize. If failure is
- + * due to any other reason, method terminates program execution.
- */
- -+ (void *)mallocWithSize:(size_t)memSize error:(NSError **)error;
- ++ (void*)mallocWithSize:(size_t)memSize error:(NSError**)error;
- @end
-
- NS_ASSUME_NONNULL_END
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/sources/TFLCommonUtils.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/sources/TFLCommonUtils.m
- index 3904b0ba11d68..9e23b5b571386 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/sources/TFLCommonUtils.m
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/sources/TFLCommonUtils.m
- @@ -20,23 +20,26 @@ static NSString *const TFLSupportTaskErrorDomain = @"org.tensorflow.lite.tasks";
-
- @implementation TFLCommonUtils
-
- -+ (void)createCustomError:(NSError **)error
- ++ (void)createCustomError:(NSError**)error
- withCode:(NSInteger)code
- - description:(NSString *)description {
- + description:(NSString*)description {
- if (error) {
- - *error = [NSError errorWithDomain:TFLSupportTaskErrorDomain
- - code:code
- - userInfo:@{NSLocalizedDescriptionKey : description}];
- + *error =
- + [NSError errorWithDomain:TFLSupportTaskErrorDomain
- + code:code
- + userInfo:@{NSLocalizedDescriptionKey : description}];
- }
- }
-
- -+ (BOOL)checkCError:(TfLiteSupportError *)supportError toError:(NSError **)error {
- ++ (BOOL)checkCError:(TfLiteSupportError*)supportError toError:(NSError**)error {
- if (!supportError) {
- return YES;
- }
- - NSString *description = [NSString stringWithCString:supportError->message
- + NSString* description = [NSString stringWithCString:supportError->message
- encoding:NSUTF8StringEncoding];
- - [self createCustomError:error withCode:supportError->code description:description];
- + [self createCustomError:error
- + withCode:supportError->code
- + description:description];
- return NO;
- }
-
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/audio/core/sources/TFLFloatBuffer.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/audio/core/sources/TFLFloatBuffer.h
- index 79b6ba238e982..a5db97038a047 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/audio/core/sources/TFLFloatBuffer.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/audio/core/sources/TFLFloatBuffer.h
- @@ -23,26 +23,28 @@ NS_ASSUME_NONNULL_BEGIN
- @property(nonatomic, readonly) NSUInteger size;
-
- /** Pointer to float array wrapped by `TFLFloatBuffer`. */
- -@property(nonatomic, readonly) float *data;
- +@property(nonatomic, readonly) float* data;
-
- /**
- - * Initializes a new `TFLFloatBuffer` by copying the elements of the given float data array.
- + * Initializes a new `TFLFloatBuffer` by copying the elements of the given float
- + * data array.
- *
- - * @param data A pointer to a float data array whose values are to be copied into the buffer.
- + * @param data A pointer to a float data array whose values are to be copied
- + * into the buffer.
- * @param size Size of the array float data array.
- *
- - * @return A new instance of `TFLFloatBuffer` initialized with the elements of the given float data
- - * array.
- + * @return A new instance of `TFLFloatBuffer` initialized with the elements of
- + * the given float data array.
- */
- -- (instancetype)initWithData:(float *)data size:(NSUInteger)size;
- +- (instancetype)initWithData:(float*)data size:(NSUInteger)size;
-
- /**
- * Initializes a `TFLFloatBuffer` of the specified size with zeros.
- *
- * @param size Number of elements the `TFLFloatBuffer` can hold.
- *
- - * @return A new instance of `TFLFloatBuffer` of the given size with all elements initialized to
- - * zero.
- + * @return A new instance of `TFLFloatBuffer` of the given size with all
- + * elements initialized to zero.
- */
- - (instancetype)initWithSize:(NSUInteger)size;
-
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/audio/core/sources/TFLFloatBuffer.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/audio/core/sources/TFLFloatBuffer.m
- index 24d50affb27aa..d32fc4363efc2 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/audio/core/sources/TFLFloatBuffer.m
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/audio/core/sources/TFLFloatBuffer.m
- @@ -16,7 +16,7 @@
-
- @implementation TFLFloatBuffer
-
- -- (instancetype)initWithData:(float *)data size:(NSUInteger)size {
- +- (instancetype)initWithData:(float*)data size:(NSUInteger)size {
- self = [self init];
- if (self) {
- _size = size;
- @@ -43,7 +43,7 @@
- return self;
- }
-
- -- (id)copyWithZone:(NSZone *)zone {
- +- (id)copyWithZone:(NSZone*)zone {
- return [[TFLFloatBuffer alloc] initWithData:_data size:_size];
- }
-
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/audio/core/sources/TFLRingBuffer.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/audio/core/sources/TFLRingBuffer.h
- index 5a0ab68974b88..b300de6b94d89 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/audio/core/sources/TFLRingBuffer.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/audio/core/sources/TFLRingBuffer.h
- @@ -17,13 +17,14 @@
-
- NS_ASSUME_NONNULL_BEGIN
-
- -/** An wrapper class which stores a buffer that is written in circular fashion. */
- +/** An wrapper class which stores a buffer that is written in circular fashion.
- + */
- @interface TFLRingBuffer : NSObject
-
- /**
- * A copy of all the internal ring buffer elements in order.
- */
- -@property(nullable, nonatomic, readonly) TFLFloatBuffer *floatBuffer;
- +@property(nullable, nonatomic, readonly) TFLFloatBuffer* floatBuffer;
-
- /**
- * Capacity of the ring buffer in number of elements.
- @@ -36,34 +37,37 @@ NS_ASSUME_NONNULL_BEGIN
- *
- * @param size Size of the ring buffer.
- *
- - * @return A new instance of `TFLRingBuffer` with the given size and all elements
- - * initialized to zero.
- + * @return A new instance of `TFLRingBuffer` with the given size and all
- + * elements initialized to zero.
- */
- - (instancetype)initWithBufferSize:(NSUInteger)size;
-
- /**
- - * Loads a slice of a float array to the ring buffer. If the float array is longer than ring
- - * buffer's capacity, samples with lower indices in the array will be ignored.
- + * Loads a slice of a float array to the ring buffer. If the float array is
- + * longer than ring buffer's capacity, samples with lower indices in the array
- + * will be ignored.
- *
- * @return Boolean indicating success or failure of loading operation.
- */
- -- (BOOL)loadBuffer:(TFLFloatBuffer *)sourceBuffer
- +- (BOOL)loadBuffer:(TFLFloatBuffer*)sourceBuffer
- offset:(NSUInteger)offset
- size:(NSUInteger)size
- - error:(NSError **)error;
- + error:(NSError**)error;
-
- /**
- - * Returns a `TFLFloatBuffer` with a copy of size number of the ring buffer elements in order
- - * starting at offset, i.e, buffer[offset:offset+size].
- + * Returns a `TFLFloatBuffer` with a copy of size number of the ring buffer
- + * elements in order starting at offset, i.e, buffer[offset:offset+size].
- *
- - * @param offset Offset in the ring buffer from which elements are to be returned.
- + * @param offset Offset in the ring buffer from which elements are to be
- + * returned.
- *
- * @param size Number of elements to be returned.
- *
- - * @return A new `TFLFloatBuffer` if offset + size is within the bounds of the ring buffer,
- - * otherwise nil.
- + * @return A new `TFLFloatBuffer` if offset + size is within the bounds of the
- + * ring buffer, otherwise nil.
- */
- -- (nullable TFLFloatBuffer *)floatBufferWithOffset:(NSUInteger)offset size:(NSUInteger)size;
- +- (nullable TFLFloatBuffer*)floatBufferWithOffset:(NSUInteger)offset
- + size:(NSUInteger)size;
-
- /**
- * Clears the `TFLRingBuffer` by setting all the elements to zero .
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/audio/core/sources/TFLRingBuffer.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/audio/core/sources/TFLRingBuffer.m
- index 675f7058fff61..57495409f51c8 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/audio/core/sources/TFLRingBuffer.m
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/audio/core/sources/TFLRingBuffer.m
- @@ -18,7 +18,7 @@
-
- @implementation TFLRingBuffer {
- NSUInteger _nextIndex;
- - TFLFloatBuffer *_buffer;
- + TFLFloatBuffer* _buffer;
- }
-
- - (instancetype)initWithBufferSize:(NSUInteger)size {
- @@ -29,18 +29,18 @@
- return self;
- }
-
- -- (BOOL)loadBuffer:(TFLFloatBuffer *)sourceBuffer
- +- (BOOL)loadBuffer:(TFLFloatBuffer*)sourceBuffer
- offset:(NSUInteger)offset
- size:(NSUInteger)size
- - error:(NSError **)error {
- + error:(NSError**)error {
- NSUInteger sizeToCopy = size;
- NSUInteger newOffset = offset;
-
- if (offset + size > sourceBuffer.size) {
- - [TFLCommonUtils
- - createCustomError:error
- - withCode:TFLSupportErrorCodeInvalidArgumentError
- - description:@"offset + size exceeds the maximum size of the source buffer."];
- + [TFLCommonUtils createCustomError:error
- + withCode:TFLSupportErrorCodeInvalidArgumentError
- + description:@"offset + size exceeds the maximum size "
- + @"of the source buffer."];
- return NO;
- }
-
- @@ -51,13 +51,15 @@
- newOffset = offset + (size - _buffer.size);
- }
-
- - // If the new nextIndex + sizeToCopy is smaller than the size of the ring buffer directly
- - // copy all elements to the end of the ring buffer.
- + // If the new nextIndex + sizeToCopy is smaller than the size of the ring
- + // buffer directly copy all elements to the end of the ring buffer.
- if (_nextIndex + sizeToCopy < _buffer.size) {
- - memcpy(_buffer.data + _nextIndex, sourceBuffer.data + newOffset, sizeof(float) * sizeToCopy);
- + memcpy(_buffer.data + _nextIndex, sourceBuffer.data + newOffset,
- + sizeof(float) * sizeToCopy);
- } else {
- NSUInteger endChunkSize = _buffer.size - _nextIndex;
- - memcpy(_buffer.data + _nextIndex, sourceBuffer.data + newOffset, sizeof(float) * endChunkSize);
- + memcpy(_buffer.data + _nextIndex, sourceBuffer.data + newOffset,
- + sizeof(float) * endChunkSize);
-
- NSUInteger startChunkSize = sizeToCopy - endChunkSize;
- memcpy(_buffer.data, sourceBuffer.data + newOffset + endChunkSize,
- @@ -69,16 +71,17 @@
- return YES;
- }
-
- -- (TFLFloatBuffer *)floatBuffer {
- +- (TFLFloatBuffer*)floatBuffer {
- return [self floatBufferWithOffset:0 size:self.size];
- }
-
- -- (nullable TFLFloatBuffer *)floatBufferWithOffset:(NSUInteger)offset size:(NSUInteger)size {
- +- (nullable TFLFloatBuffer*)floatBufferWithOffset:(NSUInteger)offset
- + size:(NSUInteger)size {
- if (offset + size > _buffer.size) {
- return nil;
- }
-
- - TFLFloatBuffer *bufferToReturn = [[TFLFloatBuffer alloc] initWithSize:size];
- + TFLFloatBuffer* bufferToReturn = [[TFLFloatBuffer alloc] initWithSize:size];
-
- // Return buffer in correct order.
- // Compute offset in flat ring buffer array considering warping.
- @@ -86,17 +89,21 @@
-
- // If no; elements to be copied are within the end of the flat ring buffer.
- if ((correctOffset + size) <= _buffer.size) {
- - memcpy(bufferToReturn.data, _buffer.data + correctOffset, sizeof(float) * size);
- + memcpy(bufferToReturn.data, _buffer.data + correctOffset,
- + sizeof(float) * size);
- } else {
- - // If no; elements to be copied warps around to the beginning of the ring buffer.
- - // Copy the chunk starting at ringBuffer[nextIndex + offset : size] to
- - // beginning of the result array.
- + // If no; elements to be copied warps around to the beginning of the ring
- + // buffer. Copy the chunk starting at ringBuffer[nextIndex + offset : size]
- + // to beginning of the result array.
- NSInteger endChunkSize = _buffer.size - correctOffset;
- - memcpy(bufferToReturn.data, _buffer.data + correctOffset, sizeof(float) * endChunkSize);
- + memcpy(bufferToReturn.data, _buffer.data + correctOffset,
- + sizeof(float) * endChunkSize);
-
- - // Next copy the chunk starting at ringBuffer[0 : size - endChunkSize] to the result array.
- + // Next copy the chunk starting at ringBuffer[0 : size - endChunkSize] to
- + // the result array.
- NSInteger firstChunkSize = size - endChunkSize;
- - memcpy(bufferToReturn.data + endChunkSize, _buffer.data, sizeof(float) * firstChunkSize);
- + memcpy(bufferToReturn.data + endChunkSize, _buffer.data,
- + sizeof(float) * firstChunkSize);
- }
-
- return bufferToReturn;
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions+Helpers.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions+Helpers.h
- index a117bd7b3c4c3..5058f7c9a5a7b 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions+Helpers.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions+Helpers.h
- @@ -18,7 +18,7 @@
- NS_ASSUME_NONNULL_BEGIN
-
- @interface TFLBaseOptions (Helpers)
- -- (void)copyToCOptions:(TfLiteBaseOptions *)cBaseOptions;
- +- (void)copyToCOptions:(TfLiteBaseOptions*)cBaseOptions;
- @end
-
- NS_ASSUME_NONNULL_END
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions.h
- index 330132f4ba138..7ab7e7240791e 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions.h
- @@ -19,10 +19,10 @@ NS_ASSUME_NONNULL_BEGIN
- NS_SWIFT_NAME(CpuSettings)
- @interface TFLCpuSettings : NSObject <NSCopying>
-
- -/** Specifies the number of threads to be used for TFLite ops that support multi-threadingwhen
- - * running inference with CPU.
- - * @discussion This property hould be greater than 0 or equal to -1. Setting it to -1 has the
- - * effect to let TFLite runtime set the value.
- +/** Specifies the number of threads to be used for TFLite ops that support
- + * multi-threadingwhen running inference with CPU.
- + * @discussion This property hould be greater than 0 or equal to -1. Setting it
- + * to -1 has the effect to let TFLite runtime set the value.
- */
- @property(nonatomic) int numThreads;
-
- @@ -35,7 +35,7 @@ NS_SWIFT_NAME(ComputeSettings)
- @interface TFLComputeSettings : NSObject <NSCopying>
-
- /** Holds cpu settings. */
- -@property(nonatomic, copy) TFLCpuSettings *cpuSettings;
- +@property(nonatomic, copy) TFLCpuSettings* cpuSettings;
-
- @end
-
- @@ -46,30 +46,32 @@ NS_SWIFT_NAME(ExternalFile)
- @interface TFLExternalFile : NSObject <NSCopying>
-
- /** Path to the file in bundle. */
- -@property(nonatomic, copy) NSString *filePath;
- +@property(nonatomic, copy) NSString* filePath;
- /// Add provision for other sources in future.
-
- @end
-
- /**
- - * Holds the base options that is used for creation of any type of task. It has fields with
- - * important information acceleration configuration, tflite model source etc.
- + * Holds the base options that is used for creation of any type of task. It has
- + * fields with important information acceleration configuration, tflite model
- + * source etc.
- */
- NS_SWIFT_NAME(BaseOptions)
- @interface TFLBaseOptions : NSObject <NSCopying>
-
- /**
- - * The external model file, as a single standalone TFLite file. It could be packed with TFLite Model
- - * Metadata[1] and associated files if exist. Fail to provide the necessary metadata and associated
- - * files might result in errors.
- + * The external model file, as a single standalone TFLite file. It could be
- + * packed with TFLite Model Metadata[1] and associated files if exist. Fail to
- + * provide the necessary metadata and associated files might result in errors.
- */
- -@property(nonatomic, copy) TFLExternalFile *modelFile;
- +@property(nonatomic, copy) TFLExternalFile* modelFile;
-
- /**
- - * Holds settings for one possible acceleration configuration including.cpu/gpu settings.
- - * Please see documentation of TfLiteComputeSettings and its members for more details.
- + * Holds settings for one possible acceleration configuration including.cpu/gpu
- + * settings. Please see documentation of TfLiteComputeSettings and its members
- + * for more details.
- */
- -@property(nonatomic, copy) TFLComputeSettings *computeSettings;
- +@property(nonatomic, copy) TFLComputeSettings* computeSettings;
-
- @end
-
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLCategory+Helpers.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLCategory+Helpers.h
- index 617fa3ae7120e..6f515e46744b9 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLCategory+Helpers.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLCategory+Helpers.h
- @@ -30,7 +30,7 @@ NS_ASSUME_NONNULL_BEGIN
- * results returned by inference methods of the iOS TF Lite Task Classification
- * tasks.
- */
- -+ (TFLCategory *)categoryWithCCategory:(TfLiteCategory *)cCategory;
- ++ (TFLCategory*)categoryWithCCategory:(TfLiteCategory*)cCategory;
- @end
-
- NS_ASSUME_NONNULL_END
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLCategory+Helpers.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLCategory+Helpers.m
- index 7d49c36aa48c9..4139525500a59 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLCategory+Helpers.m
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLCategory+Helpers.m
- @@ -19,8 +19,8 @@
- + (TFLCategory *)categoryWithCCategory:(TfLiteCategory *)cCategory {
- if (cCategory == nil) return nil;
-
- - NSString *displayName;
- - NSString *label;
- + NSString* displayName;
- + NSString* label;
-
- if (cCategory->display_name != nil) {
- displayName = [NSString stringWithCString:cCategory->display_name
- @@ -28,7 +28,8 @@
- }
-
- if (cCategory->label != nil) {
- - label = [NSString stringWithCString:cCategory->label encoding:NSUTF8StringEncoding];
- + label = [NSString stringWithCString:cCategory->label
- + encoding:NSUTF8StringEncoding];
- }
-
- return [[TFLCategory alloc] initWithIndex:(NSInteger)cCategory->index
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLCategory.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLCategory.h
- index 91060ef4f1840..5c521f2239ab7 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLCategory.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLCategory.h
- @@ -20,24 +20,25 @@ NS_ASSUME_NONNULL_BEGIN
- NS_SWIFT_NAME(ClassificationCategory)
- @interface TFLCategory : NSObject
-
- -/** Index of the class in the corresponding label map, usually packed in the TFLite Model
- - * Metadata. */
- +/** Index of the class in the corresponding label map, usually packed in the
- + * TFLite Model Metadata. */
- @property(nonatomic, readonly) NSInteger index;
-
- /** Confidence score for this class . */
- @property(nonatomic, readonly) float score;
-
- /** Class name of the class. */
- -@property(nonatomic, readonly, nullable) NSString *label;
- +@property(nonatomic, readonly, nullable) NSString* label;
-
- /** Display name of the class. */
- -@property(nonatomic, readonly, nullable) NSString *displayName;
- +@property(nonatomic, readonly, nullable) NSString* displayName;
-
- /**
- - * Initializes a new `TFLCategory` with the given index, score, label and display name.
- + * Initializes a new `TFLCategory` with the given index, score, label and
- + * display name.
- *
- - * @param index Index of the class in the corresponding label map, usually packed in the TFLite
- - * Model Metadata.
- + * @param index Index of the class in the corresponding label map, usually
- + * packed in the TFLite Model Metadata.
- *
- * @param score Confidence score for this class.
- *
- @@ -45,12 +46,13 @@ NS_SWIFT_NAME(ClassificationCategory)
- *
- * @param displayName Display name of the class.
- *
- - * @return An instance of `TFLCategory` initialized with the given index, score, label and display name.
- + * @return An instance of `TFLCategory` initialized with the given index, score,
- + * label and display name.
- */
- - (instancetype)initWithIndex:(NSInteger)index
- score:(float)score
- - label:(nullable NSString *)label
- - displayName:(nullable NSString *)displayName;
- + label:(nullable NSString*)label
- + displayName:(nullable NSString*)displayName;
-
- - (instancetype)init NS_UNAVAILABLE;
-
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLCategory.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLCategory.m
- index b72c3b55fdaf1..603c5a27c9673 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLCategory.m
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLCategory.m
- @@ -18,8 +18,8 @@
-
- - (instancetype)initWithIndex:(NSInteger)index
- score:(float)score
- - label:(nullable NSString *)label
- - displayName:(nullable NSString *)displayName {
- + label:(nullable NSString*)label
- + displayName:(nullable NSString*)displayName {
- self = [super init];
- if (self) {
- _index = index;
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions+Helpers.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions+Helpers.h
- index b12c118e89021..152aa33dbdb59 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions+Helpers.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions+Helpers.h
- @@ -18,11 +18,11 @@
- NS_ASSUME_NONNULL_BEGIN
-
- @interface TFLClassificationOptions (Helpers)
- -- (BOOL)copyToCOptions:(TfLiteClassificationOptions *)cClassificationOptions
- - error:(NSError **)error;
- +- (BOOL)copyToCOptions:(TfLiteClassificationOptions*)cClassificationOptions
- + error:(NSError**)error;
-
- - (void)deleteAllocatedMemoryOfClassificationOptions:
- - (TfLiteClassificationOptions *)cClassificationOptions;
- + (TfLiteClassificationOptions*)cClassificationOptions;
- @end
-
- NS_ASSUME_NONNULL_END
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions+Helpers.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions+Helpers.m
- index 84e8fa5e234fb..767e5e4d577a3 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions+Helpers.m
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions+Helpers.m
- @@ -20,21 +20,28 @@
-
- + (char **)cStringArrayFromNSArray:(NSArray<NSString *> *)strings error:(NSError **)error {
- if (strings.count <= 0) {
- - [TFLCommonUtils createCustomError:error
- - withCode:TFLSupportErrorCodeInvalidArgumentError
- - description:@"Invalid length of strings found for list type options."];
- + [TFLCommonUtils
- + createCustomError:error
- + withCode:TFLSupportErrorCodeInvalidArgumentError
- + description:
- + @"Invalid length of strings found for list type options."];
- return nil;
- }
-
- - char **cStrings = [TFLCommonUtils mallocWithSize:strings.count * sizeof(char *) error:error];
- - if (!cStrings) return NULL;
- + char** cStrings = [TFLCommonUtils mallocWithSize:strings.count * sizeof(char*)
- + error:error];
- + if (!cStrings)
- + return NULL;
-
- for (NSInteger i = 0; i < strings.count; i++) {
- cStrings[i] = [TFLCommonUtils
- - mallocWithSize:([strings[i] lengthOfBytesUsingEncoding:NSUTF8StringEncoding] + 1) *
- + mallocWithSize:([strings[i]
- + lengthOfBytesUsingEncoding:NSUTF8StringEncoding] +
- + 1) *
- sizeof(char)
- error:error];
- - if (!cStrings[i]) return NULL;
- + if (!cStrings[i])
- + return NULL;
-
- strcpy(cStrings[i], strings[i].UTF8String);
- }
- @@ -77,14 +84,16 @@
-
- if (self.displayNamesLocale) {
- if (self.displayNamesLocale.UTF8String) {
- - cClassificationOptions->display_names_local = strdup(self.displayNamesLocale.UTF8String);
- + cClassificationOptions->display_names_local =
- + strdup(self.displayNamesLocale.UTF8String);
- if (!cClassificationOptions->display_names_local) {
- exit(-1); // Memory Allocation Failed.
- }
- } else {
- - [TFLCommonUtils createCustomError:error
- - withCode:TFLSupportErrorCodeInvalidArgumentError
- - description:@"Could not convert (NSString *) to (char *)."];
- + [TFLCommonUtils
- + createCustomError:error
- + withCode:TFLSupportErrorCodeInvalidArgumentError
- + description:@"Could not convert (NSString *) to (char *)."];
- return NO;
- }
- }
- @@ -93,7 +102,7 @@
- }
-
- - (void)deleteAllocatedMemoryOfClassificationOptions:
- - (TfLiteClassificationOptions *)cClassificationOptions {
- + (TfLiteClassificationOptions*)cClassificationOptions {
- if (self.labelAllowList) {
- [TFLClassificationOptions deleteCStringsArray:cClassificationOptions->label_allowlist.list
- count:cClassificationOptions->label_allowlist.length];
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions.h
- index 41b69bec8a7d8..ce3f5d6580913 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions.h
- @@ -23,13 +23,14 @@ NS_SWIFT_NAME(ClassificationOptions)
- @interface TFLClassificationOptions : NSObject <NSCopying>
-
- /** If set, all classes in this list will be filtered out from the results . */
- -@property(nonatomic, copy) NSArray *labelDenyList;
- +@property(nonatomic, copy) NSArray* labelDenyList;
-
- -/** If set, all classes not in this list will be filtered out from the results . */
- -@property(nonatomic, copy) NSArray *labelAllowList;
- +/** If set, all classes not in this list will be filtered out from the results .
- + */
- +@property(nonatomic, copy) NSArray* labelAllowList;
-
- /** Display names local for display names*/
- -@property(nonatomic, copy) NSString *displayNamesLocale;
- +@property(nonatomic, copy) NSString* displayNamesLocale;
-
- /** Results with score threshold greater than this value are returned . */
- @property(nonatomic) float scoreThreshold;
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationResult+Helpers.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationResult+Helpers.h
- index 7ef58fc5b76ce..351e87db729c6 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationResult+Helpers.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationResult+Helpers.h
- @@ -20,17 +20,18 @@ NS_ASSUME_NONNULL_BEGIN
- @interface TFLClassificationResult (Helpers)
-
- /**
- - * Creates and returns a TFLClassificationResult from a TfLiteClassificationResult returned by
- - * TFLite Task C Library Classification tasks.
- + * Creates and returns a TFLClassificationResult from a
- + * TfLiteClassificationResult returned by TFLite Task C Library Classification
- + * tasks.
- *
- - * @param cClassificationResult Classification results returned by TFLite Task C Library
- - * Classification tasks
- + * @param cClassificationResult Classification results returned by TFLite Task C
- + * Library Classification tasks
- *
- - * @return Classification Result of type TFLClassificationResult to be returned by inference methods
- - * of the iOS TF Lite Task Classification tasks.
- + * @return Classification Result of type TFLClassificationResult to be returned
- + * by inference methods of the iOS TF Lite Task Classification tasks.
- */
- -+ (TFLClassificationResult *)classificationResultWithCResult:
- - (TfLiteClassificationResult *)cClassificationResult;
- ++ (TFLClassificationResult*)classificationResultWithCResult:
- + (TfLiteClassificationResult*)cClassificationResult;
- @end
-
- NS_ASSUME_NONNULL_END
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationResult+Helpers.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationResult+Helpers.m
- index c8744a3bf99c6..52e92852d88a9 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationResult+Helpers.m
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationResult+Helpers.m
- @@ -19,30 +19,34 @@
-
- + (TFLClassificationResult *)classificationResultWithCResult:
- (TfLiteClassificationResult *)cClassificationResult {
- - if (!cClassificationResult) return nil;
- + if (!cClassificationResult)
- + return nil;
-
- NSMutableArray *classificationHeads = [[NSMutableArray alloc] init];
- for (int i = 0; i < cClassificationResult->size; i++) {
- TfLiteClassifications cClassifications = cClassificationResult->classifications[i];
- - NSMutableArray *categories = [[NSMutableArray alloc] init];
- + NSMutableArray* categories = [[NSMutableArray alloc] init];
- for (int j = 0; j < cClassifications.size; j++) {
- TfLiteCategory cCategory = cClassifications.categories[j];
- [categories addObject:[TFLCategory categoryWithCCategory:&cCategory]];
- }
-
- - NSString *headName = nil;
- + NSString* headName = nil;
-
- if (cClassifications.head_name) {
- - headName = [NSString stringWithCString:cClassifications.head_name encoding:NSUTF8StringEncoding];
- + headName = [NSString stringWithCString:cClassifications.head_name
- + encoding:NSUTF8StringEncoding];
- }
- -
- - TFLClassifications *classifications = [[TFLClassifications alloc] initWithHeadIndex:cClassifications.head_index
- - headName:headName
- - categories:categories];
- +
- + TFLClassifications* classifications = [[TFLClassifications alloc]
- + initWithHeadIndex:cClassifications.head_index
- + headName:headName
- + categories:categories];
-
- [classificationHeads addObject:classifications];
- }
-
- - return [[TFLClassificationResult alloc] initWithClassifications:classificationHeads];
- + return [[TFLClassificationResult alloc]
- + initWithClassifications:classificationHeads];
- }
- @end
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationResult.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationResult.h
- index 72d5c85dec0d6..052b4f1daf710 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationResult.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationResult.h
- @@ -17,58 +17,66 @@ limitations under the License.
-
- NS_ASSUME_NONNULL_BEGIN
-
- -/** Encapsulates list of predicted classes (aka labels) for a given image classifier head. */
- +/** Encapsulates list of predicted classes (aka labels) for a given image
- + * classifier head. */
- NS_SWIFT_NAME(Classifications)
- @interface TFLClassifications : NSObject
-
- /**
- - * The index of the image classifier head these classes refer to. This is useful for multi-head
- - * models.
- + * The index of the image classifier head these classes refer to. This is useful
- + * for multi-head models.
- */
- @property(nonatomic, readonly) NSInteger headIndex;
-
- /** The name of the classifier head, which is the corresponding tensor metadata
- - * name. See https://github.com/tensorflow/tflite-support/blob/710e323265bfb71fdbdd72b3516e00cff15c0326/tensorflow_lite_support/metadata/metadata_schema.fbs#L545
- - * This will always be NULL for the `TFLClassifications` in the `TFLClassificationResult` returned by the follwing methods of `TFLImageClassifier`.
- + * name. See
- + * https://github.com/tensorflow/tflite-support/blob/710e323265bfb71fdbdd72b3516e00cff15c0326/tensorflow_lite_support/metadata/metadata_schema.fbs#L545
- + * This will always be NULL for the `TFLClassifications` in the
- + * `TFLClassificationResult` returned by the follwing methods of
- + * `TFLImageClassifier`.
- * 1. -[TFLImageClassifier classifyWithGMLImage:error:]
- * 2. -[TFLImageClassifier classifyWithGMLImage:regionOfInterest:error:]
- */
- -@property(nonatomic, readonly) NSString *headName;
- +@property(nonatomic, readonly) NSString* headName;
-
- -/** The array of predicted classes, usually sorted by descending scores (e.g.from high to low
- - * probability). */
- -@property(nonatomic, readonly) NSArray<TFLCategory *> *categories;
- +/** The array of predicted classes, usually sorted by descending scores
- + * (e.g.from high to low probability). */
- +@property(nonatomic, readonly) NSArray<TFLCategory*>* categories;
-
- /**
- - * Initializes a new `TFLClassifications` with the given head index and array of categories.
- - * head name is initialized to `nil`.
- + * Initializes a new `TFLClassifications` with the given head index and array of
- + * categories. head name is initialized to `nil`.
- *
- - * @param headIndex The index of the image classifier head these classes refer to.
- + * @param headIndex The index of the image classifier head these classes refer
- + * to.
- * @param categories An array of `TFLCategory` objects encapsulating a list of
- - * predictions usually sorted by descending scores (e.g. from high to low probability).
- + * predictions usually sorted by descending scores (e.g. from high to low
- + * probability).
- *
- - * @return An instance of `TFLClassifications` initialized with the given head index and
- - * array of categories.
- + * @return An instance of `TFLClassifications` initialized with the given head
- + * index and array of categories.
- */
- - (instancetype)initWithHeadIndex:(NSInteger)headIndex
- - categories:(NSArray<TFLCategory *> *)categories;
- -
- + categories:(NSArray<TFLCategory*>*)categories;
-
- /**
- - * Initializes a new `TFLClassifications` with the given head index, head name and array of categories.
- + * Initializes a new `TFLClassifications` with the given head index, head name
- + * and array of categories.
- *
- - * @param headIndex The index of the image classifier head these classes refer to.
- - * @param headName The name of the classifier head, which is the corresponding tensor metadata
- - * name.
- + * @param headIndex The index of the image classifier head these classes refer
- + * to.
- + * @param headName The name of the classifier head, which is the corresponding
- + * tensor metadata name.
- * @param categories An array of `TFLCategory` objects encapsulating a list of
- - * predictions usually sorted by descending scores (e.g. from high to low probability).
- + * predictions usually sorted by descending scores (e.g. from high to low
- + * probability).
- *
- - * @return An object of `TFLClassifications` initialized with the given head index, head name and
- - * array of categories.
- + * @return An object of `TFLClassifications` initialized with the given head
- + * index, head name and array of categories.
- */
- - (instancetype)initWithHeadIndex:(NSInteger)headIndex
- - headName:(nullable NSString *)headName
- - categories:(NSArray<TFLCategory *> *)categories;
- + headName:(nullable NSString*)headName
- + categories:(NSArray<TFLCategory*>*)categories;
-
- @end
-
- @@ -76,20 +84,23 @@ NS_SWIFT_NAME(Classifications)
- NS_SWIFT_NAME(ClassificationResult)
- @interface TFLClassificationResult : NSObject
-
- -/** Array of TFLClassifications objects containing image classifier predictions per image classifier
- - * head.
- +/** Array of TFLClassifications objects containing image classifier predictions
- + * per image classifier head.
- */
- -@property(nonatomic, readonly) NSArray<TFLClassifications *> *classifications;
- +@property(nonatomic, readonly) NSArray<TFLClassifications*>* classifications;
-
- /**
- - * Initializes a new `TFLClassificationResult` with the given array of classifications.
- + * Initializes a new `TFLClassificationResult` with the given array of
- + * classifications.
- *
- - * @param classifications An Aaray of `TFLClassifications` objects containing image classifier
- - * predictions per image classifier head.
- + * @param classifications An Aaray of `TFLClassifications` objects containing
- + * image classifier predictions per image classifier head.
- *
- - * @return An instance of 1TFLClassificationResult1 initialized with the given array of classifications.
- + * @return An instance of 1TFLClassificationResult1 initialized with the given
- + * array of classifications.
- */
- -- (instancetype)initWithClassifications:(NSArray<TFLClassifications *> *)classifications;
- +- (instancetype)initWithClassifications:
- + (NSArray<TFLClassifications*>*)classifications;
-
- @end
-
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationResult.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationResult.m
- index f56600cb94f3b..0ea238417c891 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationResult.m
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationResult.m
- @@ -17,9 +17,8 @@ limitations under the License.
- @implementation TFLClassifications
-
- - (instancetype)initWithHeadIndex:(NSInteger)headIndex
- - headName:(nullable NSString *)headName
- - categories:(NSArray<TFLCategory *> *)categories {
- -
- + headName:(nullable NSString*)headName
- + categories:(NSArray<TFLCategory*>*)categories {
- self = [super init];
- if (self) {
- _headIndex = headIndex;
- @@ -30,17 +29,18 @@ limitations under the License.
- }
-
- - (instancetype)initWithHeadIndex:(NSInteger)headIndex
- - categories:(NSArray<TFLCategory *> *)categories {
- + categories:(NSArray<TFLCategory*>*)categories {
- return [self initWithHeadIndex:headIndex headName:nil categories:categories];
- }
-
- @end
-
- @implementation TFLClassificationResult {
- - NSArray<TFLClassifications *> *_classifications;
- + NSArray<TFLClassifications*>* _classifications;
- }
-
- -- (instancetype)initWithClassifications:(NSArray<TFLClassifications *> *)classifications {
- +- (instancetype)initWithClassifications:
- + (NSArray<TFLClassifications*>*)classifications {
- self = [super init];
- if (self) {
- _classifications = classifications;
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLDetectionResult+Helpers.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLDetectionResult+Helpers.h
- index 7f6e8cae27f2c..81efbcc1d8c57 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLDetectionResult+Helpers.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLDetectionResult+Helpers.h
- @@ -19,16 +19,17 @@ NS_ASSUME_NONNULL_BEGIN
-
- @interface TFLDetectionResult (Helpers)
- /**
- - * Creates and retrurns a TFLDetectionResult from a TfLiteDetectionResult returned by
- - * TFLite Task C Library Object Detection task.
- + * Creates and retrurns a TFLDetectionResult from a TfLiteDetectionResult
- + * returned by TFLite Task C Library Object Detection task.
- *
- * @param cDetectionResult Detection results returned by TFLite Task C Library
- * Object Detection task.
- *
- - * @return Detection Result of type TFLDetectionResult to be returned by inference methods
- - * of the iOS TF Lite Task Object Detection task.
- + * @return Detection Result of type TFLDetectionResult to be returned by
- + * inference methods of the iOS TF Lite Task Object Detection task.
- */
- -+ (TFLDetectionResult *)detectionResultWithCResult:(TfLiteDetectionResult *)cDetectionResult;
- ++ (TFLDetectionResult*)detectionResultWithCResult:
- + (TfLiteDetectionResult*)cDetectionResult;
- @end
-
- NS_ASSUME_NONNULL_END
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLDetectionResult+Helpers.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLDetectionResult+Helpers.m
- index 405bddf117cdd..3ae292cb0ef3b 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLDetectionResult+Helpers.m
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLDetectionResult+Helpers.m
- @@ -17,8 +17,10 @@
-
- @implementation TFLDetectionResult (Helpers)
-
- -+ (TFLDetectionResult *)detectionResultWithCResult:(TfLiteDetectionResult *)cDetectionResult {
- - if (!cDetectionResult) return nil;
- ++ (TFLDetectionResult*)detectionResultWithCResult:
- + (TfLiteDetectionResult*)cDetectionResult {
- + if (!cDetectionResult)
- + return nil;
-
- NSMutableArray *detections = [[NSMutableArray alloc] init];
- for (int i = 0; i < cDetectionResult->size; i++) {
- @@ -30,10 +32,11 @@
- TFLCategory *resultCategory = [TFLCategory categoryWithCCategory:&cCategory];
- [categories addObject:resultCategory];
- }
- - TFLDetection *detection = [[TFLDetection alloc]
- - initWithBoundingBox:CGRectMake(
- - cDetection.bounding_box.origin_x, cDetection.bounding_box.origin_y,
- - cDetection.bounding_box.width, cDetection.bounding_box.height)
- + TFLDetection* detection = [[TFLDetection alloc]
- + initWithBoundingBox:CGRectMake(cDetection.bounding_box.origin_x,
- + cDetection.bounding_box.origin_y,
- + cDetection.bounding_box.width,
- + cDetection.bounding_box.height)
- categories:categories];
- [detections addObject:detection];
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLDetectionResult.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLDetectionResult.h
- index 0c64aa98b6089..00cc75bbc161e 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLDetectionResult.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLDetectionResult.h
- @@ -19,31 +19,35 @@ limitations under the License.
-
- NS_ASSUME_NONNULL_BEGIN
-
- -/** Encapsulates list of predicted classes (aka labels) and bounding box for a detected object. */
- +/** Encapsulates list of predicted classes (aka labels) and bounding box for a
- + * detected object. */
- NS_SWIFT_NAME(Detection)
- @interface TFLDetection : NSObject
-
- /**
- - * The index of the image classifier head these classes refer to. This is useful for multi-head
- - * models.
- + * The index of the image classifier head these classes refer to. This is useful
- + * for multi-head models.
- */
- @property(nonatomic, readonly) CGRect boundingBox;
-
- -/** The array of predicted classes, usually sorted by descending scores (e.g.from high to low
- - * probability). */
- -@property(nonatomic, readonly) NSArray<TFLCategory *> *categories;
- +/** The array of predicted classes, usually sorted by descending scores
- + * (e.g.from high to low probability). */
- +@property(nonatomic, readonly) NSArray<TFLCategory*>* categories;
-
- /**
- - * Initializes an object of `TFLDetection` with the given bounding box and array of categories.
- + * Initializes an object of `TFLDetection` with the given bounding box and array
- + * of categories.
- *
- - * @param boundingBox CGRect specifying the bounds of the object represented by this detection.
- - * @param categories Array of predicted classes, usually sorted by descending scores (e.g.from high
- - * to low probability).
- + * @param boundingBox CGRect specifying the bounds of the object represented by
- + * this detection.
- + * @param categories Array of predicted classes, usually sorted by descending
- + * scores (e.g.from high to low probability).
- *
- - * @return An instance of `TFLDetection` initialized with the given bounding box and array of categories.
- + * @return An instance of `TFLDetection` initialized with the given bounding box
- + * and array of categories.
- */
- - (instancetype)initWithBoundingBox:(CGRect)boundingBox
- - categories:(NSArray<TFLCategory *> *)categories;
- + categories:(NSArray<TFLCategory*>*)categories;
-
- - (instancetype)init NS_UNAVAILABLE;
-
- @@ -55,16 +59,17 @@ NS_SWIFT_NAME(Detection)
- NS_SWIFT_NAME(DetectionResult)
- @interface TFLDetectionResult : NSObject
-
- -@property(nonatomic, readonly) NSArray<TFLDetection *> *detections;
- +@property(nonatomic, readonly) NSArray<TFLDetection*>* detections;
-
- /**
- * Initializes a new `TFLDetectionResult` with the given array of detections.
- *
- * @param detections Array of detected objects of type TFLDetection.
- *
- - * @return An instance of `TFLDetectionResult` initialized with the given array of detections.
- + * @return An instance of `TFLDetectionResult` initialized with the given array
- + * of detections.
- */
- -- (instancetype)initWithDetections:(NSArray<TFLDetection *> *)detections;
- +- (instancetype)initWithDetections:(NSArray<TFLDetection*>*)detections;
-
- - (instancetype)init NS_UNAVAILABLE;
-
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLDetectionResult.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLDetectionResult.m
- index 280767e6a353a..14cec3bca3d08 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLDetectionResult.m
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLDetectionResult.m
- @@ -17,7 +17,7 @@ limitations under the License.
- @implementation TFLDetection
-
- - (instancetype)initWithBoundingBox:(CGRect)boundingBox
- - categories:(NSArray<TFLCategory *> *)categories {
- + categories:(NSArray<TFLCategory*>*)categories {
- self = [super init];
- if (self) {
- _boundingBox = boundingBox;
- @@ -30,7 +30,7 @@ limitations under the License.
-
- @implementation TFLDetectionResult
-
- -- (instancetype)initWithDetections:(NSArray<TFLDetection *> *)detections {
- +- (instancetype)initWithDetections:(NSArray<TFLDetection*>*)detections {
- self = [super init];
- if (self) {
- _detections = detections;
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLSegmentationResult+Helpers.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLSegmentationResult+Helpers.h
- index c979fda53c70b..0a85efe2877bb 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLSegmentationResult+Helpers.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLSegmentationResult+Helpers.h
- @@ -28,8 +28,8 @@ NS_ASSUME_NONNULL_BEGIN
- * @return Segmentation Result of type TFLSegmentationResult to be returned by
- * inference methods of the iOS TF Lite Task Image Segmentation task.
- */
- -+ (TFLSegmentationResult *)segmentationResultWithCResult:
- - (TfLiteSegmentationResult *)cSegmentationResult;
- ++ (TFLSegmentationResult*)segmentationResultWithCResult:
- + (TfLiteSegmentationResult*)cSegmentationResult;
- @end
-
- NS_ASSUME_NONNULL_END
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLSegmentationResult+Helpers.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLSegmentationResult+Helpers.m
- index f2ea957ca3010..2a897f0ba3614 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLSegmentationResult+Helpers.m
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLSegmentationResult+Helpers.m
- @@ -16,29 +16,31 @@
-
- @implementation TFLSegmentationResult (Helpers)
-
- -+ (TFLSegmentationResult *)segmentationResultWithCResult:
- - (TfLiteSegmentationResult *)cSegmentationResult {
- - if (!cSegmentationResult) return nil;
- ++ (TFLSegmentationResult*)segmentationResultWithCResult:
- + (TfLiteSegmentationResult*)cSegmentationResult {
- + if (!cSegmentationResult)
- + return nil;
-
- - NSMutableArray *segmentations = [[NSMutableArray alloc] init];
- + NSMutableArray* segmentations = [[NSMutableArray alloc] init];
- for (int i = 0; i < cSegmentationResult->size; i++) {
- TfLiteSegmentation cSegmentation = cSegmentationResult->segmentations[i];
- - NSMutableArray *coloredLabels = [[NSMutableArray alloc] init];
- + NSMutableArray* coloredLabels = [[NSMutableArray alloc] init];
- for (int j = 0; j < cSegmentation.colored_labels_size; j++) {
- TfLiteColoredLabel cColoredLabel = cSegmentation.colored_labels[j];
-
- - NSString *displayName;
- + NSString* displayName;
- if (cColoredLabel.display_name) {
- displayName = [NSString stringWithCString:cColoredLabel.display_name
- encoding:NSUTF8StringEncoding];
- }
-
- - NSString *label;
- + NSString* label;
- if (cColoredLabel.label) {
- - label = [NSString stringWithCString:cColoredLabel.label encoding:NSUTF8StringEncoding];
- + label = [NSString stringWithCString:cColoredLabel.label
- + encoding:NSUTF8StringEncoding];
- }
-
- - TFLColoredLabel *coloredLabel =
- + TFLColoredLabel* coloredLabel =
- [[TFLColoredLabel alloc] initWithRed:(NSUInteger)cColoredLabel.r
- green:(NSUInteger)cColoredLabel.g
- blue:(NSUInteger)cColoredLabel.b
- @@ -47,27 +49,29 @@
- [coloredLabels addObject:coloredLabel];
- }
-
- - TFLSegmentation *segmentation;
- + TFLSegmentation* segmentation;
-
- if (cSegmentation.confidence_masks) {
- - NSMutableArray *confidenceMasks = [[NSMutableArray alloc] init];
- + NSMutableArray* confidenceMasks = [[NSMutableArray alloc] init];
- for (int i = 0; i < cSegmentation.colored_labels_size; i++) {
- - TFLConfidenceMask *confidenceMask =
- - [[TFLConfidenceMask alloc] initWithWidth:(NSInteger)cSegmentation.width
- - height:(NSInteger)cSegmentation.height
- - mask:cSegmentation.confidence_masks[i]];
- + TFLConfidenceMask* confidenceMask = [[TFLConfidenceMask alloc]
- + initWithWidth:(NSInteger)cSegmentation.width
- + height:(NSInteger)cSegmentation.height
- + mask:cSegmentation.confidence_masks[i]];
- [confidenceMasks addObject:confidenceMask];
- }
- - segmentation = [[TFLSegmentation alloc] initWithConfidenceMasks:confidenceMasks
- - coloredLabels:coloredLabels];
- + segmentation =
- + [[TFLSegmentation alloc] initWithConfidenceMasks:confidenceMasks
- + coloredLabels:coloredLabels];
-
- } else if (cSegmentation.category_mask) {
- - TFLCategoryMask *categoryMask =
- + TFLCategoryMask* categoryMask =
- [[TFLCategoryMask alloc] initWithWidth:(NSInteger)cSegmentation.width
- height:(NSInteger)cSegmentation.height
- mask:cSegmentation.category_mask];
- - segmentation = [[TFLSegmentation alloc] initWithCategoryMask:categoryMask
- - coloredLabels:coloredLabels];
- + segmentation =
- + [[TFLSegmentation alloc] initWithCategoryMask:categoryMask
- + coloredLabels:coloredLabels];
- }
-
- [segmentations addObject:segmentation];
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLSegmentationResult.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLSegmentationResult.h
- index 1307e26294dd4..3aca4567ebe2e 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLSegmentationResult.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLSegmentationResult.h
- @@ -23,7 +23,7 @@ NS_SWIFT_NAME(ConfidenceMask)
- /**
- * Confidence masks of size `width` x `height` for any one class.
- */
- -@property(nonatomic, readonly) float *mask;
- +@property(nonatomic, readonly) float* mask;
-
- /**
- * The width of the mask. This is an intrinsic parameter of the model being
- @@ -42,7 +42,7 @@ NS_SWIFT_NAME(ConfidenceMask)
- */
- - (instancetype)initWithWidth:(NSInteger)width
- height:(NSInteger)height
- - mask:(float * _Nullable)mask;
- + mask:(float* _Nullable)mask;
-
- - (instancetype)init NS_UNAVAILABLE;
-
- @@ -59,7 +59,7 @@ NS_SWIFT_NAME(CategoryMask)
- * The value of each pixel in this mask represents the class to which the
- * pixel belongs.
- */
- -@property(nonatomic, readonly) UInt8 *mask;
- +@property(nonatomic, readonly) UInt8* mask;
-
- /**
- * The width of the mask. This is an intrinsic parameter of the model being
- @@ -80,15 +80,15 @@ NS_SWIFT_NAME(CategoryMask)
- *
- * @param width Width of the mask.
- * @param height Height of the mask.
- - * @param mask Flattened 2D-array of size `width` x `height`, in row major order.
- - * The value of each pixel in this mask represents the class to which the
- + * @param mask Flattened 2D-array of size `width` x `height`, in row major
- + * order. The value of each pixel in this mask represents the class to which the
- * pixel belongs.
- *
- * @return An instance of TFLCategoryMask initialized to the specified values.
- */
- - (instancetype)initWithWidth:(NSInteger)width
- height:(NSInteger)height
- - mask:(UInt8 * _Nullable)mask;
- + mask:(UInt8* _Nullable)mask;
-
- - (instancetype)init NS_UNAVAILABLE;
-
- @@ -107,17 +107,18 @@ NS_SWIFT_NAME(ColoredLabel)
- * The class name, as provided in the label map packed in the TFLite Model
- * Metadata.
- */
- -@property(nonatomic, readonly) NSString *label;
- +@property(nonatomic, readonly) NSString* label;
-
- /**
- * The display name, as provided in the label map (if available) packed in
- * the TFLite Model Metadata. See displayNamesLocale in
- * TFLClassificationOptions.
- */
- -@property(nonatomic, readonly) NSString *displayName;
- +@property(nonatomic, readonly) NSString* displayName;
-
- /**
- - * Initializes a new `TFLColoredLabel` with red, gree, blue color components, label and display name.
- + * Initializes a new `TFLColoredLabel` with red, gree, blue color components,
- + * label and display name.
- *
- * @param r Red component of the RGB color components.
- * @param g Green component of the RGB color components.
- @@ -125,13 +126,14 @@ NS_SWIFT_NAME(ColoredLabel)
- * @param label Class name.
- * @param displayName Display name.
- *
- - * @return An instance of TFLColoredLabel initialized with red, gree, blue color components, label and display name.
- + * @return An instance of TFLColoredLabel initialized with red, gree, blue color
- + * components, label and display name.
- */
- - (instancetype)initWithRed:(NSUInteger)r
- green:(NSUInteger)g
- blue:(NSUInteger)b
- - label:(NSString *)label
- - displayName:(NSString *)displayName;
- + label:(NSString*)label
- + displayName:(NSString*)displayName;
-
- - (instancetype)init NS_UNAVAILABLE;
-
- @@ -150,7 +152,8 @@ NS_SWIFT_NAME(Segmentation)
- * this particular class.
- * This property is mutually exclusive with `categoryMask`.
- */
- -@property(nonatomic, nullable, readonly) NSArray<TFLConfidenceMask *> *confidenceMasks;
- +@property(nonatomic, nullable, readonly)
- + NSArray<TFLConfidenceMask*>* confidenceMasks;
-
- /**
- * Holds the category mask.
- @@ -158,7 +161,7 @@ NS_SWIFT_NAME(Segmentation)
- * pixel belongs.
- * This property is mutually exclusive with `confidenceMasks`.
- */
- -@property(nonatomic, nullable, readonly) TFLCategoryMask *categoryMask;
- +@property(nonatomic, nullable, readonly) TFLCategoryMask* categoryMask;
-
- /**
- * The list of colored labels for all the supported categories (classes).
- @@ -167,33 +170,38 @@ NS_SWIFT_NAME(Segmentation)
- * `colored_labels[i]`, `confidence_masks` indices, i.e. `confidence_masks[i]`
- * is associated with `colored_labels[i]`.
- */
- -@property(nonatomic, readonly) NSArray<TFLColoredLabel *> *coloredLabels;
- +@property(nonatomic, readonly) NSArray<TFLColoredLabel*>* coloredLabels;
-
- + (instancetype)new NS_UNAVAILABLE;
-
- /**
- - * Initializes a new `TFLSegmentation` with an array of confidence masks and an array of colored labels.
- - * `categoryMask` is initialized to `nil` as it is mutually exclusive with `confidenceMasks`.
- + * Initializes a new `TFLSegmentation` with an array of confidence masks and an
- + * array of colored labels. `categoryMask` is initialized to `nil` as it is
- + * mutually exclusive with `confidenceMasks`.
- *
- * @param confidenceMasks An array of `TFLConfidenceMask` objects.
- * @param coloredLabels An array of `TFLColoredLabel` objects.
- *
- - * @return An instance of `TFLSegmentation` initialized with an array of confidence masks and an array of colored labels.
- + * @return An instance of `TFLSegmentation` initialized with an array of
- + * confidence masks and an array of colored labels.
- */
- -- (instancetype)initWithConfidenceMasks:(NSArray<TFLConfidenceMask *> *)confidenceMasks
- - coloredLabels:(NSArray<TFLColoredLabel *> *)coloredLabels;
- +- (instancetype)
- + initWithConfidenceMasks:(NSArray<TFLConfidenceMask*>*)confidenceMasks
- + coloredLabels:(NSArray<TFLColoredLabel*>*)coloredLabels;
-
- /**
- - * Initializes a new `TFLSegmentation` with a category mask and array of colored labels.
- - * `confidenceMasks` is initialized to `nil` as it is mutually exclusive with `categoryMask`.
- + * Initializes a new `TFLSegmentation` with a category mask and array of colored
- + * labels. `confidenceMasks` is initialized to `nil` as it is mutually exclusive
- + * with `categoryMask`.
- *
- * @param categoryMask A `TFLCategoryMask` object.
- * @param coloredLabels An array of `TFLColoredLabel` objects.
- *
- - * @return An instance of `TFLSegmentation` initialized with a category mask and array of colored labels.
- + * @return An instance of `TFLSegmentation` initialized with a category mask and
- + * array of colored labels.
- */
- -- (instancetype)initWithCategoryMask:(TFLCategoryMask *)categoryMask
- - coloredLabels:(NSArray<TFLColoredLabel *> *)coloredLabels;
- +- (instancetype)initWithCategoryMask:(TFLCategoryMask*)categoryMask
- + coloredLabels:(NSArray<TFLColoredLabel*>*)coloredLabels;
-
- - (instancetype)init NS_UNAVAILABLE;
-
- @@ -209,7 +217,7 @@ NS_SWIFT_NAME(SegmentationResult)
- * e.g. instance segmentation models, which may return one segmentation per
- * object.
- */
- -@property(nonatomic, readonly) NSArray<TFLSegmentation *> *segmentations;
- +@property(nonatomic, readonly) NSArray<TFLSegmentation*>* segmentations;
-
- + (instancetype)new NS_UNAVAILABLE;
-
- @@ -218,9 +226,10 @@ NS_SWIFT_NAME(SegmentationResult)
- *
- * @param segmentations An array of `TFLSegmentation` objects.
- *
- - * @return An instance of `TFLSegmentationResult` initialized with an array of segmentations.
- + * @return An instance of `TFLSegmentationResult` initialized with an array of
- + * segmentations.
- */
- -- (instancetype)initWithSegmentations:(NSArray<TFLSegmentation *> *)segmentations;
- +- (instancetype)initWithSegmentations:(NSArray<TFLSegmentation*>*)segmentations;
-
- - (instancetype)init NS_UNAVAILABLE;
-
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLSegmentationResult.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLSegmentationResult.m
- index 33defd1139509..45b5510525fdc 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLSegmentationResult.m
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLSegmentationResult.m
- @@ -17,13 +17,16 @@
-
- @implementation TFLCategoryMask
-
- -- (instancetype)initWithWidth:(NSInteger)width height:(NSInteger)height mask:(UInt8 *)mask {
- +- (instancetype)initWithWidth:(NSInteger)width
- + height:(NSInteger)height
- + mask:(UInt8*)mask {
- self = [super init];
- if (self) {
- _width = width;
- _height = height;
- if (mask != NULL) {
- - _mask = [TFLCommonUtils mallocWithSize:width * height * sizeof(UInt8) error:nil];
- + _mask = [TFLCommonUtils mallocWithSize:width * height * sizeof(UInt8)
- + error:nil];
- if (_mask) {
- memcpy(_mask, mask, width * height * sizeof(UInt8));
- }
- @@ -32,7 +35,7 @@
- return self;
- }
-
- -- (id)copyWithZone:(NSZone *)zone {
- +- (id)copyWithZone:(NSZone*)zone {
- return [[TFLCategoryMask alloc] initWithWidth:self.width
- height:self.height
- mask:self.mask];
- @@ -46,13 +49,16 @@
-
- @implementation TFLConfidenceMask
-
- -- (instancetype)initWithWidth:(NSInteger)width height:(NSInteger)height mask:(float *)mask {
- +- (instancetype)initWithWidth:(NSInteger)width
- + height:(NSInteger)height
- + mask:(float*)mask {
- self = [super init];
- if (self) {
- _width = width;
- _height = height;
- if (mask != NULL) {
- - _mask = [TFLCommonUtils mallocWithSize:width * height * sizeof(float) error:nil];
- + _mask = [TFLCommonUtils mallocWithSize:width * height * sizeof(float)
- + error:nil];
- if (_mask) {
- memcpy(_mask, mask, width * height * sizeof(float));
- }
- @@ -61,7 +67,7 @@
- return self;
- }
-
- -- (id)copyWithZone:(NSZone *)zone {
- +- (id)copyWithZone:(NSZone*)zone {
- return [[TFLConfidenceMask alloc] initWithWidth:self.width
- height:self.height
- mask:self.mask];
- @@ -78,8 +84,8 @@
- - (instancetype)initWithRed:(NSUInteger)r
- green:(NSUInteger)g
- blue:(NSUInteger)b
- - label:(NSString *)label
- - displayName:(NSString *)displayName {
- + label:(NSString*)label
- + displayName:(NSString*)displayName {
- self = [super init];
- if (self) {
- _r = r;
- @@ -95,21 +101,25 @@
-
- @implementation TFLSegmentation
-
- -- (instancetype)initWithConfidenceMasks:(NSArray<TFLConfidenceMask *> *)confidenceMasks
- - coloredLabels:(NSArray<TFLColoredLabel *> *)coloredLabels {
- +- (instancetype)
- + initWithConfidenceMasks:(NSArray<TFLConfidenceMask*>*)confidenceMasks
- + coloredLabels:(NSArray<TFLColoredLabel*>*)coloredLabels {
- return [self initWithConfidenceMasks:confidenceMasks
- categoryMask:nil
- coloredLabels:coloredLabels];
- }
-
- -- (instancetype)initWithCategoryMask:(TFLCategoryMask *)categoryMask
- - coloredLabels:(NSArray<TFLColoredLabel *> *)coloredLabels {
- - return [self initWithConfidenceMasks:nil categoryMask:categoryMask coloredLabels:coloredLabels];
- +- (instancetype)initWithCategoryMask:(TFLCategoryMask*)categoryMask
- + coloredLabels:(NSArray<TFLColoredLabel*>*)coloredLabels {
- + return [self initWithConfidenceMasks:nil
- + categoryMask:categoryMask
- + coloredLabels:coloredLabels];
- }
-
- -- (instancetype)initWithConfidenceMasks:(NSArray<TFLConfidenceMask *> *)confidenceMasks
- - categoryMask:(TFLCategoryMask *)categoryMask
- - coloredLabels:(NSArray<TFLColoredLabel *> *)coloredLabels {
- +- (instancetype)
- + initWithConfidenceMasks:(NSArray<TFLConfidenceMask*>*)confidenceMasks
- + categoryMask:(TFLCategoryMask*)categoryMask
- + coloredLabels:(NSArray<TFLColoredLabel*>*)coloredLabels {
- self = [super init];
- if (self) {
- _confidenceMasks = confidenceMasks;
- @@ -123,7 +133,8 @@
-
- @implementation TFLSegmentationResult
-
- -- (instancetype)initWithSegmentations:(NSArray<TFLSegmentation *> *)segmentations {
- +- (instancetype)initWithSegmentations:
- + (NSArray<TFLSegmentation*>*)segmentations {
- self = [super init];
- if (self) {
- _segmentations = segmentations;
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLBertNLClassifier.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLBertNLClassifier.h
- index 99de5ad04febf..ac81a15ac11c6 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLBertNLClassifier.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLBertNLClassifier.h
- @@ -27,15 +27,17 @@ NS_ASSUME_NONNULL_BEGIN
- @end
-
- /**
- - * Classifier API for NLClassification tasks with Bert models, categorizes string into different
- - * classes. The API expects a Bert based TFLite model with metadata populated.
- + * Classifier API for NLClassification tasks with Bert models, categorizes
- + * string into different classes. The API expects a Bert based TFLite model with
- + * metadata populated.
- *
- * The metadata should contain the following information:
- * 1 input_process_unit for Wordpiece/Sentencepiece Tokenizer.
- * 3 input tensors with names "ids", "mask" and "segment_ids".
- - * 1 output tensor of type float32[1, 2], with a optionally attached label file. If a label
- - * file is attached, the file should be a plain text file with one label per line, the number
- - * of labels should match the number of categories the model outputs.
- + * 1 output tensor of type float32[1, 2], with a optionally attached label
- + * file. If a label file is attached, the file should be a plain text file with
- + * one label per line, the number of labels should match the number of
- + * categories the model outputs.
- */
- @interface TFLBertNLClassifier : NSObject
-
- @@ -45,7 +47,7 @@ NS_ASSUME_NONNULL_BEGIN
- * @param modelPath Path to the classification model.
- * @return A TFLBertNLClassifier instance.
- */
- -+ (instancetype)bertNLClassifierWithModelPath:(NSString *)modelPath
- ++ (instancetype)bertNLClassifierWithModelPath:(NSString*)modelPath
- NS_SWIFT_NAME(bertNLClassifier(modelPath:));
-
- /**
- @@ -54,8 +56,9 @@ NS_ASSUME_NONNULL_BEGIN
- * @param modelPath Path to the classification model.
- * @return A TFLBertNLClassifier instance.
- */
- -+ (instancetype)bertNLClassifierWithModelPath:(NSString *)modelPath
- - options:(TFLBertNLClassifierOptions *)options
- ++ (instancetype)bertNLClassifierWithModelPath:(NSString*)modelPath
- + options:
- + (TFLBertNLClassifierOptions*)options
- NS_SWIFT_NAME(bertNLClassifier(modelPath:options:));
-
- /**
- @@ -65,7 +68,7 @@ NS_ASSUME_NONNULL_BEGIN
- * @param text input text to the model.
- * @return A NSDictionary of categorization results.
- */
- -- (NSDictionary<NSString *, NSNumber *> *)classifyWithText:(NSString *)text
- +- (NSDictionary<NSString*, NSNumber*>*)classifyWithText:(NSString*)text
- NS_SWIFT_NAME(classify(text:));
- @end
- NS_ASSUME_NONNULL_END
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLNLClassifier.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLNLClassifier.h
- index ceb8d2ef9a307..41eb0fb76c9ea 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLNLClassifier.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLNLClassifier.h
- @@ -23,14 +23,14 @@ NS_ASSUME_NONNULL_BEGIN
- @property(nonatomic) int inputTensorIndex;
- @property(nonatomic) int outputScoreTensorIndex;
- @property(nonatomic) int outputLabelTensorIndex;
- -@property(nonatomic) NSString *inputTensorName;
- -@property(nonatomic) NSString *outputScoreTensorName;
- -@property(nonatomic) NSString *outputLabelTensorName;
- +@property(nonatomic) NSString* inputTensorName;
- +@property(nonatomic) NSString* outputScoreTensorName;
- +@property(nonatomic) NSString* outputLabelTensorName;
- @end
-
- /**
- - * Classifier API for natural language classification tasks, categorizes string into different
- - * classes.
- + * Classifier API for natural language classification tasks, categorizes string
- + * into different classes.
- *
- * The API expects a TFLite model with the following input/output tensor:
- *
- @@ -39,25 +39,28 @@ NS_ASSUME_NONNULL_BEGIN
- *
- * Output score tensor
- * (kTfLiteUInt8/kTfLiteInt8/kTfLiteInt16/kTfLiteFloat32/kTfLiteFloat64/kTfLiteBool)
- - * output scores for each class, if type is one of the Int types, dequantize it, if it
- - * is Bool type, convert the values to 0.0 and 1.0 respectively.
- + * output scores for each class, if type is one of the Int types, dequantize
- + * it, if it is Bool type, convert the values to 0.0 and 1.0 respectively.
- *
- - * can have an optional associated file in metadata for labels, the file should be a
- - * plain text file with one label per line, the number of labels should match the number
- - * of categories the model outputs. Output label tensor: optional (kTfLiteString) -
- - * output classname for each class, should be of the same length with scores. If this
- - * tensor is not present, the API uses score indices as classnames. - will be ignored if
- - * output score tensor already has an associated label file.
- + * can have an optional associated file in metadata for labels, the file
- + * should be a plain text file with one label per line, the number of labels
- + * should match the number of categories the model outputs. Output label tensor:
- + * optional (kTfLiteString) - output classname for each class, should be of the
- + * same length with scores. If this tensor is not present, the API uses score
- + * indices as classnames. - will be ignored if output score tensor already has
- + * an associated label file.
- *
- * Optional Output label tensor (kTfLiteString/kTfLiteInt32)
- - * output classname for each class, should be of the same length with scores. If this
- - * tensor is not present, the API uses score indices as classnames.
- + * output classname for each class, should be of the same length with
- + * scores. If this tensor is not present, the API uses score indices as
- + * classnames.
- *
- - * will be ignored if output score tensor already has an associated labe file.
- + * will be ignored if output score tensor already has an associated labe
- + * file.
- *
- - * By default the API tries to find the input/output tensors with default configurations in
- - * TFLNLClassifierOptions, with tensor name prioritized over tensor index. The option is
- - * configurable for different TFLite models.
- + * By default the API tries to find the input/output tensors with default
- + * configurations in TFLNLClassifierOptions, with tensor name prioritized over
- + * tensor index. The option is configurable for different TFLite models.
- */
- @interface TFLNLClassifier : NSObject
-
- @@ -69,8 +72,8 @@ NS_ASSUME_NONNULL_BEGIN
- *
- * @return A TFLNLClassifier instance.
- */
- -+ (instancetype)nlClassifierWithModelPath:(NSString *)modelPath
- - options:(TFLNLClassifierOptions *)options
- ++ (instancetype)nlClassifierWithModelPath:(NSString*)modelPath
- + options:(TFLNLClassifierOptions*)options
- NS_SWIFT_NAME(nlClassifier(modelPath:options:));
-
- /**
- @@ -80,7 +83,7 @@ NS_ASSUME_NONNULL_BEGIN
- * @param text input text to the model.
- * @return A NSDictionary of categorization results.
- */
- -- (NSDictionary<NSString *, NSNumber *> *)classifyWithText:(NSString *)text
- +- (NSDictionary<NSString*, NSNumber*>*)classifyWithText:(NSString*)text
- NS_SWIFT_NAME(classify(text:));
- @end
- NS_ASSUME_NONNULL_END
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/qa/Sources/TFLBertQuestionAnswerer.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/qa/Sources/TFLBertQuestionAnswerer.h
- index 57b7c69c70f62..446e2cb137dd9 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/qa/Sources/TFLBertQuestionAnswerer.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/qa/Sources/TFLBertQuestionAnswerer.h
- @@ -54,13 +54,13 @@ struct TFLPos {
- * @param modelPath The file path to the tflite model.
- * @return A BertQuestionAnswerer instance.
- */
- -+ (instancetype)questionAnswererWithModelPath:(NSString *)modelPath
- ++ (instancetype)questionAnswererWithModelPath:(NSString*)modelPath
- NS_SWIFT_NAME(questionAnswerer(modelPath:));
-
- /**
- * Answers question based on the context. Could be empty if no answer was found
- * from the given context.
- - *
- + *
- * @param context Context the question bases on.
- * @param question Question to ask.
- *
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageClassifier.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageClassifier.h
- index f228034147c40..7e38abe002623 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageClassifier.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageClassifier.h
- @@ -31,29 +31,32 @@ NS_SWIFT_NAME(ImageClassifierOptions)
- * Base options that are used for creation of any type of task.
- * @discussion Please see `TFLBaseOptions` for more details.
- */
- -@property(nonatomic, copy) TFLBaseOptions *baseOptions;
- +@property(nonatomic, copy) TFLBaseOptions* baseOptions;
-
- /**
- * Options that configure the display and filtering of results.
- * @discussion Please see `TFLClassificationOptions` for more details.
- */
- -@property(nonatomic, copy) TFLClassificationOptions *classificationOptions;
- +@property(nonatomic, copy) TFLClassificationOptions* classificationOptions;
-
- /**
- - * Initializes a new `TFLImageClassifierOptions` with the absolute path to the model file
- - * stored locally on the device, set to the given the model path.
- + * Initializes a new `TFLImageClassifierOptions` with the absolute path to the
- + * model file stored locally on the device, set to the given the model path.
- *
- - * @discussion The external model file, must be a single standalone TFLite file. It could be packed
- - * with TFLite Model Metadata[1] and associated files if exist. Fail to provide the necessary
- - * metadata and associated files might result in errors. Check the [documentation]
- - * (https://www.tensorflow.org/lite/convert/metadata) for each task about the specific requirement.
- + * @discussion The external model file, must be a single standalone TFLite file.
- + * It could be packed with TFLite Model Metadata[1] and associated files if
- + * exist. Fail to provide the necessary metadata and associated files might
- + * result in errors. Check the [documentation]
- + * (https://www.tensorflow.org/lite/convert/metadata) for each task about the
- + * specific requirement.
- *
- - * @param modelPath An absolute path to a TensorFlow Lite model file stored locally on the device.
- + * @param modelPath An absolute path to a TensorFlow Lite model file stored
- + * locally on the device.
- *
- * @return An instance of `TFLImageClassifierOptions` initialized to the given
- * model path.
- */
- -- (instancetype)initWithModelPath:(NSString *)modelPath;
- +- (instancetype)initWithModelPath:(NSString*)modelPath;
-
- @end
-
- @@ -64,17 +67,19 @@ NS_SWIFT_NAME(ImageClassifier)
- @interface TFLImageClassifier : NSObject
-
- /**
- - * Creates a new instance of `TFLImageClassifier` from the given `TFLImageClassifierOptions`.
- + * Creates a new instance of `TFLImageClassifier` from the given
- + * `TFLImageClassifierOptions`.
- *
- * @param options The options to use for configuring the `TFLImageClassifier`.
- - * @param error An optional error parameter populated when there is an error in initializing
- - * the image classifier.
- + * @param error An optional error parameter populated when there is an error in
- + * initializing the image classifier.
- *
- - * @return A new instance of `TFLImageClassifier` with the given options. `nil` if there is an error
- - * in initializing the image classifier.
- + * @return A new instance of `TFLImageClassifier` with the given options. `nil`
- + * if there is an error in initializing the image classifier.
- */
- -+ (nullable instancetype)imageClassifierWithOptions:(TFLImageClassifierOptions *)options
- - error:(NSError **)error
- ++ (nullable instancetype)imageClassifierWithOptions:
- + (TFLImageClassifierOptions*)options
- + error:(NSError**)error
- NS_SWIFT_NAME(classifier(options:));
-
- + (instancetype)new NS_UNAVAILABLE;
- @@ -82,46 +87,49 @@ NS_SWIFT_NAME(ImageClassifier)
- /**
- * Performs classification on the given GMLImage.
- *
- - * @discussion This method currently supports classification of only the following types of images:
- + * @discussion This method currently supports classification of only the
- + * following types of images:
- * 1. RGB and RGBA images for `GMLImageSourceTypeImage`.
- * 2. kCVPixelFormatType_32BGRA for `GMLImageSourceTypePixelBuffer` and
- - * `GMLImageSourceTypeSampleBuffer`. If you are using `AVCaptureSession` to setup
- - * camera and get the frames for inference, you must request for this format
- - * from AVCaptureVideoDataOutput. Otherwise your classification
- - * results will be wrong.
- + * `GMLImageSourceTypeSampleBuffer`. If you are using `AVCaptureSession` to
- + * setup camera and get the frames for inference, you must request for this
- + * format from AVCaptureVideoDataOutput. Otherwise your classification results
- + * will be wrong.
- *
- * @param image An image to be classified, represented as a `GMLImage`.
- *
- - * @return A TFLClassificationResult with one set of results per image classifier head. `nil` if
- - * there is an error encountered during classification. Please see `TFLClassificationResult` for
- - * more details.
- + * @return A TFLClassificationResult with one set of results per image
- + * classifier head. `nil` if there is an error encountered during
- + * classification. Please see `TFLClassificationResult` for more details.
- */
- -- (nullable TFLClassificationResult *)classifyWithGMLImage:(GMLImage *)image
- - error:(NSError **)error
- +- (nullable TFLClassificationResult*)classifyWithGMLImage:(GMLImage*)image
- + error:(NSError**)error
- NS_SWIFT_NAME(classify(mlImage:));
-
- /**
- - * Performs classification on the pixels within the specified region of interest of the given
- - * `GMLImage`.
- + * Performs classification on the pixels within the specified region of interest
- + * of the given `GMLImage`.
- *
- - * @discussion This method currently supports inference on only following type of images:
- + * @discussion This method currently supports inference on only following type
- + * of images:
- * 1. RGB and RGBA images for `GMLImageSourceTypeImage`.
- * 2. kCVPixelFormatType_32BGRA for `GMLImageSourceTypePixelBuffer` and
- - * `GMLImageSourceTypeSampleBuffer`. If you are using `AVCaptureSession` to setup
- - * camera and get the frames for inference, you must request for this format
- - * from AVCaptureVideoDataOutput. Otherwise your classification
- - * results will be wrong.
- + * `GMLImageSourceTypeSampleBuffer`. If you are using `AVCaptureSession` to
- + * setup camera and get the frames for inference, you must request for this
- + * format from AVCaptureVideoDataOutput. Otherwise your classification results
- + * will be wrong.
- *
- * @param image An image to be classified, represented as a `GMLImage`.
- - * @param roi A CGRect specifying the region of interest within the given `GMLImage`, on which
- - * classification should be performed.
- + * @param roi A CGRect specifying the region of interest within the given
- + * `GMLImage`, on which classification should be performed.
- *
- - * @return A TFLClassificationResult with one set of results per image classifier head. `nil` if
- - * there is an error encountered during classification.
- + * @return A TFLClassificationResult with one set of results per image
- + * classifier head. `nil` if there is an error encountered during
- + * classification.
- */
- -- (nullable TFLClassificationResult *)classifyWithGMLImage:(GMLImage *)image
- - regionOfInterest:(CGRect)roi
- - error:(NSError **)error
- +- (nullable TFLClassificationResult*)classifyWithGMLImage:(GMLImage*)image
- + regionOfInterest:(CGRect)roi
- + error:(NSError**)error
- NS_SWIFT_NAME(classify(mlImage:regionOfInterest:));
-
- - (instancetype)init NS_UNAVAILABLE;
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageClassifier.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageClassifier.m
- index f8c09527bd902..79ad474054525 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageClassifier.m
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageClassifier.m
- @@ -40,7 +40,7 @@
- return self;
- }
-
- -- (instancetype)initWithModelPath:(NSString *)modelPath {
- +- (instancetype)initWithModelPath:(NSString*)modelPath {
- self = [self init];
- if (self) {
- self.baseOptions.modelFile.filePath = modelPath;
- @@ -63,40 +63,45 @@
- return self;
- }
-
- -+ (nullable instancetype)imageClassifierWithOptions:(TFLImageClassifierOptions *)options
- - error:(NSError **)error {
- ++ (nullable instancetype)imageClassifierWithOptions:
- + (TFLImageClassifierOptions*)options
- + error:(NSError**)error {
- if (!options) {
- - [TFLCommonUtils createCustomError:error
- - withCode:TFLSupportErrorCodeInvalidArgumentError
- - description:@"TFLImageClassifierOptions argument cannot be nil."];
- + [TFLCommonUtils
- + createCustomError:error
- + withCode:TFLSupportErrorCodeInvalidArgumentError
- + description:@"TFLImageClassifierOptions argument cannot be nil."];
- return nil;
- }
-
- TfLiteImageClassifierOptions cOptions = TfLiteImageClassifierOptionsCreate();
-
- - if (![options.classificationOptions copyToCOptions:&(cOptions.classification_options)
- - error:error]) {
- - [options.classificationOptions
- - deleteAllocatedMemoryOfClassificationOptions:&(cOptions.classification_options)];
- + if (![options.classificationOptions
- + copyToCOptions:&(cOptions.classification_options)
- + error:error]) {
- + [options.classificationOptions deleteAllocatedMemoryOfClassificationOptions:
- + &(cOptions.classification_options)];
- return nil;
- }
-
- [options.baseOptions copyToCOptions:&(cOptions.base_options)];
-
- - TfLiteSupportError *cCreateClassifierError = NULL;
- - TfLiteImageClassifier *cImageClassifier =
- + TfLiteSupportError* cCreateClassifierError = NULL;
- + TfLiteImageClassifier* cImageClassifier =
- TfLiteImageClassifierFromOptions(&cOptions, &cCreateClassifierError);
-
- - [options.classificationOptions
- - deleteAllocatedMemoryOfClassificationOptions:&(cOptions.classification_options)];
- + [options.classificationOptions deleteAllocatedMemoryOfClassificationOptions:
- + &(cOptions.classification_options)];
-
- - // Populate iOS error if TfliteSupportError is not null and afterwards delete it.
- + // Populate iOS error if TfliteSupportError is not null and afterwards delete
- + // it.
- if (![TFLCommonUtils checkCError:cCreateClassifierError toError:error]) {
- TfLiteSupportErrorDelete(cCreateClassifierError);
- }
-
- - // Return nil if classifier evaluates to nil. If an error was generted by the C layer, it has
- - // already been populated to an NSError and deleted before returning from the method.
- + // Return nil if classifier evaluates to nil. If an error was generted by the
- + // C layer, it has already been populated to an NSError and deleted before
- + // returning from the method.
- if (!cImageClassifier) {
- return nil;
- }
- @@ -104,16 +109,16 @@
- return [[TFLImageClassifier alloc] initWithImageClassifier:cImageClassifier];
- }
-
- -- (nullable TFLClassificationResult *)classifyWithGMLImage:(GMLImage *)image
- - error:(NSError **)error {
- +- (nullable TFLClassificationResult*)classifyWithGMLImage:(GMLImage*)image
- + error:(NSError**)error {
- return [self classifyWithGMLImage:image
- regionOfInterest:CGRectMake(0, 0, image.width, image.height)
- error:error];
- }
-
- -- (nullable TFLClassificationResult *)classifyWithGMLImage:(GMLImage *)image
- - regionOfInterest:(CGRect)roi
- - error:(NSError **)error {
- +- (nullable TFLClassificationResult*)classifyWithGMLImage:(GMLImage*)image
- + regionOfInterest:(CGRect)roi
- + error:(NSError**)error {
- if (!image) {
- [TFLCommonUtils createCustomError:error
- withCode:TFLSupportErrorCodeInvalidArgumentError
- @@ -121,7 +126,7 @@
- return nil;
- }
-
- - TfLiteFrameBuffer *cFrameBuffer = [image cFrameBufferWithError:error];
- + TfLiteFrameBuffer* cFrameBuffer = [image cFrameBufferWithError:error];
-
- if (!cFrameBuffer) {
- return nil;
- @@ -132,7 +137,7 @@
- .width = roi.size.width,
- .height = roi.size.height};
-
- - TfLiteSupportError *classifyError = NULL;
- + TfLiteSupportError* classifyError = NULL;
- TfLiteClassificationResult *cClassificationResult = TfLiteImageClassifierClassifyWithRoi(
- _imageClassifier, cFrameBuffer, &boundingBox, &classifyError);
-
- @@ -147,8 +152,9 @@
- TfLiteSupportErrorDelete(classifyError);
- }
-
- - // Return nil if C result evaluates to nil. If an error was generted by the C layer, it has
- - // already been populated to an NSError and deleted before returning from the method.
- + // Return nil if C result evaluates to nil. If an error was generted by the C
- + // layer, it has already been populated to an NSError and deleted before
- + // returning from the method.
- if (!cClassificationResult) {
- return nil;
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageSegmenter.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageSegmenter.h
- index 7b556dcd312e2..234e10d68b319 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageSegmenter.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageSegmenter.h
- @@ -20,9 +20,10 @@
- NS_ASSUME_NONNULL_BEGIN
-
- /**
- - * Specifies the type of the output segmentation mask to be returned as the result
- - * of the image segmentation operation. This directs the `TFLImageSegmenter` to
- - * choose the type of post-processing to be performed on the raw model results.
- + * Specifies the type of the output segmentation mask to be returned as the
- + * result of the image segmentation operation. This directs the
- + * `TFLImageSegmenter` to choose the type of post-processing to be performed on
- + * the raw model results.
- */
- typedef NS_ENUM(NSUInteger, TFLOutputType) {
- /** Unspecified output type. */
- @@ -52,7 +53,7 @@ NS_SWIFT_NAME(ImageSegmenterOptions)
- * Base options that is used for creation of any type of task.
- * @discussion Please see `TFLBaseOptions` for more details.
- */
- -@property(nonatomic, copy) TFLBaseOptions *baseOptions;
- +@property(nonatomic, copy) TFLBaseOptions* baseOptions;
-
- /**
- * Specifies the type of output segmentation mask to be returned as a result
- @@ -63,24 +64,26 @@ NS_SWIFT_NAME(ImageSegmenterOptions)
- /**
- * Display names local for display names
- */
- -@property(nonatomic, copy) NSString *displayNamesLocale;
- +@property(nonatomic, copy) NSString* displayNamesLocale;
-
- /**
- - * Initializes a new `TFLImageSegmenterOptions` with the absolute path to the model file
- - * stored locally on the device, set to the given the model path.
- + * Initializes a new `TFLImageSegmenterOptions` with the absolute path to the
- + * model file stored locally on the device, set to the given the model path.
- * .
- * @discussion The external model file, must be a single standalone TFLite
- * file. It could be packed with TFLite Model Metadata[1] and associated files
- * if exist. Fail to provide the necessary metadata and associated files might
- - * result in errors. Check the [documentation](https://www.tensorflow.org/lite/convert/metadata)
- - * for each task about the specific requirement.
- + * result in errors. Check the
- + * [documentation](https://www.tensorflow.org/lite/convert/metadata) for each
- + * task about the specific requirement.
- *
- - * @param modelPath An absolute path to a TensorFlow Lite model file stored locally on the device.
- + * @param modelPath An absolute path to a TensorFlow Lite model file stored
- + * locally on the device.
- *
- * @return An instance of `TFLImageSegmenterOptions` initialized to the given
- * model path.
- */
- -- (instancetype)initWithModelPath:(NSString *)modelPath;
- +- (instancetype)initWithModelPath:(NSString*)modelPath;
-
- @end
-
- @@ -88,17 +91,19 @@ NS_SWIFT_NAME(ImageSegmenter)
- @interface TFLImageSegmenter : NSObject
-
- /**
- - * Creates a new instance of `TFLImageSegmenter` from the given `TFLImageSegmenterOptions`.
- + * Creates a new instance of `TFLImageSegmenter` from the given
- + * `TFLImageSegmenterOptions`.
- *
- * @param options The options to use for configuring the `TFLImageSegmenter`.
- - * @param error An optional error parameter populated when there is an error in initializing
- - * the image segmenter.
- + * @param error An optional error parameter populated when there is an error in
- + * initializing the image segmenter.
- *
- - * @return A new instance of `TFLImageSegmenter` with the given options. `nil` if there is an error
- - * in initializing the image segmenter.
- + * @return A new instance of `TFLImageSegmenter` with the given options. `nil`
- + * if there is an error in initializing the image segmenter.
- */
- -+ (nullable instancetype)imageSegmenterWithOptions:(nonnull TFLImageSegmenterOptions *)options
- - error:(NSError **)error
- ++ (nullable instancetype)imageSegmenterWithOptions:
- + (nonnull TFLImageSegmenterOptions*)options
- + error:(NSError**)error
- NS_SWIFT_NAME(segmenter(options:));
-
- + (instancetype)new NS_UNAVAILABLE;
- @@ -106,22 +111,23 @@ NS_SWIFT_NAME(ImageSegmenter)
- /**
- * Performs segmentation on the given GMLImage.
- *
- - * @discussion This method currently supports segmentation of only the following types of images:
- + * @discussion This method currently supports segmentation of only the following
- + * types of images:
- * 1. RGB and RGBA images for `GMLImageSourceTypeImage`.
- * 2. kCVPixelFormatType_32BGRA for `GMLImageSourceTypePixelBuffer` and
- - * `GMLImageSourceTypeSampleBuffer`. If you are using `AVCaptureSession` to setup
- - * camera and get the frames for inference, you must request for this format
- - * from AVCaptureVideoDataOutput. Otherwise your segmentation
- - * results will be wrong.
- + * `GMLImageSourceTypeSampleBuffer`. If you are using `AVCaptureSession` to
- + * setup camera and get the frames for inference, you must request for this
- + * format from AVCaptureVideoDataOutput. Otherwise your segmentation results
- + * will be wrong.
- *
- * @param image An image to be segmented, represented as a `GMLImage`.
- *
- - * @return A TFLSegmentationResult that holds the segmentation masks returned by the image
- - * segmentation task. `nil` if there is an error encountered during segmentation. Please see
- - * `TFLSegmentationResult` for more details.
- + * @return A TFLSegmentationResult that holds the segmentation masks returned by
- + * the image segmentation task. `nil` if there is an error encountered during
- + * segmentation. Please see `TFLSegmentationResult` for more details.
- */
- -- (nullable TFLSegmentationResult *)segmentWithGMLImage:(GMLImage *)image
- - error:(NSError **)error
- +- (nullable TFLSegmentationResult*)segmentWithGMLImage:(GMLImage*)image
- + error:(NSError**)error
- NS_SWIFT_NAME(segment(mlImage:));
-
- - (instancetype)init NS_UNAVAILABLE;
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageSegmenter.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageSegmenter.m
- index 70068bfdd645a..7b7f3211df952 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageSegmenter.m
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageSegmenter.m
- @@ -35,7 +35,7 @@
- return self;
- }
-
- -- (instancetype)initWithModelPath:(NSString *)modelPath {
- +- (instancetype)initWithModelPath:(NSString*)modelPath {
- self = [self init];
- if (self) {
- self.baseOptions.modelFile.filePath = modelPath;
- @@ -47,14 +47,14 @@
-
- @implementation TFLImageSegmenter {
- /** ImageSegmenter backed by C API */
- - TfLiteImageSegmenter *_imageSegmenter;
- + TfLiteImageSegmenter* _imageSegmenter;
- }
-
- - (void)dealloc {
- TfLiteImageSegmenterDelete(_imageSegmenter);
- }
-
- -- (instancetype)initWithImageSegmenter:(TfLiteImageSegmenter *)imageSegmenter {
- +- (instancetype)initWithImageSegmenter:(TfLiteImageSegmenter*)imageSegmenter {
- self = [super init];
- if (self) {
- _imageSegmenter = imageSegmenter;
- @@ -62,8 +62,9 @@
- return self;
- }
-
- -+ (nullable instancetype)imageSegmenterWithOptions:(nonnull TFLImageSegmenterOptions *)options
- - error:(NSError **)error {
- ++ (nullable instancetype)imageSegmenterWithOptions:
- + (nonnull TFLImageSegmenterOptions*)options
- + error:(NSError**)error {
- TfLiteImageSegmenterOptions cOptions = TfLiteImageSegmenterOptionsCreate();
-
- [options.baseOptions copyToCOptions:&(cOptions.base_options)];
- @@ -71,20 +72,22 @@
-
- if (options.displayNamesLocale) {
- if (options.displayNamesLocale.UTF8String) {
- - cOptions.display_names_locale = strdup(options.displayNamesLocale.UTF8String);
- + cOptions.display_names_locale =
- + strdup(options.displayNamesLocale.UTF8String);
- if (!cOptions.display_names_locale) {
- exit(-1); // Memory Allocation Failed.
- }
- } else {
- - [TFLCommonUtils createCustomError:error
- - withCode:TFLSupportErrorCodeInvalidArgumentError
- - description:@"Could not convert (NSString *) to (char *)."];
- + [TFLCommonUtils
- + createCustomError:error
- + withCode:TFLSupportErrorCodeInvalidArgumentError
- + description:@"Could not convert (NSString *) to (char *)."];
- return nil;
- }
- }
-
- - TfLiteSupportError *cCreateImageSegmenterError = nil;
- - TfLiteImageSegmenter *cImageSegmenter =
- + TfLiteSupportError* cCreateImageSegmenterError = nil;
- + TfLiteImageSegmenter* cImageSegmenter =
- TfLiteImageSegmenterFromOptions(&cOptions, &cCreateImageSegmenterError);
-
- // Freeing memory of allocated string.
- @@ -94,16 +97,17 @@
- TfLiteSupportErrorDelete(cCreateImageSegmenterError);
- }
-
- - // Return nil if C object detector evaluates to nil. If an error was generted by the C layer, it
- - // has already been populated to an NSError and deleted before returning from the method.
- + // Return nil if C object detector evaluates to nil. If an error was generted
- + // by the C layer, it has already been populated to an NSError and deleted
- + // before returning from the method.
- if (!cImageSegmenter) {
- return nil;
- }
- return [[TFLImageSegmenter alloc] initWithImageSegmenter:cImageSegmenter];
- }
-
- -- (nullable TFLSegmentationResult *)segmentWithGMLImage:(GMLImage *)image
- - error:(NSError **)error {
- +- (nullable TFLSegmentationResult*)segmentWithGMLImage:(GMLImage*)image
- + error:(NSError**)error {
- if (!image) {
- [TFLCommonUtils createCustomError:error
- withCode:TFLSupportErrorCodeInvalidArgumentError
- @@ -111,15 +115,15 @@
- return nil;
- }
-
- - TfLiteFrameBuffer *cFrameBuffer = [image cFrameBufferWithError:error];
- + TfLiteFrameBuffer* cFrameBuffer = [image cFrameBufferWithError:error];
-
- if (!cFrameBuffer) {
- return nil;
- }
-
- - TfLiteSupportError *cSegmentError = nil;
- - TfLiteSegmentationResult *cSegmentationResult =
- - TfLiteImageSegmenterSegment(_imageSegmenter, cFrameBuffer, &cSegmentError);
- + TfLiteSupportError* cSegmentError = nil;
- + TfLiteSegmentationResult* cSegmentationResult = TfLiteImageSegmenterSegment(
- + _imageSegmenter, cFrameBuffer, &cSegmentError);
-
- free(cFrameBuffer->buffer);
- cFrameBuffer->buffer = nil;
- @@ -132,13 +136,14 @@
- TfLiteSupportErrorDelete(cSegmentError);
- }
-
- - // Return nil if C result evaluates to nil. If an error was generted by the C layer, it has
- - // already been populated to an NSError and deleted before returning from the method.
- + // Return nil if C result evaluates to nil. If an error was generted by the C
- + // layer, it has already been populated to an NSError and deleted before
- + // returning from the method.
- if (!cSegmentationResult) {
- return nil;
- }
-
- - TFLSegmentationResult *segmentationResult =
- + TFLSegmentationResult* segmentationResult =
- [TFLSegmentationResult segmentationResultWithCResult:cSegmentationResult];
- TfLiteSegmentationResultDelete(cSegmentationResult);
-
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLObjectDetector.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLObjectDetector.h
- index 5e3a0e7186cfe..db76c90cc6868 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLObjectDetector.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLObjectDetector.h
- @@ -30,28 +30,31 @@ NS_SWIFT_NAME(ObjectDetectorOptions)
- * Base options that is used for creation of any type of task.
- * @discussion Please see `TFLBaseOptions` for more details.
- */
- -@property(nonatomic, copy) TFLBaseOptions *baseOptions;
- +@property(nonatomic, copy) TFLBaseOptions* baseOptions;
-
- /**
- * Options that configure the display and filtering of results.
- * @discussion Please see `TFLClassificationOptions` for more details.
- */
- -@property(nonatomic, copy) TFLClassificationOptions *classificationOptions;
- +@property(nonatomic, copy) TFLClassificationOptions* classificationOptions;
-
- /**
- - * Initializes a new `TFLObjectDetectorOptions` with the absolute path to the model file
- - * stored locally on the device, set to the given the model path.
- + * Initializes a new `TFLObjectDetectorOptions` with the absolute path to the
- + * model file stored locally on the device, set to the given the model path.
- *
- - * @discussion The external model file, must be a single standalone TFLite file. It could be packed
- - * with TFLite Model Metadata[1] and associated files if exist. Fail to provide the necessary
- - * metadata and associated files might result in errors. Check the [documentation]
- - * (https://www.tensorflow.org/lite/convert/metadata) for each task about the specific requirement.
- + * @discussion The external model file, must be a single standalone TFLite file.
- + * It could be packed with TFLite Model Metadata[1] and associated files if
- + * exist. Fail to provide the necessary metadata and associated files might
- + * result in errors. Check the [documentation]
- + * (https://www.tensorflow.org/lite/convert/metadata) for each task about the
- + * specific requirement.
- *
- - * @param modelPath An absolute path to a TensorFlow Lite model file stored locally on the device.
- + * @param modelPath An absolute path to a TensorFlow Lite model file stored
- + * locally on the device.
- * @return An instance of `TFLObjectDetectorOptions` initialized to the given
- * model path.
- */
- -- (instancetype)initWithModelPath:(NSString *)modelPath;
- +- (instancetype)initWithModelPath:(NSString*)modelPath;
-
- @end
-
- @@ -59,40 +62,43 @@ NS_SWIFT_NAME(ObjectDetector)
- @interface TFLObjectDetector : NSObject
-
- /**
- - * Creates a new instance of `TFLObjectDetector` from the given `TFLObjectDetectorOptions`.
- + * Creates a new instance of `TFLObjectDetector` from the given
- + * `TFLObjectDetectorOptions`.
- *
- * @param options The options to use for configuring the `TFLObjectDetector`.
- - * @param error An optional error parameter populated when there is an error in initializing
- - * the object detector.
- + * @param error An optional error parameter populated when there is an error in
- + * initializing the object detector.
- *
- - * @return A new instance of `TFLObjectDetector` with the given options. `nil` if there is an error
- - * in initializing the object detector.
- + * @return A new instance of `TFLObjectDetector` with the given options. `nil`
- + * if there is an error in initializing the object detector.
- */
- -+ (nullable instancetype)objectDetectorWithOptions:(TFLObjectDetectorOptions *)options
- - error:(NSError **)error
- ++ (nullable instancetype)objectDetectorWithOptions:
- + (TFLObjectDetectorOptions*)options
- + error:(NSError**)error
- NS_SWIFT_NAME(detector(options:));
-
- + (instancetype)new NS_UNAVAILABLE;
-
- /**
- * Performs object detection on the given GMLImage.
- - * @discussion This method currently supports object detection on only the following types of
- - * images:
- + * @discussion This method currently supports object detection on only the
- + * following types of images:
- * 1. RGB and RGBA images for `GMLImageSourceTypeImage`.
- * 2. `kCVPixelFormatType_32BGRA` for `GMLImageSourceTypePixelBuffer` and
- - * `GMLImageSourceTypeSampleBuffer`. If you are using `AVCaptureSession` to setup
- - * camera and get the frames for inference, you must request for this format
- - * from AVCaptureVideoDataOutput. Otherwise your object detection
- - * results will be wrong.
- + * `GMLImageSourceTypeSampleBuffer`. If you are using `AVCaptureSession` to
- + * setup camera and get the frames for inference, you must request for this
- + * format from AVCaptureVideoDataOutput. Otherwise your object detection results
- + * will be wrong.
- *
- - * @param image An image on which object detection is to be performed, represented as a `GMLImage`.
- + * @param image An image on which object detection is to be performed,
- + * represented as a `GMLImage`.
- *
- - * @return A `TFLDetectionResult` holding an array of TFLDetection objects, each having a bounding
- - * box specifying the region the were detected in and an array of predicted classes. Please see
- - * `TFLDetectionResult` for more details.
- + * @return A `TFLDetectionResult` holding an array of TFLDetection objects, each
- + * having a bounding box specifying the region the were detected in and an array
- + * of predicted classes. Please see `TFLDetectionResult` for more details.
- */
- -- (nullable TFLDetectionResult *)detectWithGMLImage:(GMLImage *)image
- - error:(NSError **)error
- +- (nullable TFLDetectionResult*)detectWithGMLImage:(GMLImage*)image
- + error:(NSError**)error
- NS_SWIFT_NAME(detect(mlImage:));
-
- - (instancetype)init NS_UNAVAILABLE;
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLObjectDetector.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLObjectDetector.m
- index 31cb241a2a448..def2e5b0b4877 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLObjectDetector.m
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLObjectDetector.m
- @@ -40,7 +40,7 @@
- return self;
- }
-
- -- (instancetype)initWithModelPath:(NSString *)modelPath {
- +- (instancetype)initWithModelPath:(NSString*)modelPath {
- self = [self init];
- if (self) {
- self.baseOptions.modelFile.filePath = modelPath;
- @@ -63,40 +63,45 @@
- return self;
- }
-
- -+ (nullable instancetype)objectDetectorWithOptions:(TFLObjectDetectorOptions *)options
- - error:(NSError **)error {
- ++ (nullable instancetype)objectDetectorWithOptions:
- + (TFLObjectDetectorOptions*)options
- + error:(NSError**)error {
- if (!options) {
- - [TFLCommonUtils createCustomError:error
- - withCode:TFLSupportErrorCodeInvalidArgumentError
- - description:@"TFLObjectDetectorOptions argument cannot be nil."];
- + [TFLCommonUtils
- + createCustomError:error
- + withCode:TFLSupportErrorCodeInvalidArgumentError
- + description:@"TFLObjectDetectorOptions argument cannot be nil."];
- return nil;
- }
-
- TfLiteObjectDetectorOptions cOptions = TfLiteObjectDetectorOptionsCreate();
- - if (![options.classificationOptions copyToCOptions:&(cOptions.classification_options)
- - error:error]) {
- + if (![options.classificationOptions
- + copyToCOptions:&(cOptions.classification_options)
- + error:error]) {
- // Deallocating any allocated memory on failure.
- - [options.classificationOptions
- - deleteAllocatedMemoryOfClassificationOptions:&(cOptions.classification_options)];
- + [options.classificationOptions deleteAllocatedMemoryOfClassificationOptions:
- + &(cOptions.classification_options)];
- return nil;
- }
-
- [options.baseOptions copyToCOptions:&(cOptions.base_options)];
-
- - TfLiteSupportError *cCreateObjectDetectorError = nil;
- - TfLiteObjectDetector *cObjectDetector =
- + TfLiteSupportError* cCreateObjectDetectorError = nil;
- + TfLiteObjectDetector* cObjectDetector =
- TfLiteObjectDetectorFromOptions(&cOptions, &cCreateObjectDetectorError);
-
- - [options.classificationOptions
- - deleteAllocatedMemoryOfClassificationOptions:&(cOptions.classification_options)];
- + [options.classificationOptions deleteAllocatedMemoryOfClassificationOptions:
- + &(cOptions.classification_options)];
-
- - // Populate iOS error if TfliteSupportError is not null and afterwards delete it.
- + // Populate iOS error if TfliteSupportError is not null and afterwards delete
- + // it.
- if (![TFLCommonUtils checkCError:cCreateObjectDetectorError toError:error]) {
- TfLiteSupportErrorDelete(cCreateObjectDetectorError);
- }
-
- - // Return nil if C object detector evaluates to nil. If an error was generted by the C layer, it
- - // has already been populated to an NSError and deleted before returning from the method.
- + // Return nil if C object detector evaluates to nil. If an error was generted
- + // by the C layer, it has already been populated to an NSError and deleted
- + // before returning from the method.
- if (!cObjectDetector) {
- return nil;
- }
- @@ -104,8 +109,8 @@
- return [[TFLObjectDetector alloc] initWithObjectDetector:cObjectDetector];
- }
-
- -- (nullable TFLDetectionResult *)detectWithGMLImage:(GMLImage *)image
- - error:(NSError **)error {
- +- (nullable TFLDetectionResult*)detectWithGMLImage:(GMLImage*)image
- + error:(NSError**)error {
- if (!image) {
- [TFLCommonUtils createCustomError:error
- withCode:TFLSupportErrorCodeInvalidArgumentError
- @@ -113,14 +118,14 @@
- return nil;
- }
-
- - TfLiteFrameBuffer *cFrameBuffer = [image cFrameBufferWithError:error];
- + TfLiteFrameBuffer* cFrameBuffer = [image cFrameBufferWithError:error];
-
- if (!cFrameBuffer) {
- return nil;
- }
-
- - TfLiteSupportError *cDetectError = nil;
- - TfLiteDetectionResult *cDetectionResult =
- + TfLiteSupportError* cDetectError = nil;
- + TfLiteDetectionResult* cDetectionResult =
- TfLiteObjectDetectorDetect(_objectDetector, cFrameBuffer, &cDetectError);
-
- free(cFrameBuffer->buffer);
- @@ -134,8 +139,9 @@
- TfLiteSupportErrorDelete(cDetectError);
- }
-
- - // Return nil if C result evaluates to nil. If an error was generted by the C layer, it has
- - // already been populated to an NSError and deleted before returning from the method.
- + // Return nil if C result evaluates to nil. If an error was generted by the C
- + // layer, it has already been populated to an NSError and deleted before
- + // returning from the method.
- if (!cDetectionResult) {
- return nil;
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/utils/sources/GMLImage+Utils.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/utils/sources/GMLImage+Utils.h
- index 77c3e33185b9f..8524903b36602 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/utils/sources/GMLImage+Utils.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/utils/sources/GMLImage+Utils.h
- @@ -36,7 +36,7 @@ NS_ASSUME_NONNULL_BEGIN
- * @return The TfLiteFrameBuffer created from the gmlImage which can be used
- * with the TF Lite Task Vision C library.
- */
- -- (nullable TfLiteFrameBuffer *)cFrameBufferWithError:(NSError *_Nullable *)error;
- +- (nullable TfLiteFrameBuffer*)cFrameBufferWithError:(NSError* _Nullable*)error;
-
- /**
- * Gets grayscale pixel buffer from GMLImage if source type is
- @@ -61,9 +61,9 @@ NS_ASSUME_NONNULL_BEGIN
- * @return The GMLImage object contains the loaded image. This method returns
- * nil if it cannot load the image.
- */
- -+ (nullable GMLImage *)imageFromBundleWithClass:(Class)classObject
- - fileName:(NSString *)name
- - ofType:(NSString *)type
- ++ (nullable GMLImage*)imageFromBundleWithClass:(Class)classObject
- + fileName:(NSString*)name
- + ofType:(NSString*)type
- NS_SWIFT_NAME(imageFromBundle(class:filename:type:));
-
- @end
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/utils/sources/GMLImage+Utils.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/utils/sources/GMLImage+Utils.m
- index d1ab5105448fe..532f75ef25a6c 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/utils/sources/GMLImage+Utils.m
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/utils/sources/GMLImage+Utils.m
- @@ -25,35 +25,38 @@
-
- @interface TFLCVPixelBufferUtils : NSObject
-
- -+ (TfLiteFrameBuffer *)cFrameBufferWithWidth:(int)width
- - height:(int)height
- - frameBufferFormat:(enum TfLiteFrameBufferFormat)frameBufferFormat
- - buffer:(uint8_t *)buffer
- - error:(NSError **)error;
- ++ (TfLiteFrameBuffer*)cFrameBufferWithWidth:(int)width
- + height:(int)height
- + frameBufferFormat:
- + (enum TfLiteFrameBufferFormat)frameBufferFormat
- + buffer:(uint8_t*)buffer
- + error:(NSError**)error;
-
- -+ (TfLiteFrameBuffer *)cFramebufferFromCVPixelBuffer:(CVPixelBufferRef)pixelBuffer
- - error:(NSError **)error;
- ++ (TfLiteFrameBuffer*)cFramebufferFromCVPixelBuffer:
- + (CVPixelBufferRef)pixelBuffer
- + error:(NSError**)error;
-
- @end
-
- @interface UIImage (RawPixelDataUtils)
- -- (TfLiteFrameBuffer *)frameBufferWithError:(NSError **)error;
- +- (TfLiteFrameBuffer*)frameBufferWithError:(NSError**)error;
- - (CVPixelBufferRef)grayScalePixelBuffer;
- @end
-
- @implementation TFLCVPixelBufferUtils
-
- -+ (TfLiteFrameBuffer *)cFrameBufferWithWidth:(int)width
- - height:(int)height
- - frameBufferFormat:(enum TfLiteFrameBufferFormat)frameBufferFormat
- - buffer:(uint8_t *)buffer
- - error:(NSError **)error {
- ++ (TfLiteFrameBuffer*)cFrameBufferWithWidth:(int)width
- + height:(int)height
- + frameBufferFormat:
- + (enum TfLiteFrameBufferFormat)frameBufferFormat
- + buffer:(uint8_t*)buffer
- + error:(NSError**)error {
- if (!buffer) {
- return NULL;
- }
-
- - TfLiteFrameBuffer *cFrameBuffer = [TFLCommonUtils mallocWithSize:sizeof(TfLiteFrameBuffer)
- - error:error];
- + TfLiteFrameBuffer* cFrameBuffer =
- + [TFLCommonUtils mallocWithSize:sizeof(TfLiteFrameBuffer) error:error];
-
- if (cFrameBuffer) {
- cFrameBuffer->dimension.width = width;
- @@ -65,17 +68,18 @@
- return cFrameBuffer;
- }
-
- -+ (uint8_t *)createRGBImageDatafromImageData:(uint8_t *)data
- - withWidth:(size_t)width
- - height:(size_t)height
- - stride:(size_t)stride
- - pixelBufferFormat:(OSType)pixelBufferFormatType
- - error:(NSError **)error {
- ++ (uint8_t*)createRGBImageDatafromImageData:(uint8_t*)data
- + withWidth:(size_t)width
- + height:(size_t)height
- + stride:(size_t)stride
- + pixelBufferFormat:(OSType)pixelBufferFormatType
- + error:(NSError**)error {
- NSInteger destinationChannelCount = 3;
- size_t destinationBytesPerRow = width * destinationChannelCount;
-
- - uint8_t *destPixelBufferAddress =
- - [TFLCommonUtils mallocWithSize:sizeof(uint8_t) * height * destinationBytesPerRow error:error];
- + uint8_t* destPixelBufferAddress = [TFLCommonUtils
- + mallocWithSize:sizeof(uint8_t) * height * destinationBytesPerRow
- + error:error];
-
- if (!destPixelBufferAddress) {
- return NULL;
- @@ -95,19 +99,23 @@
-
- switch (pixelBufferFormatType) {
- case kCVPixelFormatType_32RGBA: {
- - convertError = vImageConvert_RGBA8888toRGB888(&srcBuffer, &destBuffer, kvImageNoFlags);
- + convertError = vImageConvert_RGBA8888toRGB888(&srcBuffer, &destBuffer,
- + kvImageNoFlags);
- break;
- }
- case kCVPixelFormatType_32BGRA: {
- - convertError = vImageConvert_BGRA8888toRGB888(&srcBuffer, &destBuffer, kvImageNoFlags);
- + convertError = vImageConvert_BGRA8888toRGB888(&srcBuffer, &destBuffer,
- + kvImageNoFlags);
- break;
- }
- default: {
- - [TFLCommonUtils createCustomError:error
- - withCode:TFLSupportErrorCodeInvalidArgumentError
- - description:@"Invalid source pixel buffer format. Expecting one of "
- - @"kCVPixelFormatType_32RGBA, kCVPixelFormatType_32BGRA, "
- - @"kCVPixelFormatType_32ARGB"];
- + [TFLCommonUtils
- + createCustomError:error
- + withCode:TFLSupportErrorCodeInvalidArgumentError
- + description:
- + @"Invalid source pixel buffer format. Expecting one of "
- + @"kCVPixelFormatType_32RGBA, kCVPixelFormatType_32BGRA, "
- + @"kCVPixelFormatType_32ARGB"];
-
- free(destPixelBufferAddress);
- return NULL;
- @@ -126,16 +134,17 @@
- return destPixelBufferAddress;
- }
-
- -+ (uint8_t *)createRGBImageDatafromCVPixelBuffer:(CVPixelBufferRef)pixelBuffer
- - error:(NSError **)error {
- ++ (uint8_t*)createRGBImageDatafromCVPixelBuffer:(CVPixelBufferRef)pixelBuffer
- + error:(NSError**)error {
- CVPixelBufferLockBaseAddress(pixelBuffer, 0);
-
- - uint8_t *rgbData = [TFLCVPixelBufferUtils
- + uint8_t* rgbData = [TFLCVPixelBufferUtils
- createRGBImageDatafromImageData:CVPixelBufferGetBaseAddress(pixelBuffer)
- withWidth:CVPixelBufferGetWidth(pixelBuffer)
- height:CVPixelBufferGetHeight(pixelBuffer)
- stride:CVPixelBufferGetBytesPerRow(pixelBuffer)
- - pixelBufferFormat:CVPixelBufferGetPixelFormatType(pixelBuffer)
- + pixelBufferFormat:CVPixelBufferGetPixelFormatType(
- + pixelBuffer)
- error:error];
-
- CVPixelBufferUnlockBaseAddress(pixelBuffer, 0);
- @@ -143,9 +152,10 @@
- return rgbData;
- }
-
- -+ (TfLiteFrameBuffer *)cFramebufferFromCVPixelBuffer:(CVPixelBufferRef)pixelBuffer
- - error:(NSError **)error {
- - uint8_t *buffer = NULL;
- ++ (TfLiteFrameBuffer*)cFramebufferFromCVPixelBuffer:
- + (CVPixelBufferRef)pixelBuffer
- + error:(NSError**)error {
- + uint8_t* buffer = NULL;
- enum TfLiteFrameBufferFormat cPixelFormat = kRGB;
-
- OSType pixelBufferFormat = CVPixelBufferGetPixelFormatType(pixelBuffer);
- @@ -154,14 +164,18 @@
- case kCVPixelFormatType_32BGRA: {
- cPixelFormat = kRGB;
-
- - buffer = [TFLCVPixelBufferUtils createRGBImageDatafromCVPixelBuffer:pixelBuffer error:error];
- + buffer =
- + [TFLCVPixelBufferUtils createRGBImageDatafromCVPixelBuffer:pixelBuffer
- + error:error];
- break;
- }
- default: {
- - [TFLCommonUtils createCustomError:error
- - withCode:TFLSupportErrorCodeInvalidArgumentError
- - description:@"Unsupported pixel format for CVPixelBuffer. Supported "
- - @"pixel format types are kCVPixelFormatType_32BGRA"];
- + [TFLCommonUtils
- + createCustomError:error
- + withCode:TFLSupportErrorCodeInvalidArgumentError
- + description:
- + @"Unsupported pixel format for CVPixelBuffer. Supported "
- + @"pixel format types are kCVPixelFormatType_32BGRA"];
- }
- }
-
- @@ -176,8 +190,8 @@
-
- @implementation UIImage (RawPixelDataUtils)
-
- -- (TfLiteFrameBuffer *)frameBufferWithError:(NSError **)error {
- - TfLiteFrameBuffer *frameBuffer = nil;
- +- (TfLiteFrameBuffer*)frameBufferWithError:(NSError**)error {
- + TfLiteFrameBuffer* frameBuffer = nil;
-
- if (self.CGImage) {
- frameBuffer = [self frameBufferFromCGImage:self.CGImage error:error];
- @@ -202,59 +216,65 @@
- }
-
- CGDataProviderRef imageDataProvider = CGImageGetDataProvider(cgImage);
- - CFMutableDataRef mutableDataRef =
- - CFDataCreateMutableCopy(kCFAllocatorDefault, 0, CGDataProviderCopyData(imageDataProvider));
- + CFMutableDataRef mutableDataRef = CFDataCreateMutableCopy(
- + kCFAllocatorDefault, 0, CGDataProviderCopyData(imageDataProvider));
-
- - UInt8 *pixelData = CFDataGetMutableBytePtr(mutableDataRef);
- + UInt8* pixelData = CFDataGetMutableBytePtr(mutableDataRef);
-
- - if (pixelData == nil) return nil;
- + if (pixelData == nil)
- + return nil;
-
- CVPixelBufferRef cvPixelBuffer = nil;
-
- - CVPixelBufferCreateWithBytes(kCFAllocatorDefault, CGImageGetWidth(cgImage),
- - CGImageGetHeight(cgImage), kCVPixelFormatType_OneComponent8,
- - pixelData, CGImageGetBytesPerRow(cgImage), nil, nil, options,
- - &cvPixelBuffer);
- + CVPixelBufferCreateWithBytes(
- + kCFAllocatorDefault, CGImageGetWidth(cgImage), CGImageGetHeight(cgImage),
- + kCVPixelFormatType_OneComponent8, pixelData,
- + CGImageGetBytesPerRow(cgImage), nil, nil, options, &cvPixelBuffer);
-
- return cvPixelBuffer;
- }
-
- -+ (UInt8 *_Nullable)pixelDataFromCGImage:(CGImageRef)cgImage error:(NSError **)error {
- ++ (UInt8* _Nullable)pixelDataFromCGImage:(CGImageRef)cgImage
- + error:(NSError**)error {
- size_t width = CGImageGetWidth(cgImage);
- size_t height = CGImageGetHeight(cgImage);
-
- NSInteger bitsPerComponent = 8;
- NSInteger channelCount = 4;
- - UInt8 *buffer_to_return = NULL;
- + UInt8* buffer_to_return = NULL;
-
- CGColorSpaceRef colorSpace = CGColorSpaceCreateDeviceRGB();
- size_t bytesPerRow = channelCount * width;
-
- // iOS infers bytesPerRow if it is set to 0.
- - // See https://developer.apple.com/documentation/coregraphics/1455939-cgbitmapcontextcreate
- + // See
- + // https://developer.apple.com/documentation/coregraphics/1455939-cgbitmapcontextcreate
- // But for segmentation test image, this was not the case.
- // Hence setting it to the value of channelCount*width.
- // kCGImageAlphaNoneSkipLast specifies that Alpha will always be next to B.
- // kCGBitmapByteOrder32Big specifies that R will be stored before B.
- // In combination they signify a pixelFormat of kCVPixelFormatType32RGBA.
- - CGBitmapInfo bitMapinfoFor32RGBA = kCGImageAlphaNoneSkipLast | kCGBitmapByteOrder32Big;
- - CGContextRef context = CGBitmapContextCreate(nil, width, height, bitsPerComponent, bytesPerRow,
- - colorSpace, bitMapinfoFor32RGBA);
- + CGBitmapInfo bitMapinfoFor32RGBA =
- + kCGImageAlphaNoneSkipLast | kCGBitmapByteOrder32Big;
- + CGContextRef context =
- + CGBitmapContextCreate(nil, width, height, bitsPerComponent, bytesPerRow,
- + colorSpace, bitMapinfoFor32RGBA);
-
- if (context) {
- CGContextDrawImage(context, CGRectMake(0, 0, width, height), cgImage);
- - uint8_t *srcData = CGBitmapContextGetData(context);
- + uint8_t* srcData = CGBitmapContextGetData(context);
-
- if (srcData) {
- - // We have drawn the image as an RGBA image with 8 bitsPerComponent and hence can safely input
- - // a pixel format of type kCVPixelFormatType_32RGBA for conversion by vImage.
- - buffer_to_return =
- - [TFLCVPixelBufferUtils createRGBImageDatafromImageData:srcData
- - withWidth:width
- - height:height
- - stride:bytesPerRow
- - pixelBufferFormat:kCVPixelFormatType_32RGBA
- - error:error];
- + // We have drawn the image as an RGBA image with 8 bitsPerComponent and
- + // hence can safely input a pixel format of type kCVPixelFormatType_32RGBA
- + // for conversion by vImage.
- + buffer_to_return = [TFLCVPixelBufferUtils
- + createRGBImageDatafromImageData:srcData
- + withWidth:width
- + height:height
- + stride:bytesPerRow
- + pixelBufferFormat:kCVPixelFormatType_32RGBA
- + error:error];
- }
-
- CGContextRelease(context);
- @@ -265,18 +285,21 @@
- return buffer_to_return;
- }
-
- -- (TfLiteFrameBuffer *)frameBufferFromCGImage:(CGImageRef)cgImage error:(NSError **)error {
- - UInt8 *buffer = [UIImage pixelDataFromCGImage:cgImage error:error];
- +- (TfLiteFrameBuffer*)frameBufferFromCGImage:(CGImageRef)cgImage
- + error:(NSError**)error {
- + UInt8* buffer = [UIImage pixelDataFromCGImage:cgImage error:error];
-
- - return [TFLCVPixelBufferUtils cFrameBufferWithWidth:(int)CGImageGetWidth(cgImage)
- - height:(int)CGImageGetHeight(cgImage)
- - frameBufferFormat:kRGB
- - buffer:buffer
- - error:error];
- + return [TFLCVPixelBufferUtils
- + cFrameBufferWithWidth:(int)CGImageGetWidth(cgImage)
- + height:(int)CGImageGetHeight(cgImage)
- + frameBufferFormat:kRGB
- + buffer:buffer
- + error:error];
- }
-
- -- (TfLiteFrameBuffer *)frameBufferFromCIImage:(CIImage *)ciImage error:(NSError **)error {
- - uint8_t *buffer = NULL;
- +- (TfLiteFrameBuffer*)frameBufferFromCIImage:(CIImage*)ciImage
- + error:(NSError**)error {
- + uint8_t* buffer = NULL;
-
- int width = 0;
- int height = 0;
- @@ -285,17 +308,20 @@
- width = (int)CVPixelBufferGetWidth(ciImage.pixelBuffer);
- height = (int)CVPixelBufferGetHeight(ciImage.pixelBuffer);
-
- - buffer = [TFLCVPixelBufferUtils createRGBImageDatafromCVPixelBuffer:ciImage.pixelBuffer
- - error:error];
- + buffer = [TFLCVPixelBufferUtils
- + createRGBImageDatafromCVPixelBuffer:ciImage.pixelBuffer
- + error:error];
-
- } else if (ciImage.CGImage) {
- buffer = [UIImage pixelDataFromCGImage:ciImage.CGImage error:error];
- width = (int)CGImageGetWidth(ciImage.CGImage);
- height = (int)CGImageGetWidth(ciImage.CGImage);
- } else {
- - [TFLCommonUtils createCustomError:error
- - withCode:TFLSupportErrorCodeInvalidArgumentError
- - description:@"CIImage should have CGImage or CVPixelBuffer info."];
- + [TFLCommonUtils
- + createCustomError:error
- + withCode:TFLSupportErrorCodeInvalidArgumentError
- + description:
- + @"CIImage should have CGImage or CVPixelBuffer info."];
- }
-
- return [TFLCVPixelBufferUtils cFrameBufferWithWidth:width
- @@ -309,19 +335,23 @@
-
- @implementation GMLImage (Utils)
-
- -- (nullable TfLiteFrameBuffer *)cFrameBufferWithError:(NSError *_Nullable *)error {
- - TfLiteFrameBuffer *cFrameBuffer = NULL;
- +- (nullable TfLiteFrameBuffer*)cFrameBufferWithError:
- + (NSError* _Nullable*)error {
- + TfLiteFrameBuffer* cFrameBuffer = NULL;
-
- switch (self.imageSourceType) {
- case GMLImageSourceTypeSampleBuffer: {
- - CVPixelBufferRef sampleImagePixelBuffer = CMSampleBufferGetImageBuffer(self.sampleBuffer);
- - cFrameBuffer = [TFLCVPixelBufferUtils cFramebufferFromCVPixelBuffer:sampleImagePixelBuffer
- - error:error];
- + CVPixelBufferRef sampleImagePixelBuffer =
- + CMSampleBufferGetImageBuffer(self.sampleBuffer);
- + cFrameBuffer = [TFLCVPixelBufferUtils
- + cFramebufferFromCVPixelBuffer:sampleImagePixelBuffer
- + error:error];
- break;
- }
- case GMLImageSourceTypePixelBuffer: {
- - cFrameBuffer = [TFLCVPixelBufferUtils cFramebufferFromCVPixelBuffer:self.pixelBuffer
- - error:error];
- + cFrameBuffer =
- + [TFLCVPixelBufferUtils cFramebufferFromCVPixelBuffer:self.pixelBuffer
- + error:error];
- break;
- }
- case GMLImageSourceTypeImage: {
- @@ -352,14 +382,17 @@
- return nil;
- }
-
- -+ (GMLImage *)imageFromBundleWithClass:(Class)classObject
- - fileName:(NSString *)name
- - ofType:(NSString *)type {
- - NSString *imagePath = [[NSBundle bundleForClass:classObject] pathForResource:name ofType:type];
- - if (!imagePath) return nil;
- ++ (GMLImage*)imageFromBundleWithClass:(Class)classObject
- + fileName:(NSString*)name
- + ofType:(NSString*)type {
- + NSString* imagePath =
- + [[NSBundle bundleForClass:classObject] pathForResource:name ofType:type];
- + if (!imagePath)
- + return nil;
-
- - UIImage *image = [[UIImage alloc] initWithContentsOfFile:imagePath];
- - if (!image) return nil;
- + UIImage* image = [[UIImage alloc] initWithContentsOfFile:imagePath];
- + if (!image)
- + return nil;
-
- return [[GMLImage alloc] initWithImage:image];
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/audio/core/TFLRingBufferTests.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/audio/core/TFLRingBufferTests.m
- index 3e2df5d4bf023..cd389b9c0a9a8 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/audio/core/TFLRingBufferTests.m
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/audio/core/TFLRingBufferTests.m
- @@ -17,10 +17,12 @@
- #import "tensorflow_lite_support/ios/sources/TFLCommon.h"
- #import "tensorflow_lite_support/ios/task/audio/core/sources/TFLRingBuffer.h"
-
- -#define VerifyError(error, expectedErrorDomain, expectedErrorCode, expectedLocalizedDescription) \
- - XCTAssertEqual(error.domain, expectedErrorDomain); \
- - XCTAssertEqual(error.code, expectedErrorCode); \
- - XCTAssertEqualObjects(error.localizedDescription, expectedLocalizedDescription);
- +#define VerifyError(error, expectedErrorDomain, expectedErrorCode, \
- + expectedLocalizedDescription) \
- + XCTAssertEqual(error.domain, expectedErrorDomain); \
- + XCTAssertEqual(error.code, expectedErrorCode); \
- + XCTAssertEqualObjects(error.localizedDescription, \
- + expectedLocalizedDescription);
-
- NS_ASSUME_NONNULL_BEGIN
-
- @@ -33,15 +35,20 @@ NS_ASSUME_NONNULL_BEGIN
- NSInteger inDataLength = 5;
- float inData[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f};
-
- - TFLFloatBuffer *inBuffer = [[TFLFloatBuffer alloc] initWithData:&(inData[0]) size:inDataLength];
- + TFLFloatBuffer* inBuffer =
- + [[TFLFloatBuffer alloc] initWithData:&(inData[0]) size:inDataLength];
-
- NSInteger bufferSize = 5;
- - TFLRingBuffer *ringBuffer = [[TFLRingBuffer alloc] initWithBufferSize:bufferSize];
- + TFLRingBuffer* ringBuffer =
- + [[TFLRingBuffer alloc] initWithBufferSize:bufferSize];
-
- - XCTAssertTrue([ringBuffer loadBuffer:inBuffer offset:0 size:inDataLength error:nil]);
- + XCTAssertTrue([ringBuffer loadBuffer:inBuffer
- + offset:0
- + size:inDataLength
- + error:nil]);
- // State after load: [1.0, 2.0, 3.0, 4.0, 5.0]
-
- - TFLFloatBuffer *outBuffer = ringBuffer.floatBuffer;
- + TFLFloatBuffer* outBuffer = ringBuffer.floatBuffer;
- XCTAssertNotNil(outBuffer);
- XCTAssertEqual(outBuffer.size, bufferSize);
-
- @@ -55,16 +62,21 @@ NS_ASSUME_NONNULL_BEGIN
- - (void)testLoadSucceedsWithPartialLengthBuffer {
- NSInteger inDataSize = 3;
- float inData[] = {1.0f, 2.0f, 3.0f};
- - TFLFloatBuffer *inBuffer = [[TFLFloatBuffer alloc] initWithData:&(inData[0]) size:inDataSize];
- + TFLFloatBuffer* inBuffer =
- + [[TFLFloatBuffer alloc] initWithData:&(inData[0]) size:inDataSize];
-
- NSInteger bufferSize = 5;
- - TFLRingBuffer *ringBuffer = [[TFLRingBuffer alloc] initWithBufferSize:bufferSize];
- + TFLRingBuffer* ringBuffer =
- + [[TFLRingBuffer alloc] initWithBufferSize:bufferSize];
-
- - XCTAssertTrue([ringBuffer loadBuffer:inBuffer offset:0 size:inDataSize error:nil]);
- + XCTAssertTrue([ringBuffer loadBuffer:inBuffer
- + offset:0
- + size:inDataSize
- + error:nil]);
-
- // State after load: [0.0, 0.0, 1.0, 2.0, 3.0]
-
- - TFLFloatBuffer *outBuffer = ringBuffer.floatBuffer;
- + TFLFloatBuffer* outBuffer = ringBuffer.floatBuffer;
- XCTAssertNotNil(outBuffer);
- XCTAssertEqual(outBuffer.size, bufferSize);
-
- @@ -80,23 +92,32 @@ NS_ASSUME_NONNULL_BEGIN
- NSInteger initialDataSize = 4;
- float initialArray[] = {1.0f, 2.0f, 3.0f, 4.0f};
-
- - TFLFloatBuffer *initialBuffer =
- - [[TFLFloatBuffer alloc] initWithData:&(initialArray[0]) size:initialDataSize];
- + TFLFloatBuffer* initialBuffer =
- + [[TFLFloatBuffer alloc] initWithData:&(initialArray[0])
- + size:initialDataSize];
-
- NSInteger bufferSize = 5;
- - TFLRingBuffer *ringBuffer = [[TFLRingBuffer alloc] initWithBufferSize:bufferSize];
- + TFLRingBuffer* ringBuffer =
- + [[TFLRingBuffer alloc] initWithBufferSize:bufferSize];
-
- - XCTAssertTrue([ringBuffer loadBuffer:initialBuffer offset:0 size:initialDataSize error:nil]);
- + XCTAssertTrue([ringBuffer loadBuffer:initialBuffer
- + offset:0
- + size:initialDataSize
- + error:nil]);
-
- // State after load: [0.0, 1.0, 2.0, 3.0, 4.0]
-
- NSInteger inDataSize = 3;
- float inArray[] = {5, 6, 7};
- - TFLFloatBuffer *inBuffer = [[TFLFloatBuffer alloc] initWithData:&(inArray[0]) size:inDataSize];
- + TFLFloatBuffer* inBuffer =
- + [[TFLFloatBuffer alloc] initWithData:&(inArray[0]) size:inDataSize];
-
- - XCTAssertTrue([ringBuffer loadBuffer:inBuffer offset:0 size:inDataSize error:nil]);
- + XCTAssertTrue([ringBuffer loadBuffer:inBuffer
- + offset:0
- + size:inDataSize
- + error:nil]);
-
- - TFLFloatBuffer *outBuffer = ringBuffer.floatBuffer;
- + TFLFloatBuffer* outBuffer = ringBuffer.floatBuffer;
- XCTAssertNotNil(outBuffer);
- XCTAssertEqual(outBuffer.size, bufferSize);
-
- @@ -112,24 +133,33 @@ NS_ASSUME_NONNULL_BEGIN
- NSInteger initialDataSize = 5;
- float initialArray[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f};
-
- - TFLFloatBuffer *initialBuffer =
- - [[TFLFloatBuffer alloc] initWithData:&(initialArray[0]) size:initialDataSize];
- + TFLFloatBuffer* initialBuffer =
- + [[TFLFloatBuffer alloc] initWithData:&(initialArray[0])
- + size:initialDataSize];
-
- NSInteger bufferSize = 5;
- - TFLRingBuffer *ringBuffer = [[TFLRingBuffer alloc] initWithBufferSize:bufferSize];
- + TFLRingBuffer* ringBuffer =
- + [[TFLRingBuffer alloc] initWithBufferSize:bufferSize];
-
- - XCTAssertTrue([ringBuffer loadBuffer:initialBuffer offset:0 size:initialDataSize error:nil]);
- + XCTAssertTrue([ringBuffer loadBuffer:initialBuffer
- + offset:0
- + size:initialDataSize
- + error:nil]);
-
- // State after load: [1.0, 2.0, 3.0, 4.0, 5.0]
-
- NSInteger sourceDataSize = 6;
- float sourceArray[] = {6, 7, 8, 9, 10, 11};
- - TFLFloatBuffer *sourceBuffer =
- - [[TFLFloatBuffer alloc] initWithData:&(sourceArray[0]) size:sourceDataSize];
- + TFLFloatBuffer* sourceBuffer =
- + [[TFLFloatBuffer alloc] initWithData:&(sourceArray[0])
- + size:sourceDataSize];
-
- - XCTAssertTrue([ringBuffer loadBuffer:sourceBuffer offset:0 size:sourceDataSize error:nil]);
- + XCTAssertTrue([ringBuffer loadBuffer:sourceBuffer
- + offset:0
- + size:sourceDataSize
- + error:nil]);
-
- - TFLFloatBuffer *outBuffer = ringBuffer.floatBuffer;
- + TFLFloatBuffer* outBuffer = ringBuffer.floatBuffer;
- XCTAssertNotNil(outBuffer);
- XCTAssertEqual(outBuffer.size, bufferSize);
-
- @@ -145,25 +175,34 @@ NS_ASSUME_NONNULL_BEGIN
- NSInteger initialDataSize = 5;
- float initialArray[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f};
-
- - TFLFloatBuffer *initialBuffer =
- - [[TFLFloatBuffer alloc] initWithData:&(initialArray[0]) size:initialDataSize];
- + TFLFloatBuffer* initialBuffer =
- + [[TFLFloatBuffer alloc] initWithData:&(initialArray[0])
- + size:initialDataSize];
-
- NSInteger bufferSize = 5;
- - TFLRingBuffer *ringBuffer = [[TFLRingBuffer alloc] initWithBufferSize:bufferSize];
- + TFLRingBuffer* ringBuffer =
- + [[TFLRingBuffer alloc] initWithBufferSize:bufferSize];
-
- - XCTAssertTrue([ringBuffer loadBuffer:initialBuffer offset:0 size:initialDataSize error:nil]);
- + XCTAssertTrue([ringBuffer loadBuffer:initialBuffer
- + offset:0
- + size:initialDataSize
- + error:nil]);
-
- // State after load: [1.0, 2.0, 3.0, 4.0, 5.0]
-
- NSInteger totalInSize = 8;
- float inArray[] = {6, 7, 8, 9, 10, 11, 12, 13};
- - TFLFloatBuffer *inBuffer = [[TFLFloatBuffer alloc] initWithData:&(inArray[0]) size:totalInSize];
- + TFLFloatBuffer* inBuffer =
- + [[TFLFloatBuffer alloc] initWithData:&(inArray[0]) size:totalInSize];
-
- NSInteger offset = 2;
- NSInteger inDataSize = 6;
- - XCTAssertTrue([ringBuffer loadBuffer:inBuffer offset:offset size:inDataSize error:nil]);
- + XCTAssertTrue([ringBuffer loadBuffer:inBuffer
- + offset:offset
- + size:inDataSize
- + error:nil]);
-
- - TFLFloatBuffer *outBuffer = ringBuffer.floatBuffer;
- + TFLFloatBuffer* outBuffer = ringBuffer.floatBuffer;
- XCTAssertNotNil(outBuffer);
- XCTAssertEqual(outBuffer.size, bufferSize);
-
- @@ -179,25 +218,34 @@ NS_ASSUME_NONNULL_BEGIN
- NSInteger initialDataSize = 2;
- float initialArray[] = {1.0f, 2.0f};
-
- - TFLFloatBuffer *initialBuffer =
- - [[TFLFloatBuffer alloc] initWithData:&(initialArray[0]) size:initialDataSize];
- + TFLFloatBuffer* initialBuffer =
- + [[TFLFloatBuffer alloc] initWithData:&(initialArray[0])
- + size:initialDataSize];
-
- NSInteger bufferSize = 5;
- - TFLRingBuffer *ringBuffer = [[TFLRingBuffer alloc] initWithBufferSize:bufferSize];
- + TFLRingBuffer* ringBuffer =
- + [[TFLRingBuffer alloc] initWithBufferSize:bufferSize];
-
- - XCTAssertTrue([ringBuffer loadBuffer:initialBuffer offset:0 size:initialDataSize error:nil]);
- + XCTAssertTrue([ringBuffer loadBuffer:initialBuffer
- + offset:0
- + size:initialDataSize
- + error:nil]);
-
- // State after load: [0.0, 0.0, 0.0, 1.0, 2.0]
-
- NSInteger totalInSize = 4;
- float inArray[] = {6.0f, 7.0f, 8.0f, 9.0f};
- - TFLFloatBuffer *inBuffer = [[TFLFloatBuffer alloc] initWithData:&(inArray[0]) size:totalInSize];
- + TFLFloatBuffer* inBuffer =
- + [[TFLFloatBuffer alloc] initWithData:&(inArray[0]) size:totalInSize];
-
- NSInteger offset = 2;
- NSInteger inDataSize = 2;
- - XCTAssertTrue([ringBuffer loadBuffer:inBuffer offset:offset size:inDataSize error:nil]);
- + XCTAssertTrue([ringBuffer loadBuffer:inBuffer
- + offset:offset
- + size:inDataSize
- + error:nil]);
-
- - TFLFloatBuffer *outBuffer = ringBuffer.floatBuffer;
- + TFLFloatBuffer* outBuffer = ringBuffer.floatBuffer;
- XCTAssertNotNil(outBuffer);
- XCTAssertEqual(outBuffer.size, bufferSize);
-
- @@ -213,26 +261,36 @@ NS_ASSUME_NONNULL_BEGIN
- NSInteger initialDataSize = 2;
- float initialArray[] = {1.0f, 2.0f};
-
- - TFLFloatBuffer *initialBuffer =
- - [[TFLFloatBuffer alloc] initWithData:&(initialArray[0]) size:initialDataSize];
- + TFLFloatBuffer* initialBuffer =
- + [[TFLFloatBuffer alloc] initWithData:&(initialArray[0])
- + size:initialDataSize];
-
- NSInteger bufferSize = 5;
- - TFLRingBuffer *ringBuffer = [[TFLRingBuffer alloc] initWithBufferSize:bufferSize];
- + TFLRingBuffer* ringBuffer =
- + [[TFLRingBuffer alloc] initWithBufferSize:bufferSize];
-
- - XCTAssertTrue([ringBuffer loadBuffer:initialBuffer offset:0 size:initialDataSize error:nil]);
- + XCTAssertTrue([ringBuffer loadBuffer:initialBuffer
- + offset:0
- + size:initialDataSize
- + error:nil]);
-
- NSInteger totalInSize = 4;
- float inArray[] = {6.0f, 7.0f, 8.0f, 9.0f};
- - TFLFloatBuffer *inBuffer = [[TFLFloatBuffer alloc] initWithData:&(inArray[0]) size:totalInSize];
- + TFLFloatBuffer* inBuffer =
- + [[TFLFloatBuffer alloc] initWithData:&(inArray[0]) size:totalInSize];
-
- NSInteger offset = 2;
- NSInteger inDataSize = 3;
-
- - NSError *error = nil;
- - XCTAssertFalse([ringBuffer loadBuffer:inBuffer offset:offset size:inDataSize error:&error]);
- + NSError* error = nil;
- + XCTAssertFalse([ringBuffer loadBuffer:inBuffer
- + offset:offset
- + size:inDataSize
- + error:&error]);
-
- XCTAssertNotNil(error);
- - VerifyError(error, @"org.tensorflow.lite.tasks", TFLSupportErrorCodeInvalidArgumentError,
- + VerifyError(error, @"org.tensorflow.lite.tasks",
- + TFLSupportErrorCodeInvalidArgumentError,
- @"offset + size exceeds the maximum size of the source buffer.");
- }
-
- @@ -240,19 +298,24 @@ NS_ASSUME_NONNULL_BEGIN
- NSInteger initialDataSize = 2;
- float initialArray[] = {1.0f, 2.0f};
-
- - TFLFloatBuffer *initialBuffer =
- - [[TFLFloatBuffer alloc] initWithData:&(initialArray[0]) size:initialDataSize];
- + TFLFloatBuffer* initialBuffer =
- + [[TFLFloatBuffer alloc] initWithData:&(initialArray[0])
- + size:initialDataSize];
-
- NSInteger bufferSize = 5;
- - TFLRingBuffer *ringBuffer = [[TFLRingBuffer alloc] initWithBufferSize:bufferSize];
- + TFLRingBuffer* ringBuffer =
- + [[TFLRingBuffer alloc] initWithBufferSize:bufferSize];
-
- - XCTAssertTrue([ringBuffer loadBuffer:initialBuffer offset:0 size:initialDataSize error:nil]);
- + XCTAssertTrue([ringBuffer loadBuffer:initialBuffer
- + offset:0
- + size:initialDataSize
- + error:nil]);
-
- [ringBuffer clear];
-
- float expectedData[] = {0.0f, 0.0f, 0.0f, 0.0f, 0.0f};
-
- - TFLFloatBuffer *outBuffer = ringBuffer.floatBuffer;
- + TFLFloatBuffer* outBuffer = ringBuffer.floatBuffer;
- XCTAssertNotNil(outBuffer);
- XCTAssertEqual(outBuffer.size, bufferSize);
-
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/vision/image_classifier/TFLImageClassifierTests.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/vision/image_classifier/TFLImageClassifierTests.m
- index d03b6044bdd68..b1ed8cf1e2f6d 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/vision/image_classifier/TFLImageClassifierTests.m
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/vision/image_classifier/TFLImageClassifierTests.m
- @@ -29,8 +29,9 @@ NS_ASSUME_NONNULL_BEGIN
- // Put setup code here. This method is called before the invocation of each test method in the
- // class.
- [super setUp];
- - self.modelPath = [[NSBundle bundleForClass:self.class] pathForResource:@"mobilenet_v2_1.0_224"
- - ofType:@"tflite"];
- + self.modelPath = [[NSBundle bundleForClass:self.class]
- + pathForResource:@"mobilenet_v2_1.0_224"
- + ofType:@"tflite"];
- XCTAssertNotNil(self.modelPath);
- }
-
- @@ -42,8 +43,9 @@ NS_ASSUME_NONNULL_BEGIN
- [TFLImageClassifier imageClassifierWithOptions:imageClassifierOptions error:nil];
- XCTAssertNotNil(imageClassifier);
-
- - GMLImage *gmlImage =
- - [GMLImage imageFromBundleWithClass:self.class fileName:@"burger" ofType:@"jpg"];
- + GMLImage* gmlImage = [GMLImage imageFromBundleWithClass:self.class
- + fileName:@"burger"
- + ofType:@"jpg"];
- XCTAssertNotNil(gmlImage);
-
- TFLClassificationResult *classificationResults = [imageClassifier classifyWithGMLImage:gmlImage
- @@ -67,14 +69,16 @@ NS_ASSUME_NONNULL_BEGIN
- [TFLImageClassifier imageClassifierWithOptions:imageClassifierOptions error:nil];
- XCTAssertNotNil(imageClassifier);
-
- - GMLImage *gmlImage =
- - [GMLImage imageFromBundleWithClass:self.class fileName:@"burger" ofType:@"jpg"];
- + GMLImage* gmlImage = [GMLImage imageFromBundleWithClass:self.class
- + fileName:@"burger"
- + ofType:@"jpg"];
- XCTAssertNotNil(gmlImage);
-
- TFLClassificationResult *classificationResults = [imageClassifier classifyWithGMLImage:gmlImage
- error:nil];
- XCTAssertTrue(classificationResults.classifications.count > 0);
- - XCTAssertLessThanOrEqual(classificationResults.classifications[0].categories.count, maxResults);
- + XCTAssertLessThanOrEqual(
- + classificationResults.classifications[0].categories.count, maxResults);
-
- TFLCategory *category = classificationResults.classifications[0].categories[0];
- XCTAssertTrue([category.label isEqual:@"cheeseburger"]);
- @@ -92,8 +96,9 @@ NS_ASSUME_NONNULL_BEGIN
- [TFLImageClassifier imageClassifierWithOptions:imageClassifierOptions error:nil];
- XCTAssertNotNil(imageClassifier);
-
- - GMLImage *gmlImage =
- - [GMLImage imageFromBundleWithClass:self.class fileName:@"multi_objects" ofType:@"jpg"];
- + GMLImage* gmlImage = [GMLImage imageFromBundleWithClass:self.class
- + fileName:@"multi_objects"
- + ofType:@"jpg"];
- XCTAssertNotNil(gmlImage);
-
- CGRect roi = CGRectMake(406, 110, 148, 153);
- @@ -117,8 +122,9 @@ NS_ASSUME_NONNULL_BEGIN
- [TFLImageClassifier imageClassifierWithOptions:imageClassifierOptions error:nil];
- XCTAssertNotNil(imageClassifier);
-
- - GMLImage *gmlImage =
- - [GMLImage imageFromBundleWithClass:self.class fileName:@"sparrow" ofType:@"png"];
- + GMLImage* gmlImage = [GMLImage imageFromBundleWithClass:self.class
- + fileName:@"sparrow"
- + ofType:@"png"];
- XCTAssertNotNil(gmlImage);
-
- TFLClassificationResult *classificationResults = [imageClassifier classifyWithGMLImage:gmlImage
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/vision/image_segmenter/TFLImageSegmenterTests.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/vision/image_segmenter/TFLImageSegmenterTests.m
- index c2977475f6d4f..f483a516b9bc6 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/vision/image_segmenter/TFLImageSegmenterTests.m
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/vision/image_segmenter/TFLImageSegmenterTests.m
- @@ -18,10 +18,11 @@
- #import "tensorflow_lite_support/ios/task/vision/sources/TFLImageSegmenter.h"
- #import "tensorflow_lite_support/ios/task/vision/utils/sources/GMLImage+Utils.h"
-
- -#define VerifyColoredLabel(coloredLabel, expectedR, expectedG, expectedB, expectedLabel) \
- - XCTAssertEqual(coloredLabel.r, expectedR); \
- - XCTAssertEqual(coloredLabel.g, expectedG); \
- - XCTAssertEqual(coloredLabel.b, expectedB); \
- +#define VerifyColoredLabel(coloredLabel, expectedR, expectedG, expectedB, \
- + expectedLabel) \
- + XCTAssertEqual(coloredLabel.r, expectedR); \
- + XCTAssertEqual(coloredLabel.g, expectedG); \
- + XCTAssertEqual(coloredLabel.b, expectedB); \
- XCTAssertEqualObjects(coloredLabel.label, expectedLabel)
-
- // The maximum fraction of pixels in the candidate mask that can have a
- @@ -40,22 +41,24 @@ NSInteger const deepLabV3SegmentationHeight = 257;
-
- @interface TFLImageSegmenterTests : XCTestCase
-
- -@property(nonatomic, nullable) NSString *modelPath;
- +@property(nonatomic, nullable) NSString* modelPath;
-
- @end
-
- @implementation TFLImageSegmenterTests
-
- - (void)setUp {
- - // Put setup code here. This method is called before the invocation of each test method in the
- - // class.
- + // Put setup code here. This method is called before the invocation of each
- + // test method in the class.
- [super setUp];
- - self.modelPath = [[NSBundle bundleForClass:self.class] pathForResource:@"deeplabv3"
- - ofType:@"tflite"];
- + self.modelPath =
- + [[NSBundle bundleForClass:self.class] pathForResource:@"deeplabv3"
- + ofType:@"tflite"];
- XCTAssertNotNil(self.modelPath);
- }
-
- -- (void)compareWithDeepLabV3PartialColoredLabels:(NSArray<TFLColoredLabel *> *)coloredLabels {
- +- (void)compareWithDeepLabV3PartialColoredLabels:
- + (NSArray<TFLColoredLabel*>*)coloredLabels {
- VerifyColoredLabel(coloredLabels[0],
- 0, // expectedR
- 0, // expectedG
- @@ -204,58 +207,67 @@ NSInteger const deepLabV3SegmentationHeight = 257;
- }
-
- - (void)testSuccessfulImageSegmentationWithCategoryMask {
- - TFLImageSegmenterOptions *imageSegmenterOptions =
- + TFLImageSegmenterOptions* imageSegmenterOptions =
- [[TFLImageSegmenterOptions alloc] initWithModelPath:self.modelPath];
-
- - TFLImageSegmenter *imageSegmenter =
- - [TFLImageSegmenter imageSegmenterWithOptions:imageSegmenterOptions error:nil];
- + TFLImageSegmenter* imageSegmenter =
- + [TFLImageSegmenter imageSegmenterWithOptions:imageSegmenterOptions
- + error:nil];
- XCTAssertNotNil(imageSegmenter);
-
- - GMLImage *gmlImage = [GMLImage imageFromBundleWithClass:self.class
- - fileName:@"segmentation_input_rotation0"
- - ofType:@"jpg"];
- + GMLImage* gmlImage =
- + [GMLImage imageFromBundleWithClass:self.class
- + fileName:@"segmentation_input_rotation0"
- + ofType:@"jpg"];
- XCTAssertNotNil(gmlImage);
-
- - TFLSegmentationResult *segmentationResult = [imageSegmenter segmentWithGMLImage:gmlImage
- - error:nil];
- + TFLSegmentationResult* segmentationResult =
- + [imageSegmenter segmentWithGMLImage:gmlImage error:nil];
-
- XCTAssertNotNil(segmentationResult);
- XCTAssertEqual(segmentationResult.segmentations.count, 1);
-
- XCTAssertNotNil(segmentationResult.segmentations[0].coloredLabels);
- - [self compareWithDeepLabV3PartialColoredLabels:segmentationResult.segmentations[0].coloredLabels];
- + [self compareWithDeepLabV3PartialColoredLabels:segmentationResult
- + .segmentations[0]
- + .coloredLabels];
-
- XCTAssertNotNil(segmentationResult.segmentations[0].categoryMask);
- XCTAssertTrue(segmentationResult.segmentations[0].categoryMask.mask != nil);
-
- - GMLImage *goldenImage = [GMLImage imageFromBundleWithClass:self.class
- - fileName:@"segmentation_golden_rotation0"
- - ofType:@"png"];
- + GMLImage* goldenImage =
- + [GMLImage imageFromBundleWithClass:self.class
- + fileName:@"segmentation_golden_rotation0"
- + ofType:@"png"];
-
- XCTAssertNotNil(goldenImage);
- CVPixelBufferRef pixelBuffer = [goldenImage grayScalePixelBuffer];
-
- CVPixelBufferLockBaseAddress(pixelBuffer, kCVPixelBufferLock_ReadOnly);
-
- - UInt8 *pixelBufferBaseAddress = (UInt8 *)CVPixelBufferGetBaseAddress(pixelBuffer);
- + UInt8* pixelBufferBaseAddress =
- + (UInt8*)CVPixelBufferGetBaseAddress(pixelBuffer);
-
- XCTAssertEqual(deepLabV3SegmentationWidth,
- segmentationResult.segmentations[0].categoryMask.width);
- XCTAssertEqual(deepLabV3SegmentationHeight,
- segmentationResult.segmentations[0].categoryMask.height);
-
- - NSInteger numPixels = deepLabV3SegmentationWidth * deepLabV3SegmentationHeight;
- + NSInteger numPixels =
- + deepLabV3SegmentationWidth * deepLabV3SegmentationHeight;
-
- float inconsistentPixels = 0;
-
- for (int i = 0; i < numPixels; i++)
- - if (segmentationResult.segmentations[0].categoryMask.mask[i] * kGoldenMaskMagnificationFactor !=
- + if (segmentationResult.segmentations[0].categoryMask.mask[i] *
- + kGoldenMaskMagnificationFactor !=
- pixelBufferBaseAddress[i])
- inconsistentPixels += 1;
-
- CVPixelBufferUnlockBaseAddress(pixelBuffer, kCVPixelBufferLock_ReadOnly);
-
- - XCTAssertLessThan(inconsistentPixels / (float)numPixels, kGoldenMaskTolerance);
- + XCTAssertLessThan(inconsistentPixels / (float)numPixels,
- + kGoldenMaskTolerance);
- }
-
- @end
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/vision/object_detector/TFLObjectDetectorTests.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/vision/object_detector/TFLObjectDetectorTests.m
- index f6820f335e18b..f7091a5995b02 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/vision/object_detector/TFLObjectDetectorTests.m
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/vision/object_detector/TFLObjectDetectorTests.m
- @@ -18,16 +18,22 @@
- #import "tensorflow_lite_support/ios/task/vision/sources/TFLObjectDetector.h"
- #import "tensorflow_lite_support/ios/task/vision/utils/sources/GMLImage+Utils.h"
-
- -#define VerifyDetection(detection, expectedBoundingBox, expectedFirstScore, expectedFirstLabel) \
- - XCTAssertGreaterThan(detection.categories.count, 0); \
- - NSLog(@"Detected %f", detection.categories[0].score); \
- - NSLog(@"Expected %f", expectedFirstScore); \
- - XCTAssertEqual(detection.boundingBox.origin.x, expectedBoundingBox.origin.x); \
- - XCTAssertEqual(detection.boundingBox.origin.y, expectedBoundingBox.origin.y); \
- - XCTAssertEqual(detection.boundingBox.size.width, expectedBoundingBox.size.width); \
- - XCTAssertEqual(detection.boundingBox.size.height, expectedBoundingBox.size.height); \
- - XCTAssertEqualObjects(detection.categories[0].label, expectedFirstLabel); \
- - XCTAssertEqualWithAccuracy(detection.categories[0].score, expectedFirstScore, 0.001)
- +#define VerifyDetection(detection, expectedBoundingBox, expectedFirstScore, \
- + expectedFirstLabel) \
- + XCTAssertGreaterThan(detection.categories.count, 0); \
- + NSLog(@"Detected %f", detection.categories[0].score); \
- + NSLog(@"Expected %f", expectedFirstScore); \
- + XCTAssertEqual(detection.boundingBox.origin.x, \
- + expectedBoundingBox.origin.x); \
- + XCTAssertEqual(detection.boundingBox.origin.y, \
- + expectedBoundingBox.origin.y); \
- + XCTAssertEqual(detection.boundingBox.size.width, \
- + expectedBoundingBox.size.width); \
- + XCTAssertEqual(detection.boundingBox.size.height, \
- + expectedBoundingBox.size.height); \
- + XCTAssertEqualObjects(detection.categories[0].label, expectedFirstLabel); \
- + XCTAssertEqualWithAccuracy(detection.categories[0].score, \
- + expectedFirstScore, 0.001)
-
- @interface TFLObjectDetectorTests : XCTestCase
- @property(nonatomic, nullable) NSString *modelPath;
- @@ -77,8 +83,9 @@
- [TFLObjectDetector objectDetectorWithOptions:objectDetectorOptions error:nil];
- XCTAssertNotNil(objectDetector);
-
- - GMLImage *gmlImage =
- - [GMLImage imageFromBundleWithClass:self.class fileName:@"cats_and_dogs" ofType:@"jpg"];
- + GMLImage* gmlImage = [GMLImage imageFromBundleWithClass:self.class
- + fileName:@"cats_and_dogs"
- + ofType:@"jpg"];
- XCTAssertNotNil(gmlImage);
-
- TFLDetectionResult *detectionResults = [objectDetector detectWithGMLImage:gmlImage error:nil];
- @@ -95,8 +102,9 @@
- [TFLObjectDetector objectDetectorWithOptions:objectDetectorOptions error:nil];
- XCTAssertNotNil(objectDetector);
-
- - GMLImage *gmlImage =
- - [GMLImage imageFromBundleWithClass:self.class fileName:@"cats_and_dogs" ofType:@"jpg"];
- + GMLImage* gmlImage = [GMLImage imageFromBundleWithClass:self.class
- + fileName:@"cats_and_dogs"
- + ofType:@"jpg"];
- XCTAssertNotNil(gmlImage);
-
- TFLDetectionResult *detectionResult = [objectDetector detectWithGMLImage:gmlImage error:nil];
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLBertTokenizer.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLBertTokenizer.h
- index ed679c22a467b..c10c82afc1913 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLBertTokenizer.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLBertTokenizer.h
- @@ -28,11 +28,13 @@ NS_ASSUME_NONNULL_BEGIN
- /**
- * Initializes the tokenizer with the path to wordpiece vocabulary file.
- */
- -- (instancetype)initWithVocabPath:(NSString *)vocabPath NS_DESIGNATED_INITIALIZER;
- +- (instancetype)initWithVocabPath:(NSString*)vocabPath
- + NS_DESIGNATED_INITIALIZER;
-
- /**
- * Initializes the tokenizer with a list of tokens.
- */
- -- (instancetype)initWithVocab:(NSArray<NSString *> *)vocab NS_DESIGNATED_INITIALIZER;
- +- (instancetype)initWithVocab:(NSArray<NSString*>*)vocab
- + NS_DESIGNATED_INITIALIZER;
- @end
- NS_ASSUME_NONNULL_END
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLSentencepieceTokenizer.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLSentencepieceTokenizer.h
- index f556dc642d736..be4010abd8e6f 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLSentencepieceTokenizer.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLSentencepieceTokenizer.h
- @@ -28,6 +28,6 @@ NS_ASSUME_NONNULL_BEGIN
- /**
- * Initializes the tokenizer with the path to sentencepiece model file.
- */
- -- (instancetype)initWithModelPath:(NSString *)modelPath;
- +- (instancetype)initWithModelPath:(NSString*)modelPath;
- @end
- NS_ASSUME_NONNULL_END
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLTokenizer.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLTokenizer.h
- index ee0972f8aba30..bd832060b6e80 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLTokenizer.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLTokenizer.h
- @@ -26,7 +26,7 @@ NS_ASSUME_NONNULL_BEGIN
- *
- * @return A list of tokens.
- */
- -- (NSArray<NSString *> *)tokensFromInput:(NSString *)input;
- +- (NSArray<NSString*>*)tokensFromInput:(NSString*)input;
-
- /*
- * Convert a list of tokens back to their coressponding IDs.
- @@ -34,6 +34,6 @@ NS_ASSUME_NONNULL_BEGIN
- *
- * @return A list of ids.
- */
- -- (NSArray<NSNumber *> *)idsFromTokens:(NSArray<NSString *> *)tokens;
- +- (NSArray<NSNumber*>*)idsFromTokens:(NSArray<NSString*>*)tokens;
- @end
- NS_ASSUME_NONNULL_END
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLTokenizerUtil.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLTokenizerUtil.h
- index 574b555301616..14e2906675b71 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLTokenizerUtil.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLTokenizerUtil.h
- @@ -18,21 +18,24 @@ limitations under the License.
- using ::tflite::support::text::tokenizer::Tokenizer;
-
- /**
- - * Invokes the cpp tokenizer's tokenize function and converts input/output to objc.
- + * Invokes the cpp tokenizer's tokenize function and converts input/output to
- + * objc.
- *
- * @param tokenizer The cpp tokenizer pointer.
- * @param input The input string to be tokenized.
- *
- * @return A list of tokens.
- */
- -NSArray<NSString *> *Tokenize(Tokenizer *tokenizer, NSString *input);
- +NSArray<NSString*>* Tokenize(Tokenizer* tokenizer, NSString* input);
-
- /**
- - * Invokes the cpp tokenizer's convertTokensToIds function and converts input/output to objc.
- + * Invokes the cpp tokenizer's convertTokensToIds function and converts
- + * input/output to objc.
- *
- * @param tokenizer The cpp tokenizer pointer.
- * @param input The tokens to be converted.
- *
- * @return A list of ids.
- */
- -NSArray<NSNumber *> *ConvertTokensToIds(Tokenizer *tokenizer, NSArray<NSString *> *tokens);
- +NSArray<NSNumber*>* ConvertTokensToIds(Tokenizer* tokenizer,
- + NSArray<NSString*>* tokens);
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/utils/Sources/TFLStringUtil.mm b/third_party/tflite_support/src/tensorflow_lite_support/ios/utils/Sources/TFLStringUtil.mm
- index 6e9cf23802427..2a11bb6730474 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/utils/Sources/TFLStringUtil.mm
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/utils/Sources/TFLStringUtil.mm
- @@ -14,10 +14,13 @@ limitations under the License.
- ==============================================================================*/
- #import "third_party/tensorflow_lite_support/ios/utils/Sources/TFLStringUtil.h"
-
- -std::string MakeString(NSString* str) { return std::string([str UTF8String]); }
- +std::string MakeString(NSString* str) {
- + return std::string([str UTF8String]);
- +}
-
- NSString* MakeNSString(const std::string& str) {
- - return [[NSString alloc] initWithBytes:const_cast<void*>(static_cast<const void*>(str.data()))
- - length:str.length()
- - encoding:NSUTF8StringEncoding];
- + return [[NSString alloc]
- + initWithBytes:const_cast<void*>(static_cast<const void*>(str.data()))
- + length:str.length()
- + encoding:NSUTF8StringEncoding];
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/audio/TensorAudio.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/audio/TensorAudio.java
- index 2b59c675b0316..6f2f2d437fb4a 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/audio/TensorAudio.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/audio/TensorAudio.java
- @@ -15,19 +15,24 @@ limitations under the License.
-
- package org.tensorflow.lite.support.audio;
-
- -import static java.lang.System.arraycopy;
- import static org.tensorflow.lite.support.common.internal.SupportPreconditions.checkArgument;
-
- +import static java.lang.System.arraycopy;
- +
- import android.media.AudioFormat;
- import android.media.AudioRecord;
- import android.os.Build;
- +
- import androidx.annotation.RequiresApi;
- +
- import com.google.auto.value.AutoValue;
- +
- +import org.tensorflow.lite.DataType;
- +import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
- +
- import java.nio.ByteBuffer;
- import java.nio.ByteOrder;
- import java.nio.FloatBuffer;
- -import org.tensorflow.lite.DataType;
- -import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
-
- /**
- * Defines a ring buffer and some utility functions to prepare the input audio samples.
- @@ -60,285 +65,282 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
- * </pre>
- */
- public class TensorAudio {
- + private static final String TAG = TensorAudio.class.getSimpleName();
- + private final FloatRingBuffer buffer;
- + private final TensorAudioFormat format;
-
- - private static final String TAG = TensorAudio.class.getSimpleName();
- - private final FloatRingBuffer buffer;
- - private final TensorAudioFormat format;
- -
- - /**
- - * Creates a {@link android.media.AudioRecord} instance with a ring buffer whose size is {@code
- - * sampleCounts} * {@code format.getChannels()}.
- - *
- - * @param format the expected {@link TensorAudioFormat} of audio data loaded into this class.
- - * @param sampleCounts the number of samples to be fed into the model
- - */
- - public static TensorAudio create(TensorAudioFormat format, int sampleCounts) {
- - return new TensorAudio(format, sampleCounts);
- - }
- -
- - /**
- - * Creates a {@link TensorAudio} instance with a ring buffer whose size is {@code sampleCounts} *
- - * {@code format.getChannelCount()}.
- - *
- - * @param format the {@link android.media.AudioFormat} required by the TFLite model. It defines
- - * the number of channels and sample rate.
- - * @param sampleCounts the number of samples to be fed into the model
- - */
- - public static TensorAudio create(AudioFormat format, int sampleCounts) {
- - return new TensorAudio(TensorAudioFormat.create(format), sampleCounts);
- - }
- -
- - /**
- - * Wraps a few constants describing the format of the incoming audio samples, namely number of
- - * channels and the sample rate. By default, channels is set to 1.
- - */
- - @AutoValue
- - public abstract static class TensorAudioFormat {
- - private static final int DEFAULT_CHANNELS = 1;
- -
- - /** Creates a {@link TensorAudioFormat} instance from Android AudioFormat class. */
- - @RequiresApi(Build.VERSION_CODES.M)
- - public static TensorAudioFormat create(AudioFormat format) {
- - return TensorAudioFormat.builder()
- - .setChannels(format.getChannelCount())
- - .setSampleRate(format.getSampleRate())
- - .build();
- + /**
- + * Creates a {@link android.media.AudioRecord} instance with a ring buffer whose size is {@code
- + * sampleCounts} * {@code format.getChannels()}.
- + *
- + * @param format the expected {@link TensorAudioFormat} of audio data loaded into this class.
- + * @param sampleCounts the number of samples to be fed into the model
- + */
- + public static TensorAudio create(TensorAudioFormat format, int sampleCounts) {
- + return new TensorAudio(format, sampleCounts);
- }
-
- - public abstract int getChannels();
- -
- - public abstract int getSampleRate();
- -
- - public static Builder builder() {
- - return new AutoValue_TensorAudio_TensorAudioFormat.Builder().setChannels(DEFAULT_CHANNELS);
- + /**
- + * Creates a {@link TensorAudio} instance with a ring buffer whose size is {@code sampleCounts}
- + * *
- + * {@code format.getChannelCount()}.
- + *
- + * @param format the {@link android.media.AudioFormat} required by the TFLite model. It defines
- + * the number of channels and sample rate.
- + * @param sampleCounts the number of samples to be fed into the model
- + */
- + public static TensorAudio create(AudioFormat format, int sampleCounts) {
- + return new TensorAudio(TensorAudioFormat.create(format), sampleCounts);
- }
-
- - /** Builder for {@link TensorAudioFormat} */
- - @AutoValue.Builder
- - public abstract static class Builder {
- -
- - /* By default, it's set to have 1 channel. */
- - public abstract Builder setChannels(int value);
- -
- - public abstract Builder setSampleRate(int value);
- -
- - abstract TensorAudioFormat autoBuild();
- -
- - public TensorAudioFormat build() {
- - TensorAudioFormat format = autoBuild();
- - checkArgument(format.getChannels() > 0, "Number of channels should be greater than 0");
- - checkArgument(format.getSampleRate() > 0, "Sample rate should be greater than 0");
- - return format;
- - }
- - }
- - }
- -
- - /**
- - * Stores the input audio samples {@code src} in the ring buffer.
- - *
- - * @param src input audio samples in {@link android.media.AudioFormat#ENCODING_PCM_FLOAT}. For
- - * multi-channel input, the array is interleaved.
- - */
- - public void load(float[] src) {
- - load(src, 0, src.length);
- - }
- -
- - /**
- - * Stores the input audio samples {@code src} in the ring buffer.
- - *
- - * @param src input audio samples in {@link android.media.AudioFormat#ENCODING_PCM_FLOAT}. For
- - * multi-channel input, the array is interleaved.
- - * @param offsetInFloat starting position in the {@code src} array
- - * @param sizeInFloat the number of float values to be copied
- - * @throws IllegalArgumentException for incompatible audio format or incorrect input size
- - */
- - public void load(float[] src, int offsetInFloat, int sizeInFloat) {
- - checkArgument(
- - sizeInFloat % format.getChannels() == 0,
- - String.format(
- - "Size (%d) needs to be a multiplier of the number of channels (%d)",
- - sizeInFloat, format.getChannels()));
- - buffer.load(src, offsetInFloat, sizeInFloat);
- - }
- -
- - /**
- - * Converts the input audio samples {@code src} to ENCODING_PCM_FLOAT, then stores it in the ring
- - * buffer.
- - *
- - * @param src input audio samples in {@link android.media.AudioFormat#ENCODING_PCM_16BIT}. For
- - * multi-channel input, the array is interleaved.
- - */
- - public void load(short[] src) {
- - load(src, 0, src.length);
- - }
- -
- - /**
- - * Converts the input audio samples {@code src} to ENCODING_PCM_FLOAT, then stores it in the ring
- - * buffer.
- - *
- - * @param src input audio samples in {@link android.media.AudioFormat#ENCODING_PCM_16BIT}. For
- - * multi-channel input, the array is interleaved.
- - * @param offsetInShort starting position in the src array
- - * @param sizeInShort the number of short values to be copied
- - * @throws IllegalArgumentException if the source array can't be copied
- - */
- - public void load(short[] src, int offsetInShort, int sizeInShort) {
- - checkArgument(
- - offsetInShort + sizeInShort <= src.length,
- - String.format(
- - "Index out of range. offset (%d) + size (%d) should <= newData.length (%d)",
- - offsetInShort, sizeInShort, src.length));
- - float[] floatData = new float[sizeInShort];
- - for (int i = 0; i < sizeInShort; i++) {
- - // Convert the data to PCM Float encoding i.e. values between -1 and 1
- - floatData[i] = src[i + offsetInShort] * 1.f / Short.MAX_VALUE;
- - }
- - load(floatData);
- - }
- -
- - /**
- - * Loads latest data from the {@link android.media.AudioRecord} in a non-blocking way. Only
- - * supporting ENCODING_PCM_16BIT and ENCODING_PCM_FLOAT.
- - *
- - * @param record an instance of {@link android.media.AudioRecord}
- - * @return number of captured audio values whose size is {@code channelCount * sampleCount}. If
- - * there was no new data in the AudioRecord or an error occurred, this method will return 0.
- - * @throws IllegalArgumentException for unsupported audio encoding format
- - * @throws IllegalStateException if reading from AudioRecord failed
- - */
- - @RequiresApi(Build.VERSION_CODES.M)
- - public int load(AudioRecord record) {
- - checkArgument(
- - this.format.equals(TensorAudioFormat.create(record.getFormat())),
- - "Incompatible audio format.");
- - int loadedValues = 0;
- - if (record.getAudioFormat() == AudioFormat.ENCODING_PCM_FLOAT) {
- - float[] newData = new float[record.getChannelCount() * record.getBufferSizeInFrames()];
- - loadedValues = record.read(newData, 0, newData.length, AudioRecord.READ_NON_BLOCKING);
- - if (loadedValues > 0) {
- - load(newData, 0, loadedValues);
- - return loadedValues;
- - }
- - } else if (record.getAudioFormat() == AudioFormat.ENCODING_PCM_16BIT) {
- - short[] newData = new short[record.getChannelCount() * record.getBufferSizeInFrames()];
- - loadedValues = record.read(newData, 0, newData.length, AudioRecord.READ_NON_BLOCKING);
- - if (loadedValues > 0) {
- - load(newData, 0, loadedValues);
- - return loadedValues;
- - }
- - } else {
- - throw new IllegalArgumentException(
- - "Unsupported encoding. Requires ENCODING_PCM_16BIT or ENCODING_PCM_FLOAT.");
- + /**
- + * Wraps a few constants describing the format of the incoming audio samples, namely number of
- + * channels and the sample rate. By default, channels is set to 1.
- + */
- + @AutoValue
- + public abstract static class TensorAudioFormat {
- + private static final int DEFAULT_CHANNELS = 1;
- +
- + /** Creates a {@link TensorAudioFormat} instance from Android AudioFormat class. */
- + @RequiresApi(Build.VERSION_CODES.M)
- + public static TensorAudioFormat create(AudioFormat format) {
- + return TensorAudioFormat.builder()
- + .setChannels(format.getChannelCount())
- + .setSampleRate(format.getSampleRate())
- + .build();
- + }
- +
- + public abstract int getChannels();
- +
- + public abstract int getSampleRate();
- +
- + public static Builder builder() {
- + return new AutoValue_TensorAudio_TensorAudioFormat.Builder().setChannels(
- + DEFAULT_CHANNELS);
- + }
- +
- + /** Builder for {@link TensorAudioFormat} */
- + @AutoValue.Builder
- + public abstract static class Builder {
- + /* By default, it's set to have 1 channel. */
- + public abstract Builder setChannels(int value);
- +
- + public abstract Builder setSampleRate(int value);
- +
- + abstract TensorAudioFormat autoBuild();
- +
- + public TensorAudioFormat build() {
- + TensorAudioFormat format = autoBuild();
- + checkArgument(
- + format.getChannels() > 0, "Number of channels should be greater than 0");
- + checkArgument(format.getSampleRate() > 0, "Sample rate should be greater than 0");
- + return format;
- + }
- + }
- }
-
- - switch (loadedValues) {
- - case AudioRecord.ERROR_INVALID_OPERATION:
- - throw new IllegalStateException("AudioRecord.ERROR_INVALID_OPERATION");
- -
- - case AudioRecord.ERROR_BAD_VALUE:
- - throw new IllegalStateException("AudioRecord.ERROR_BAD_VALUE");
- -
- - case AudioRecord.ERROR_DEAD_OBJECT:
- - throw new IllegalStateException("AudioRecord.ERROR_DEAD_OBJECT");
- + /**
- + * Stores the input audio samples {@code src} in the ring buffer.
- + *
- + * @param src input audio samples in {@link android.media.AudioFormat#ENCODING_PCM_FLOAT}. For
- + * multi-channel input, the array is interleaved.
- + */
- + public void load(float[] src) {
- + load(src, 0, src.length);
- + }
-
- - case AudioRecord.ERROR:
- - throw new IllegalStateException("AudioRecord.ERROR");
- + /**
- + * Stores the input audio samples {@code src} in the ring buffer.
- + *
- + * @param src input audio samples in {@link android.media.AudioFormat#ENCODING_PCM_FLOAT}. For
- + * multi-channel input, the array is interleaved.
- + * @param offsetInFloat starting position in the {@code src} array
- + * @param sizeInFloat the number of float values to be copied
- + * @throws IllegalArgumentException for incompatible audio format or incorrect input size
- + */
- + public void load(float[] src, int offsetInFloat, int sizeInFloat) {
- + checkArgument(sizeInFloat % format.getChannels() == 0,
- + String.format("Size (%d) needs to be a multiplier of the number of channels (%d)",
- + sizeInFloat, format.getChannels()));
- + buffer.load(src, offsetInFloat, sizeInFloat);
- + }
-
- - default:
- - return 0;
- + /**
- + * Converts the input audio samples {@code src} to ENCODING_PCM_FLOAT, then stores it in the
- + * ring buffer.
- + *
- + * @param src input audio samples in {@link android.media.AudioFormat#ENCODING_PCM_16BIT}. For
- + * multi-channel input, the array is interleaved.
- + */
- + public void load(short[] src) {
- + load(src, 0, src.length);
- }
- - }
- -
- - /**
- - * Returns a float {@link TensorBuffer} holding all the available audio samples in {@link
- - * android.media.AudioFormat#ENCODING_PCM_FLOAT} i.e. values are in the range of [-1, 1].
- - */
- - public TensorBuffer getTensorBuffer() {
- - ByteBuffer byteBuffer = buffer.getBuffer();
- - TensorBuffer tensorBuffer =
- - TensorBuffer.createFixedSize(
- - new int[] {
- - /* batch= */ 1, /* modelInputLengthInFloat= */ byteBuffer.asFloatBuffer().limit()
- - },
- - DataType.FLOAT32);
- - tensorBuffer.loadBuffer(byteBuffer);
- - return tensorBuffer;
- - }
- -
- - /* Returns the {@link TensorAudioFormat} associated with the tensor. */
- - public TensorAudioFormat getFormat() {
- - return format;
- - }
- -
- - private TensorAudio(TensorAudioFormat format, int sampleCounts) {
- - this.format = format;
- - this.buffer = new FloatRingBuffer(sampleCounts * format.getChannels());
- - }
- -
- - /** Actual implementation of the ring buffer. */
- - private static class FloatRingBuffer {
- -
- - private final float[] buffer;
- - private int nextIndex = 0;
- -
- - public FloatRingBuffer(int flatSize) {
- - buffer = new float[flatSize];
- +
- + /**
- + * Converts the input audio samples {@code src} to ENCODING_PCM_FLOAT, then stores it in the
- + * ring buffer.
- + *
- + * @param src input audio samples in {@link android.media.AudioFormat#ENCODING_PCM_16BIT}. For
- + * multi-channel input, the array is interleaved.
- + * @param offsetInShort starting position in the src array
- + * @param sizeInShort the number of short values to be copied
- + * @throws IllegalArgumentException if the source array can't be copied
- + */
- + public void load(short[] src, int offsetInShort, int sizeInShort) {
- + checkArgument(offsetInShort + sizeInShort <= src.length,
- + String.format(
- + "Index out of range. offset (%d) + size (%d) should <= newData.length (%d)",
- + offsetInShort, sizeInShort, src.length));
- + float[] floatData = new float[sizeInShort];
- + for (int i = 0; i < sizeInShort; i++) {
- + // Convert the data to PCM Float encoding i.e. values between -1 and 1
- + floatData[i] = src[i + offsetInShort] * 1.f / Short.MAX_VALUE;
- + }
- + load(floatData);
- }
-
- /**
- - * Loads the entire float array to the ring buffer. If the float array is longer than ring
- - * buffer's capacity, samples with lower indices in the array will be ignored.
- + * Loads latest data from the {@link android.media.AudioRecord} in a non-blocking way. Only
- + * supporting ENCODING_PCM_16BIT and ENCODING_PCM_FLOAT.
- + *
- + * @param record an instance of {@link android.media.AudioRecord}
- + * @return number of captured audio values whose size is {@code channelCount * sampleCount}. If
- + * there was no new data in the AudioRecord or an error occurred, this method will return 0.
- + * @throws IllegalArgumentException for unsupported audio encoding format
- + * @throws IllegalStateException if reading from AudioRecord failed
- */
- - public void load(float[] newData) {
- - load(newData, 0, newData.length);
- + @RequiresApi(Build.VERSION_CODES.M)
- + public int load(AudioRecord record) {
- + checkArgument(this.format.equals(TensorAudioFormat.create(record.getFormat())),
- + "Incompatible audio format.");
- + int loadedValues = 0;
- + if (record.getAudioFormat() == AudioFormat.ENCODING_PCM_FLOAT) {
- + float[] newData = new float[record.getChannelCount() * record.getBufferSizeInFrames()];
- + loadedValues = record.read(newData, 0, newData.length, AudioRecord.READ_NON_BLOCKING);
- + if (loadedValues > 0) {
- + load(newData, 0, loadedValues);
- + return loadedValues;
- + }
- + } else if (record.getAudioFormat() == AudioFormat.ENCODING_PCM_16BIT) {
- + short[] newData = new short[record.getChannelCount() * record.getBufferSizeInFrames()];
- + loadedValues = record.read(newData, 0, newData.length, AudioRecord.READ_NON_BLOCKING);
- + if (loadedValues > 0) {
- + load(newData, 0, loadedValues);
- + return loadedValues;
- + }
- + } else {
- + throw new IllegalArgumentException(
- + "Unsupported encoding. Requires ENCODING_PCM_16BIT or ENCODING_PCM_FLOAT.");
- + }
- +
- + switch (loadedValues) {
- + case AudioRecord.ERROR_INVALID_OPERATION:
- + throw new IllegalStateException("AudioRecord.ERROR_INVALID_OPERATION");
- +
- + case AudioRecord.ERROR_BAD_VALUE:
- + throw new IllegalStateException("AudioRecord.ERROR_BAD_VALUE");
- +
- + case AudioRecord.ERROR_DEAD_OBJECT:
- + throw new IllegalStateException("AudioRecord.ERROR_DEAD_OBJECT");
- +
- + case AudioRecord.ERROR:
- + throw new IllegalStateException("AudioRecord.ERROR");
- +
- + default:
- + return 0;
- + }
- }
-
- /**
- - * Loads a slice of the float array to the ring buffer. If the float array is longer than ring
- - * buffer's capacity, samples with lower indices in the array will be ignored.
- + * Returns a float {@link TensorBuffer} holding all the available audio samples in {@link
- + * android.media.AudioFormat#ENCODING_PCM_FLOAT} i.e. values are in the range of [-1, 1].
- */
- - public void load(float[] newData, int offset, int size) {
- - checkArgument(
- - offset + size <= newData.length,
- - String.format(
- - "Index out of range. offset (%d) + size (%d) should <= newData.length (%d)",
- - offset, size, newData.length));
- - // If buffer can't hold all the data, only keep the most recent data of size buffer.length
- - if (size > buffer.length) {
- - offset += (size - buffer.length);
- - size = buffer.length;
- - }
- - if (nextIndex + size < buffer.length) {
- - // No need to wrap nextIndex, just copy newData[offset:offset + size]
- - // to buffer[nextIndex:nextIndex+size]
- - arraycopy(newData, offset, buffer, nextIndex, size);
- - } else {
- - // Need to wrap nextIndex, perform copy in two chunks.
- - int firstChunkSize = buffer.length - nextIndex;
- - // First copy newData[offset:offset+firstChunkSize] to buffer[nextIndex:buffer.length]
- - arraycopy(newData, offset, buffer, nextIndex, firstChunkSize);
- - // Then copy newData[offset+firstChunkSize:offset+size] to buffer[0:size-firstChunkSize]
- - arraycopy(newData, offset + firstChunkSize, buffer, 0, size - firstChunkSize);
- - }
- -
- - nextIndex = (nextIndex + size) % buffer.length;
- + public TensorBuffer getTensorBuffer() {
- + ByteBuffer byteBuffer = buffer.getBuffer();
- + TensorBuffer tensorBuffer = TensorBuffer.createFixedSize(
- + new int[] {/* batch= */ 1,
- + /* modelInputLengthInFloat= */ byteBuffer.asFloatBuffer().limit()},
- + DataType.FLOAT32);
- + tensorBuffer.loadBuffer(byteBuffer);
- + return tensorBuffer;
- + }
- +
- + /* Returns the {@link TensorAudioFormat} associated with the tensor. */
- + public TensorAudioFormat getFormat() {
- + return format;
- }
-
- - public ByteBuffer getBuffer() {
- - // Create non-direct buffers. On Pixel 4, creating direct buffer costs around 0.1 ms, which
- - // can be 5x ~ 10x longer compared to non-direct buffer backed by arrays (around 0.01ms), so
- - // generally we don't create direct buffer for every invocation.
- - ByteBuffer byteBuffer = ByteBuffer.allocate(DataType.FLOAT32.byteSize() * buffer.length);
- - byteBuffer.order(ByteOrder.nativeOrder());
- - FloatBuffer result = byteBuffer.asFloatBuffer();
- - result.put(buffer, nextIndex, buffer.length - nextIndex);
- - result.put(buffer, 0, nextIndex);
- - byteBuffer.rewind();
- - return byteBuffer;
- + private TensorAudio(TensorAudioFormat format, int sampleCounts) {
- + this.format = format;
- + this.buffer = new FloatRingBuffer(sampleCounts * format.getChannels());
- }
-
- - public int getCapacity() {
- - return buffer.length;
- + /** Actual implementation of the ring buffer. */
- + private static class FloatRingBuffer {
- + private final float[] buffer;
- + private int nextIndex = 0;
- +
- + public FloatRingBuffer(int flatSize) {
- + buffer = new float[flatSize];
- + }
- +
- + /**
- + * Loads the entire float array to the ring buffer. If the float array is longer than ring
- + * buffer's capacity, samples with lower indices in the array will be ignored.
- + */
- + public void load(float[] newData) {
- + load(newData, 0, newData.length);
- + }
- +
- + /**
- + * Loads a slice of the float array to the ring buffer. If the float array is longer than
- + * ring buffer's capacity, samples with lower indices in the array will be ignored.
- + */
- + public void load(float[] newData, int offset, int size) {
- + checkArgument(offset + size <= newData.length,
- + String.format(
- + "Index out of range. offset (%d) + size (%d) should <= newData.length (%d)",
- + offset, size, newData.length));
- + // If buffer can't hold all the data, only keep the most recent data of size
- + // buffer.length
- + if (size > buffer.length) {
- + offset += (size - buffer.length);
- + size = buffer.length;
- + }
- + if (nextIndex + size < buffer.length) {
- + // No need to wrap nextIndex, just copy newData[offset:offset + size]
- + // to buffer[nextIndex:nextIndex+size]
- + arraycopy(newData, offset, buffer, nextIndex, size);
- + } else {
- + // Need to wrap nextIndex, perform copy in two chunks.
- + int firstChunkSize = buffer.length - nextIndex;
- + // First copy newData[offset:offset+firstChunkSize] to
- + // buffer[nextIndex:buffer.length]
- + arraycopy(newData, offset, buffer, nextIndex, firstChunkSize);
- + // Then copy newData[offset+firstChunkSize:offset+size] to
- + // buffer[0:size-firstChunkSize]
- + arraycopy(newData, offset + firstChunkSize, buffer, 0, size - firstChunkSize);
- + }
- +
- + nextIndex = (nextIndex + size) % buffer.length;
- + }
- +
- + public ByteBuffer getBuffer() {
- + // Create non-direct buffers. On Pixel 4, creating direct buffer costs around 0.1 ms,
- + // which can be 5x ~ 10x longer compared to non-direct buffer backed by arrays (around
- + // 0.01ms), so generally we don't create direct buffer for every invocation.
- + ByteBuffer byteBuffer =
- + ByteBuffer.allocate(DataType.FLOAT32.byteSize() * buffer.length);
- + byteBuffer.order(ByteOrder.nativeOrder());
- + FloatBuffer result = byteBuffer.asFloatBuffer();
- + result.put(buffer, nextIndex, buffer.length - nextIndex);
- + result.put(buffer, 0, nextIndex);
- + byteBuffer.rewind();
- + return byteBuffer;
- + }
- +
- + public int getCapacity() {
- + return buffer.length;
- + }
- }
- - }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/FileUtil.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/FileUtil.java
- index 776391b526b47..6090f85d99083 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/FileUtil.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/FileUtil.java
- @@ -17,6 +17,10 @@ package org.tensorflow.lite.support.common;
-
- import android.content.Context;
- import android.content.res.AssetFileDescriptor;
- +
- +import org.checkerframework.checker.nullness.qual.NonNull;
- +import org.tensorflow.lite.support.common.internal.SupportPreconditions;
- +
- import java.io.BufferedReader;
- import java.io.FileInputStream;
- import java.io.IOException;
- @@ -28,160 +32,159 @@ import java.nio.channels.FileChannel;
- import java.nio.charset.Charset;
- import java.util.ArrayList;
- import java.util.List;
- -import org.checkerframework.checker.nullness.qual.NonNull;
- -import org.tensorflow.lite.support.common.internal.SupportPreconditions;
-
- /** File I/O utilities. */
- public class FileUtil {
- - private FileUtil() {}
- -
- - /**
- - * Loads labels from the label file into a list of strings.
- - *
- - * <p>A legal label file is the plain text file whose contents are split into lines, and each line
- - * is an individual value. The file should be in assets of the context.
- - *
- - * @param context The context holds assets.
- - * @param filePath The path of the label file, relative with assets directory.
- - * @return a list of labels.
- - * @throws IOException if error occurs to open or read the file.
- - */
- - @NonNull
- - public static List<String> loadLabels(@NonNull Context context, @NonNull String filePath)
- - throws IOException {
- - return loadLabels(context, filePath, Charset.defaultCharset());
- - }
- -
- - /**
- - * Loads labels from the label file into a list of strings.
- - *
- - * <p>A legal label file is the plain text file whose contents are split into lines, and each line
- - * is an individual value. The empty lines will be ignored. The file should be in assets of the
- - * context.
- - *
- - * @param context The context holds assets.
- - * @param filePath The path of the label file, relative with assets directory.
- - * @param cs {@code Charset} to use when decoding content of label file.
- - * @return a list of labels.
- - * @throws IOException if error occurs to open or read the file.
- - */
- - @NonNull
- - public static List<String> loadLabels(
- - @NonNull Context context, @NonNull String filePath, Charset cs) throws IOException {
- - SupportPreconditions.checkNotNull(context, "Context cannot be null.");
- - SupportPreconditions.checkNotNull(filePath, "File path cannot be null.");
- - try (InputStream inputStream = context.getAssets().open(filePath)) {
- - return loadLabels(inputStream, cs);
- + private FileUtil() {}
- +
- + /**
- + * Loads labels from the label file into a list of strings.
- + *
- + * <p>A legal label file is the plain text file whose contents are split into lines, and each
- + * line is an individual value. The file should be in assets of the context.
- + *
- + * @param context The context holds assets.
- + * @param filePath The path of the label file, relative with assets directory.
- + * @return a list of labels.
- + * @throws IOException if error occurs to open or read the file.
- + */
- + @NonNull
- + public static List<String> loadLabels(@NonNull Context context, @NonNull String filePath)
- + throws IOException {
- + return loadLabels(context, filePath, Charset.defaultCharset());
- }
- - }
- -
- - /**
- - * Loads labels from an input stream of an opened label file. See details for label files in
- - * {@link FileUtil#loadLabels(Context, String)}.
- - *
- - * @param inputStream the input stream of an opened label file.
- - * @return a list of labels.
- - * @throws IOException if error occurs to open or read the file.
- - */
- - @NonNull
- - public static List<String> loadLabels(@NonNull InputStream inputStream) throws IOException {
- - return loadLabels(inputStream, Charset.defaultCharset());
- - }
- -
- - /**
- - * Loads labels from an input stream of an opened label file. See details for label files in
- - * {@link FileUtil#loadLabels(Context, String)}.
- - *
- - * @param inputStream the input stream of an opened label file.
- - * @param cs {@code Charset} to use when decoding content of label file.
- - * @return a list of labels.
- - * @throws IOException if error occurs to open or read the file.
- - */
- - @NonNull
- - public static List<String> loadLabels(@NonNull InputStream inputStream, Charset cs)
- - throws IOException {
- - List<String> labels = new ArrayList<>();
- - try (BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream, cs))) {
- - String line;
- - while ((line = reader.readLine()) != null) {
- - if (line.trim().length() > 0) {
- - labels.add(line);
- +
- + /**
- + * Loads labels from the label file into a list of strings.
- + *
- + * <p>A legal label file is the plain text file whose contents are split into lines, and each
- + * line is an individual value. The empty lines will be ignored. The file should be in assets of
- + * the context.
- + *
- + * @param context The context holds assets.
- + * @param filePath The path of the label file, relative with assets directory.
- + * @param cs {@code Charset} to use when decoding content of label file.
- + * @return a list of labels.
- + * @throws IOException if error occurs to open or read the file.
- + */
- + @NonNull
- + public static List<String> loadLabels(
- + @NonNull Context context, @NonNull String filePath, Charset cs) throws IOException {
- + SupportPreconditions.checkNotNull(context, "Context cannot be null.");
- + SupportPreconditions.checkNotNull(filePath, "File path cannot be null.");
- + try (InputStream inputStream = context.getAssets().open(filePath)) {
- + return loadLabels(inputStream, cs);
- }
- - }
- - return labels;
- }
- - }
- -
- - /**
- - * Loads a vocabulary file (a single-column text file) into a list of strings.
- - *
- - * <p>A vocabulary file is a single-column plain text file whose contents are split into lines,
- - * and each line is an individual value. The file should be in assets of the context.
- - *
- - * @param context The context holds assets.
- - * @param filePath The path of the vocabulary file, relative with assets directory.
- - * @return a list of vocabulary words.
- - * @throws IOException if error occurs to open or read the file.
- - */
- - @NonNull
- - public static List<String> loadSingleColumnTextFile(
- - @NonNull Context context, @NonNull String filePath, Charset cs) throws IOException {
- - return loadLabels(context, filePath, cs);
- - }
- -
- - /**
- - * Loads vocabulary from an input stream of an opened vocabulary file (which is a single-column
- - * text file).
- - *
- - * <p>A vocabulary file is a single-column plain text file whose contents are split into lines,
- - * and each line is an individual value. The file should be in assets of the context.
- - *
- - * @param inputStream the input stream of an opened vocabulary file.
- - * @return a list of vocabulary words.
- - * @throws IOException if error occurs to open or read the file.
- - */
- - @NonNull
- - public static List<String> loadSingleColumnTextFile(@NonNull InputStream inputStream, Charset cs)
- - throws IOException {
- - return loadLabels(inputStream, cs);
- - }
- -
- - /**
- - * Loads a file from the asset folder through memory mapping.
- - *
- - * @param context Application context to access assets.
- - * @param filePath Asset path of the file.
- - * @return the loaded memory mapped file.
- - * @throws IOException if an I/O error occurs when loading the tflite model.
- - */
- - @NonNull
- - public static MappedByteBuffer loadMappedFile(@NonNull Context context, @NonNull String filePath)
- - throws IOException {
- - SupportPreconditions.checkNotNull(context, "Context should not be null.");
- - SupportPreconditions.checkNotNull(filePath, "File path cannot be null.");
- - try (AssetFileDescriptor fileDescriptor = context.getAssets().openFd(filePath);
- - FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor())) {
- - FileChannel fileChannel = inputStream.getChannel();
- - long startOffset = fileDescriptor.getStartOffset();
- - long declaredLength = fileDescriptor.getDeclaredLength();
- - return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
- +
- + /**
- + * Loads labels from an input stream of an opened label file. See details for label files in
- + * {@link FileUtil#loadLabels(Context, String)}.
- + *
- + * @param inputStream the input stream of an opened label file.
- + * @return a list of labels.
- + * @throws IOException if error occurs to open or read the file.
- + */
- + @NonNull
- + public static List<String> loadLabels(@NonNull InputStream inputStream) throws IOException {
- + return loadLabels(inputStream, Charset.defaultCharset());
- + }
- +
- + /**
- + * Loads labels from an input stream of an opened label file. See details for label files in
- + * {@link FileUtil#loadLabels(Context, String)}.
- + *
- + * @param inputStream the input stream of an opened label file.
- + * @param cs {@code Charset} to use when decoding content of label file.
- + * @return a list of labels.
- + * @throws IOException if error occurs to open or read the file.
- + */
- + @NonNull
- + public static List<String> loadLabels(@NonNull InputStream inputStream, Charset cs)
- + throws IOException {
- + List<String> labels = new ArrayList<>();
- + try (BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream, cs))) {
- + String line;
- + while ((line = reader.readLine()) != null) {
- + if (line.trim().length() > 0) {
- + labels.add(line);
- + }
- + }
- + return labels;
- + }
- + }
- +
- + /**
- + * Loads a vocabulary file (a single-column text file) into a list of strings.
- + *
- + * <p>A vocabulary file is a single-column plain text file whose contents are split into lines,
- + * and each line is an individual value. The file should be in assets of the context.
- + *
- + * @param context The context holds assets.
- + * @param filePath The path of the vocabulary file, relative with assets directory.
- + * @return a list of vocabulary words.
- + * @throws IOException if error occurs to open or read the file.
- + */
- + @NonNull
- + public static List<String> loadSingleColumnTextFile(
- + @NonNull Context context, @NonNull String filePath, Charset cs) throws IOException {
- + return loadLabels(context, filePath, cs);
- + }
- +
- + /**
- + * Loads vocabulary from an input stream of an opened vocabulary file (which is a single-column
- + * text file).
- + *
- + * <p>A vocabulary file is a single-column plain text file whose contents are split into lines,
- + * and each line is an individual value. The file should be in assets of the context.
- + *
- + * @param inputStream the input stream of an opened vocabulary file.
- + * @return a list of vocabulary words.
- + * @throws IOException if error occurs to open or read the file.
- + */
- + @NonNull
- + public static List<String> loadSingleColumnTextFile(
- + @NonNull InputStream inputStream, Charset cs) throws IOException {
- + return loadLabels(inputStream, cs);
- + }
- +
- + /**
- + * Loads a file from the asset folder through memory mapping.
- + *
- + * @param context Application context to access assets.
- + * @param filePath Asset path of the file.
- + * @return the loaded memory mapped file.
- + * @throws IOException if an I/O error occurs when loading the tflite model.
- + */
- + @NonNull
- + public static MappedByteBuffer loadMappedFile(
- + @NonNull Context context, @NonNull String filePath) throws IOException {
- + SupportPreconditions.checkNotNull(context, "Context should not be null.");
- + SupportPreconditions.checkNotNull(filePath, "File path cannot be null.");
- + try (AssetFileDescriptor fileDescriptor = context.getAssets().openFd(filePath);
- + FileInputStream inputStream =
- + new FileInputStream(fileDescriptor.getFileDescriptor())) {
- + FileChannel fileChannel = inputStream.getChannel();
- + long startOffset = fileDescriptor.getStartOffset();
- + long declaredLength = fileDescriptor.getDeclaredLength();
- + return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
- + }
- + }
- +
- + /**
- + * Loads a binary file from the asset folder.
- + *
- + * @param context Application context to access assets.
- + * @param filePath Asset path of the file.
- + * @return the byte array for the binary file.
- + * @throws IOException if an I/O error occurs when loading file.
- + */
- + @NonNull
- + public static byte[] loadByteFromFile(@NonNull Context context, @NonNull String filePath)
- + throws IOException {
- + ByteBuffer buffer = loadMappedFile(context, filePath);
- + byte[] byteArray = new byte[buffer.remaining()];
- + buffer.get(byteArray);
- + return byteArray;
- }
- - }
- -
- - /**
- - * Loads a binary file from the asset folder.
- - *
- - * @param context Application context to access assets.
- - * @param filePath Asset path of the file.
- - * @return the byte array for the binary file.
- - * @throws IOException if an I/O error occurs when loading file.
- - */
- - @NonNull
- - public static byte[] loadByteFromFile(@NonNull Context context, @NonNull String filePath)
- - throws IOException {
- - ByteBuffer buffer = loadMappedFile(context, filePath);
- - byte[] byteArray = new byte[buffer.remaining()];
- - buffer.get(byteArray);
- - return byteArray;
- - }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/Operator.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/Operator.java
- index 38dfe8818cbbc..45dfc4d9d868b 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/Operator.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/Operator.java
- @@ -20,12 +20,11 @@ package org.tensorflow.lite.support.common;
- * @param <T> The class which Operator handles.
- */
- public interface Operator<T> {
- -
- - /**
- - * Applies an operation on a T object, returning a T object.
- - *
- - * <p>Note: The returned object could probably be the same one with given input, and given input
- - * could probably be changed.
- - */
- - T apply(T x);
- + /**
- + * Applies an operation on a T object, returning a T object.
- + *
- + * <p>Note: The returned object could probably be the same one with given input, and given input
- + * could probably be changed.
- + */
- + T apply(T x);
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/Processor.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/Processor.java
- index 9d0024b2f5887..a94adb89b8666 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/Processor.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/Processor.java
- @@ -17,5 +17,5 @@ package org.tensorflow.lite.support.common;
-
- /** Processes T object with prepared {@code Operator<T>}. */
- public interface Processor<T> {
- - T process(T input);
- + T process(T input);
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/SequentialProcessor.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/SequentialProcessor.java
- index af688c863c254..aa900b7c93d87 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/SequentialProcessor.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/SequentialProcessor.java
- @@ -15,13 +15,14 @@ limitations under the License.
-
- package org.tensorflow.lite.support.common;
-
- +import org.checkerframework.checker.nullness.qual.NonNull;
- +import org.tensorflow.lite.support.common.internal.SupportPreconditions;
- +
- import java.util.ArrayList;
- import java.util.Collections;
- import java.util.HashMap;
- import java.util.List;
- import java.util.Map;
- -import org.checkerframework.checker.nullness.qual.NonNull;
- -import org.tensorflow.lite.support.common.internal.SupportPreconditions;
-
- /**
- * A processor base class that chains a serial of {@code Operator<T>} and executes them.
- @@ -32,52 +33,50 @@ import org.tensorflow.lite.support.common.internal.SupportPreconditions;
- * @param <T> The type that the Operator is handling.
- */
- public class SequentialProcessor<T> implements Processor<T> {
- + /** List of operators added to this {@link SequentialProcessor}. */
- + protected final List<Operator<T>> operatorList;
- + /**
- + * The {@link Map} between the operator name and the corresponding op indexes in {@code
- + * operatorList}. An operator may be added multiple times into this {@link SequentialProcessor}.
- + */
- + protected final Map<String, List<Integer>> operatorIndex;
-
- - /** List of operators added to this {@link SequentialProcessor}. */
- - protected final List<Operator<T>> operatorList;
- - /**
- - * The {@link Map} between the operator name and the corresponding op indexes in {@code
- - * operatorList}. An operator may be added multiple times into this {@link SequentialProcessor}.
- - */
- - protected final Map<String, List<Integer>> operatorIndex;
- -
- - protected SequentialProcessor(Builder<T> builder) {
- - operatorList = builder.operatorList;
- - operatorIndex = Collections.unmodifiableMap(builder.operatorIndex);
- - }
- + protected SequentialProcessor(Builder<T> builder) {
- + operatorList = builder.operatorList;
- + operatorIndex = Collections.unmodifiableMap(builder.operatorIndex);
- + }
-
- - @Override
- - public T process(T x) {
- - for (Operator<T> op : operatorList) {
- - x = op.apply(x);
- + @Override
- + public T process(T x) {
- + for (Operator<T> op : operatorList) {
- + x = op.apply(x);
- + }
- + return x;
- }
- - return x;
- - }
-
- - /** The inner builder class to build a Sequential Processor. */
- - protected static class Builder<T> {
- + /** The inner builder class to build a Sequential Processor. */
- + protected static class Builder<T> {
- + private final List<Operator<T>> operatorList;
- + private final Map<String, List<Integer>> operatorIndex;
-
- - private final List<Operator<T>> operatorList;
- - private final Map<String, List<Integer>> operatorIndex;
- + protected Builder() {
- + operatorList = new ArrayList<>();
- + operatorIndex = new HashMap<>();
- + }
-
- - protected Builder() {
- - operatorList = new ArrayList<>();
- - operatorIndex = new HashMap<>();
- - }
- -
- - public Builder<T> add(@NonNull Operator<T> op) {
- - SupportPreconditions.checkNotNull(op, "Adding null Op is illegal.");
- - operatorList.add(op);
- - String operatorName = op.getClass().getName();
- - if (!operatorIndex.containsKey(operatorName)) {
- - operatorIndex.put(operatorName, new ArrayList<Integer>());
- - }
- - operatorIndex.get(operatorName).add(operatorList.size() - 1);
- - return this;
- - }
- + public Builder<T> add(@NonNull Operator<T> op) {
- + SupportPreconditions.checkNotNull(op, "Adding null Op is illegal.");
- + operatorList.add(op);
- + String operatorName = op.getClass().getName();
- + if (!operatorIndex.containsKey(operatorName)) {
- + operatorIndex.put(operatorName, new ArrayList<Integer>());
- + }
- + operatorIndex.get(operatorName).add(operatorList.size() - 1);
- + return this;
- + }
-
- - public SequentialProcessor<T> build() {
- - return new SequentialProcessor<T>(this);
- + public SequentialProcessor<T> build() {
- + return new SequentialProcessor<T>(this);
- + }
- }
- - }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/TensorOperator.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/TensorOperator.java
- index d1b7021df257c..692c2d479dcce 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/TensorOperator.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/TensorOperator.java
- @@ -21,7 +21,7 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
- * Applies some operation on TensorBuffers.
- */
- public interface TensorOperator extends Operator<TensorBuffer> {
- - /** @see Operator#apply(Object) . */
- - @Override
- - TensorBuffer apply(TensorBuffer input);
- + /** @see Operator#apply(Object) . */
- + @Override
- + TensorBuffer apply(TensorBuffer input);
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/TensorProcessor.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/TensorProcessor.java
- index 8096d0c764bab..faad66edeb04e 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/TensorProcessor.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/TensorProcessor.java
- @@ -32,37 +32,36 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
- * @see TensorProcessor#process to apply the processor on a {@code TensorBuffer}.
- */
- public class TensorProcessor extends SequentialProcessor<TensorBuffer> {
- - private TensorProcessor(Builder builder) {
- - super(builder);
- - }
- -
- - /** The Builder to create an {@link TensorProcessor}, which could be executed later. */
- - public static class Builder extends SequentialProcessor.Builder<TensorBuffer> {
- -
- - /**
- - * Creates a Builder to build {@link TensorProcessor}.
- - *
- - * @see #add(TensorOperator) to add an Op.
- - * @see #build() to complete the building process and get a built Processor.
- - */
- - public Builder() {
- - super();
- + private TensorProcessor(Builder builder) {
- + super(builder);
- }
-
- - /**
- - * Adds an {@link TensorOperator} into the Operator chain.
- - *
- - * @param op the Operator instance to be executed then.
- - */
- - public TensorProcessor.Builder add(TensorOperator op) {
- - super.add(op);
- - return this;
- - }
- + /** The Builder to create an {@link TensorProcessor}, which could be executed later. */
- + public static class Builder extends SequentialProcessor.Builder<TensorBuffer> {
- + /**
- + * Creates a Builder to build {@link TensorProcessor}.
- + *
- + * @see #add(TensorOperator) to add an Op.
- + * @see #build() to complete the building process and get a built Processor.
- + */
- + public Builder() {
- + super();
- + }
- +
- + /**
- + * Adds an {@link TensorOperator} into the Operator chain.
- + *
- + * @param op the Operator instance to be executed then.
- + */
- + public TensorProcessor.Builder add(TensorOperator op) {
- + super.add(op);
- + return this;
- + }
-
- - /** Completes the building process and gets the {@link TensorProcessor} instance. */
- - @Override
- - public TensorProcessor build() {
- - return new TensorProcessor(this);
- + /** Completes the building process and gets the {@link TensorProcessor} instance. */
- + @Override
- + public TensorProcessor build() {
- + return new TensorProcessor(this);
- + }
- }
- - }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/internal/SupportPreconditions.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/internal/SupportPreconditions.java
- index e3e962a5f8252..29faa545b71f2 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/internal/SupportPreconditions.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/internal/SupportPreconditions.java
- @@ -19,164 +19,168 @@ import org.checkerframework.checker.nullness.qual.Nullable;
-
- /** Static error checking util methods. */
- public final class SupportPreconditions {
- - /**
- - * Ensures that an object reference passed as a parameter to the calling method is not null.
- - *
- - * @param reference an object reference
- - * @return the non-null reference that was validated
- - * @throws NullPointerException if {@code reference} is null
- - */
- - public static <T extends Object> T checkNotNull(T reference) {
- - if (reference == null) {
- - throw new NullPointerException("The object reference is null.");
- + /**
- + * Ensures that an object reference passed as a parameter to the calling method is not null.
- + *
- + * @param reference an object reference
- + * @return the non-null reference that was validated
- + * @throws NullPointerException if {@code reference} is null
- + */
- + public static <T extends Object> T checkNotNull(T reference) {
- + if (reference == null) {
- + throw new NullPointerException("The object reference is null.");
- + }
- + return reference;
- }
- - return reference;
- - }
- -
- - /**
- - * Ensures that an object reference passed as a parameter to the calling method is not null.
- - *
- - * @param reference an object reference
- - * @param errorMessage the exception message to use if the check fails; will be converted to a
- - * string using {@link String#valueOf(Object)}
- - * @return the non-null reference that was validated
- - * @throws NullPointerException if {@code reference} is null
- - */
- - public static <T extends Object> T checkNotNull(T reference, @Nullable Object errorMessage) {
- - if (reference == null) {
- - throw new NullPointerException(String.valueOf(errorMessage));
- +
- + /**
- + * Ensures that an object reference passed as a parameter to the calling method is not null.
- + *
- + * @param reference an object reference
- + * @param errorMessage the exception message to use if the check fails; will be converted to a
- + * string using {@link String#valueOf(Object)}
- + * @return the non-null reference that was validated
- + * @throws NullPointerException if {@code reference} is null
- + */
- + public static <T extends Object> T checkNotNull(T reference, @Nullable Object errorMessage) {
- + if (reference == null) {
- + throw new NullPointerException(String.valueOf(errorMessage));
- + }
- + return reference;
- + }
- +
- + /**
- + * Ensures that the given String is not empty and not null.
- + *
- + * @param string the String to test
- + * @return the non-null non-empty String that was validated
- + * @throws IllegalArgumentException if {@code string} is null or empty
- + */
- + public static String checkNotEmpty(String string) {
- + if (string == null || string.length() == 0) {
- + throw new IllegalArgumentException("Given String is empty or null.");
- + }
- + return string;
- }
- - return reference;
- - }
- -
- - /**
- - * Ensures that the given String is not empty and not null.
- - *
- - * @param string the String to test
- - * @return the non-null non-empty String that was validated
- - * @throws IllegalArgumentException if {@code string} is null or empty
- - */
- - public static String checkNotEmpty(String string) {
- - if (string == null || string.length() == 0) {
- - throw new IllegalArgumentException("Given String is empty or null.");
- +
- + /**
- + * Ensures that the given String is not empty and not null.
- + *
- + * @param string the String to test
- + * @param errorMessage the exception message to use if the check fails; will be converted to a
- + * string using {@link String#valueOf(Object)}
- + * @return the non-null non-empty String that was validated
- + * @throws IllegalArgumentException if {@code string} is null or empty
- + */
- + public static String checkNotEmpty(String string, Object errorMessage) {
- + if (string == null || string.length() == 0) {
- + throw new IllegalArgumentException(String.valueOf(errorMessage));
- + }
- + return string;
- }
- - return string;
- - }
- -
- - /**
- - * Ensures that the given String is not empty and not null.
- - *
- - * @param string the String to test
- - * @param errorMessage the exception message to use if the check fails; will be converted to a
- - * string using {@link String#valueOf(Object)}
- - * @return the non-null non-empty String that was validated
- - * @throws IllegalArgumentException if {@code string} is null or empty
- - */
- - public static String checkNotEmpty(String string, Object errorMessage) {
- - if (string == null || string.length() == 0) {
- - throw new IllegalArgumentException(String.valueOf(errorMessage));
- +
- + /**
- + * Ensures the truth of an expression involving one or more parameters to the calling method.
- + *
- + * @param expression a boolean expression.
- + * @throws IllegalArgumentException if {@code expression} is false.
- + */
- + public static void checkArgument(boolean expression) {
- + if (!expression) {
- + throw new IllegalArgumentException();
- + }
- }
- - return string;
- - }
- -
- - /**
- - * Ensures the truth of an expression involving one or more parameters to the calling method.
- - *
- - * @param expression a boolean expression.
- - * @throws IllegalArgumentException if {@code expression} is false.
- - */
- - public static void checkArgument(boolean expression) {
- - if (!expression) {
- - throw new IllegalArgumentException();
- +
- + /**
- + * Ensures the truth of an expression involving one or more parameters to the calling method.
- + *
- + * @param expression a boolean expression.
- + * @param errorMessage the exception message to use if the check fails; will be converted to a
- + * string using {@link String#valueOf(Object)}.
- + * @throws IllegalArgumentException if {@code expression} is false.
- + */
- + public static void checkArgument(boolean expression, @Nullable Object errorMessage) {
- + if (!expression) {
- + throw new IllegalArgumentException(String.valueOf(errorMessage));
- + }
- }
- - }
- -
- - /**
- - * Ensures the truth of an expression involving one or more parameters to the calling method.
- - *
- - * @param expression a boolean expression.
- - * @param errorMessage the exception message to use if the check fails; will be converted to a
- - * string using {@link String#valueOf(Object)}.
- - * @throws IllegalArgumentException if {@code expression} is false.
- - */
- - public static void checkArgument(boolean expression, @Nullable Object errorMessage) {
- - if (!expression) {
- - throw new IllegalArgumentException(String.valueOf(errorMessage));
- +
- + /**
- + * Ensures that {@code index} specifies a valid <i>element</i> in an array, list or string of
- + * size
- + * {@code size}. An element index may range from zero, inclusive, to {@code size}, exclusive.
- + *
- + * @param index a user-supplied index identifying an element of an array, list or string
- + * @param size the size of that array, list or string
- + * @return the value of {@code index}
- + * @throws IndexOutOfBoundsException if {@code index} is negative or is not less than {@code
- + * size}
- + * @throws IllegalArgumentException if {@code size} is negative
- + */
- + public static int checkElementIndex(int index, int size) {
- + return checkElementIndex(index, size, "index");
- }
- - }
- -
- - /**
- - * Ensures that {@code index} specifies a valid <i>element</i> in an array, list or string of size
- - * {@code size}. An element index may range from zero, inclusive, to {@code size}, exclusive.
- - *
- - * @param index a user-supplied index identifying an element of an array, list or string
- - * @param size the size of that array, list or string
- - * @return the value of {@code index}
- - * @throws IndexOutOfBoundsException if {@code index} is negative or is not less than {@code size}
- - * @throws IllegalArgumentException if {@code size} is negative
- - */
- - public static int checkElementIndex(int index, int size) {
- - return checkElementIndex(index, size, "index");
- - }
- -
- - /**
- - * Ensures that {@code index} specifies a valid <i>element</i> in an array, list or string of size
- - * {@code size}. An element index may range from zero, inclusive, to {@code size}, exclusive.
- - *
- - * @param index a user-supplied index identifying an element of an array, list or string
- - * @param size the size of that array, list or string
- - * @param desc the text to use to describe this index in an error message
- - * @return the value of {@code index}
- - * @throws IndexOutOfBoundsException if {@code index} is negative or is not less than {@code size}
- - * @throws IllegalArgumentException if {@code size} is negative
- - */
- - public static int checkElementIndex(int index, int size, @Nullable String desc) {
- - // Carefully optimized for execution by hotspot (explanatory comment above)
- - if (index < 0 || index >= size) {
- - throw new IndexOutOfBoundsException(badElementIndex(index, size, desc));
- +
- + /**
- + * Ensures that {@code index} specifies a valid <i>element</i> in an array, list or string of
- + * size
- + * {@code size}. An element index may range from zero, inclusive, to {@code size}, exclusive.
- + *
- + * @param index a user-supplied index identifying an element of an array, list or string
- + * @param size the size of that array, list or string
- + * @param desc the text to use to describe this index in an error message
- + * @return the value of {@code index}
- + * @throws IndexOutOfBoundsException if {@code index} is negative or is not less than {@code
- + * size}
- + * @throws IllegalArgumentException if {@code size} is negative
- + */
- + public static int checkElementIndex(int index, int size, @Nullable String desc) {
- + // Carefully optimized for execution by hotspot (explanatory comment above)
- + if (index < 0 || index >= size) {
- + throw new IndexOutOfBoundsException(badElementIndex(index, size, desc));
- + }
- + return index;
- }
- - return index;
- - }
- -
- - /**
- - * Ensures the truth of an expression involving the state of the calling instance, but not
- - * involving any parameters to the calling method.
- - *
- - * @param expression a boolean expression
- - * @throws IllegalStateException if {@code expression} is false
- - */
- - public static void checkState(boolean expression) {
- - if (!expression) {
- - throw new IllegalStateException();
- +
- + /**
- + * Ensures the truth of an expression involving the state of the calling instance, but not
- + * involving any parameters to the calling method.
- + *
- + * @param expression a boolean expression
- + * @throws IllegalStateException if {@code expression} is false
- + */
- + public static void checkState(boolean expression) {
- + if (!expression) {
- + throw new IllegalStateException();
- + }
- }
- - }
- -
- - /**
- - * Ensures the truth of an expression involving the state of the calling instance, but not
- - * involving any parameters to the calling method.
- - *
- - * @param expression a boolean expression
- - * @param errorMessage the exception message to use if the check fails; will be converted to a
- - * string using {@link String#valueOf(Object)}
- - * @throws IllegalStateException if {@code expression} is false
- - */
- - public static void checkState(boolean expression, @Nullable Object errorMessage) {
- - if (!expression) {
- - throw new IllegalStateException(String.valueOf(errorMessage));
- +
- + /**
- + * Ensures the truth of an expression involving the state of the calling instance, but not
- + * involving any parameters to the calling method.
- + *
- + * @param expression a boolean expression
- + * @param errorMessage the exception message to use if the check fails; will be converted to a
- + * string using {@link String#valueOf(Object)}
- + * @throws IllegalStateException if {@code expression} is false
- + */
- + public static void checkState(boolean expression, @Nullable Object errorMessage) {
- + if (!expression) {
- + throw new IllegalStateException(String.valueOf(errorMessage));
- + }
- }
- - }
- -
- - private static String badElementIndex(int index, int size, @Nullable String desc) {
- - if (index < 0) {
- - return String.format("%s (%s) must not be negative", desc, index);
- - } else if (size < 0) {
- - throw new IllegalArgumentException("negative size: " + size);
- - } else { // index >= size
- - return String.format("%s (%s) must be less than size (%s)", desc, index, size);
- +
- + private static String badElementIndex(int index, int size, @Nullable String desc) {
- + if (index < 0) {
- + return String.format("%s (%s) must not be negative", desc, index);
- + } else if (size < 0) {
- + throw new IllegalArgumentException("negative size: " + size);
- + } else { // index >= size
- + return String.format("%s (%s) must be less than size (%s)", desc, index, size);
- + }
- }
- - }
-
- - private SupportPreconditions() {
- - throw new AssertionError("SupportPreconditions is Uninstantiable.");
- - }
- + private SupportPreconditions() {
- + throw new AssertionError("SupportPreconditions is Uninstantiable.");
- + }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/CastOp.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/CastOp.java
- index 742a1ef90994c..a14cd1f1e503d 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/CastOp.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/CastOp.java
- @@ -22,34 +22,33 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
-
- /** Casts a {@link TensorBuffer} to a specified data type. */
- public class CastOp implements TensorOperator {
- + private final DataType destinationType;
- +
- + /**
- + * Constructs a CastOp.
- + *
- + * <p>Note: For only converting type for a certain {@link TensorBuffer} on-the-fly rather than
- + * in a processor, please directly use {@link TensorBuffer#createFrom(TensorBuffer, DataType)}.
- + *
- + * <p>When this Op is executed, if the original {@link TensorBuffer} is already in {@code
- + * destinationType}, the original buffer will be directly returned.
- + *
- + * @param destinationType The type of the casted {@link TensorBuffer}.
- + * @throws IllegalArgumentException if {@code destinationType} is neither {@link DataType#UINT8}
- + * nor {@link DataType#FLOAT32}.
- + */
- + public CastOp(DataType destinationType) {
- + SupportPreconditions.checkArgument(
- + destinationType == DataType.UINT8 || destinationType == DataType.FLOAT32,
- + "Destination type " + destinationType + " is not supported.");
- + this.destinationType = destinationType;
- + }
-
- - private final DataType destinationType;
- -
- - /**
- - * Constructs a CastOp.
- - *
- - * <p>Note: For only converting type for a certain {@link TensorBuffer} on-the-fly rather than in
- - * a processor, please directly use {@link TensorBuffer#createFrom(TensorBuffer, DataType)}.
- - *
- - * <p>When this Op is executed, if the original {@link TensorBuffer} is already in {@code
- - * destinationType}, the original buffer will be directly returned.
- - *
- - * @param destinationType The type of the casted {@link TensorBuffer}.
- - * @throws IllegalArgumentException if {@code destinationType} is neither {@link DataType#UINT8}
- - * nor {@link DataType#FLOAT32}.
- - */
- - public CastOp(DataType destinationType) {
- - SupportPreconditions.checkArgument(
- - destinationType == DataType.UINT8 || destinationType == DataType.FLOAT32,
- - "Destination type " + destinationType + " is not supported.");
- - this.destinationType = destinationType;
- - }
- -
- - @Override
- - public TensorBuffer apply(TensorBuffer input) {
- - if (input.getDataType() == destinationType) {
- - return input;
- + @Override
- + public TensorBuffer apply(TensorBuffer input) {
- + if (input.getDataType() == destinationType) {
- + return input;
- + }
- + return TensorBuffer.createFrom(input, destinationType);
- }
- - return TensorBuffer.createFrom(input, destinationType);
- - }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/DequantizeOp.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/DequantizeOp.java
- index 1881747870be3..8b6d183189b7f 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/DequantizeOp.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/DequantizeOp.java
- @@ -32,9 +32,8 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
- * as 0.
- */
- public class DequantizeOp extends NormalizeOp implements TensorOperator {
- -
- - public DequantizeOp(float zeroPoint, float scale) {
- - // Quantization: f = (q - z) * s
- - super(zeroPoint, 1 / scale);
- - }
- + public DequantizeOp(float zeroPoint, float scale) {
- + // Quantization: f = (q - z) * s
- + super(zeroPoint, 1 / scale);
- + }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/NormalizeOp.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/NormalizeOp.java
- index cff4d0b55d60a..912df13b59cec 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/NormalizeOp.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/NormalizeOp.java
- @@ -26,135 +26,134 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBufferFloat;
- * Normalizes a {@link TensorBuffer} with given mean and stddev: output = (input - mean) / stddev.
- */
- public class NormalizeOp implements TensorOperator {
- + // mean.length should always be equal to stddev.length and always >= 1.
- + private final float[] mean;
- + private final float[] stddev;
- + private final int numChannels;
- + private final boolean isIdentityOp;
-
- - // mean.length should always be equal to stddev.length and always >= 1.
- - private final float[] mean;
- - private final float[] stddev;
- - private final int numChannels;
- - private final boolean isIdentityOp;
- + /**
- + * Initializes a NormalizeOp. When being called, it creates a new {@link TensorBuffer}, which
- + * satisfies:
- + *
- + * <pre>
- + * output = (input - mean) / stddev
- + * </pre>
- + *
- + * <p>In the following two cases, reset {@code mean} to 0 and {@code stddev} to 1 to bypass the
- + * normalization. <br>
- + * 1. Both {@code mean} and {code stddev} are 0. <br>
- + * 2. {@code mean} is 0 and {stddev} is Infinity.
- + *
- + * <p>Note: If {@code mean} is set to 0 and {@code stddev} is set to 1, no computation will
- + * happen, and original input will be directly returned in execution.
- + *
- + * <p>Note: The returned {@link TensorBuffer} is always a {@link DataType#FLOAT32} tensor at
- + * present, except when the input is a {@link DataType#UINT8} tensor, {@code mean} is set to 0
- + * and
- + * {@code stddev} is set to 1, so that the original {@link DataType#UINT8} tensor is returned.
- + *
- + * @param mean the mean value to be subtracted first.
- + * @param stddev the standard deviation value to divide then.
- + * @throws IllegalArgumentException if {@code stddev} is zero.
- + */
- + public NormalizeOp(float mean, float stddev) {
- + // Make exceptions to the cases that
- + // 1. Both mean and stddev are 0.0f. This may happen when reading the normalization
- + // parameters from a tensor which does not have the values populated in the metadata. The
- + // same situation may also happen to the quantization parameters.
- + // 2. mean is 0.0f and stddev is Infinity. This may happen when reading the quantization
- + // parameters from a tensor which does not have the values populated in the metadata, and
- + // then passing the parameters into the DequantizeOp. Bypass both of the two cases, by
- + // reseting stddev to 1.0f.
- + if (mean == 0.0f && (stddev == 0.0f || Float.isInfinite(stddev))) {
- + stddev = 1.0f;
- + }
-
- - /**
- - * Initializes a NormalizeOp. When being called, it creates a new {@link TensorBuffer}, which
- - * satisfies:
- - *
- - * <pre>
- - * output = (input - mean) / stddev
- - * </pre>
- - *
- - * <p>In the following two cases, reset {@code mean} to 0 and {@code stddev} to 1 to bypass the
- - * normalization. <br>
- - * 1. Both {@code mean} and {code stddev} are 0. <br>
- - * 2. {@code mean} is 0 and {stddev} is Infinity.
- - *
- - * <p>Note: If {@code mean} is set to 0 and {@code stddev} is set to 1, no computation will
- - * happen, and original input will be directly returned in execution.
- - *
- - * <p>Note: The returned {@link TensorBuffer} is always a {@link DataType#FLOAT32} tensor at
- - * present, except when the input is a {@link DataType#UINT8} tensor, {@code mean} is set to 0 and
- - * {@code stddev} is set to 1, so that the original {@link DataType#UINT8} tensor is returned.
- - *
- - * @param mean the mean value to be subtracted first.
- - * @param stddev the standard deviation value to divide then.
- - * @throws IllegalArgumentException if {@code stddev} is zero.
- - */
- - public NormalizeOp(float mean, float stddev) {
- - // Make exceptions to the cases that
- - // 1. Both mean and stddev are 0.0f. This may happen when reading the normalization parameters
- - // from a tensor which does not have the values populated in the metadata. The same situation
- - // may also happen to the quantization parameters.
- - // 2. mean is 0.0f and stddev is Infinity. This may happen when reading the quantization
- - // parameters from a tensor which does not have the values populated in the metadata, and then
- - // passing the parameters into the DequantizeOp.
- - // Bypass both of the two cases, by reseting stddev to 1.0f.
- - if (mean == 0.0f && (stddev == 0.0f || Float.isInfinite(stddev))) {
- - stddev = 1.0f;
- - }
- + SupportPreconditions.checkArgument(stddev != 0.0f, "Stddev cannot be zero.");
- + boolean meansIsZeroAndDevsIs1 = false;
- + if (mean == 0.0f && stddev == 1.0f) {
- + meansIsZeroAndDevsIs1 = true;
- + }
-
- - SupportPreconditions.checkArgument(stddev != 0.0f, "Stddev cannot be zero.");
- - boolean meansIsZeroAndDevsIs1 = false;
- - if (mean == 0.0f && stddev == 1.0f) {
- - meansIsZeroAndDevsIs1 = true;
- + this.isIdentityOp = meansIsZeroAndDevsIs1;
- + this.mean = new float[] {mean};
- + this.stddev = new float[] {stddev};
- + this.numChannels = 1;
- }
-
- - this.isIdentityOp = meansIsZeroAndDevsIs1;
- - this.mean = new float[] {mean};
- - this.stddev = new float[] {stddev};
- - this.numChannels = 1;
- - }
- -
- - /**
- - * Initializes a NormalizeOp. When being called, it creates a new {@link TensorBuffer}, which
- - * satisfies:
- - *
- - * <pre>
- - * // Pseudo code. [...][i] means a certain element whose channel id is i.
- - * output[...][i] = (input[...][i] - mean[i]) / stddev[i]
- - * </pre>
- - *
- - * <p>Note: If all values in {@code mean} are set to 0 and all {@code stddev} are set to 1, no
- - * computation will happen, and original input will be directly returned in execution.
- - *
- - * <p>Note: The returned {@link TensorBuffer} is always a {@link DataType#FLOAT32} tensor at
- - * present, except that the input is a {@link DataType#UINT8} tensor, all {@code mean} are set to
- - * 0 and all {@code stddev} are set to 1.
- - *
- - * @param mean the mean values to be subtracted first for each channel.
- - * @param stddev the standard deviation values to divide then for each channel.
- - * @throws IllegalArgumentException if any {@code stddev} is zero, or {@code mean} has different
- - * number of elements with {@code stddev}, or any of them is empty.
- - */
- - public NormalizeOp(@NonNull float[] mean, @NonNull float[] stddev) {
- - SupportPreconditions.checkNotNull(mean, "Mean cannot be null");
- - SupportPreconditions.checkNotNull(stddev, "Stddev cannot be null");
- - SupportPreconditions.checkArgument(
- - mean.length == stddev.length,
- - "Per channel normalization requires same number of means and stddevs");
- - SupportPreconditions.checkArgument(mean.length > 0, "Means and stddevs are empty.");
- - this.mean = mean.clone();
- - this.stddev = stddev.clone();
- - boolean allMeansAreZeroAndAllDevsAre1 = true;
- - this.numChannels = mean.length;
- - for (int i = 0; i < numChannels; i++) {
- - SupportPreconditions.checkArgument(this.stddev[i] != 0, "Stddev cannot be zero.");
- - if (this.stddev[i] != 1 || this.mean[i] != 0) {
- - allMeansAreZeroAndAllDevsAre1 = false;
- - }
- + /**
- + * Initializes a NormalizeOp. When being called, it creates a new {@link TensorBuffer}, which
- + * satisfies:
- + *
- + * <pre>
- + * // Pseudo code. [...][i] means a certain element whose channel id is i.
- + * output[...][i] = (input[...][i] - mean[i]) / stddev[i]
- + * </pre>
- + *
- + * <p>Note: If all values in {@code mean} are set to 0 and all {@code stddev} are set to 1, no
- + * computation will happen, and original input will be directly returned in execution.
- + *
- + * <p>Note: The returned {@link TensorBuffer} is always a {@link DataType#FLOAT32} tensor at
- + * present, except that the input is a {@link DataType#UINT8} tensor, all {@code mean} are set
- + * to 0 and all {@code stddev} are set to 1.
- + *
- + * @param mean the mean values to be subtracted first for each channel.
- + * @param stddev the standard deviation values to divide then for each channel.
- + * @throws IllegalArgumentException if any {@code stddev} is zero, or {@code mean} has different
- + * number of elements with {@code stddev}, or any of them is empty.
- + */
- + public NormalizeOp(@NonNull float[] mean, @NonNull float[] stddev) {
- + SupportPreconditions.checkNotNull(mean, "Mean cannot be null");
- + SupportPreconditions.checkNotNull(stddev, "Stddev cannot be null");
- + SupportPreconditions.checkArgument(mean.length == stddev.length,
- + "Per channel normalization requires same number of means and stddevs");
- + SupportPreconditions.checkArgument(mean.length > 0, "Means and stddevs are empty.");
- + this.mean = mean.clone();
- + this.stddev = stddev.clone();
- + boolean allMeansAreZeroAndAllDevsAre1 = true;
- + this.numChannels = mean.length;
- + for (int i = 0; i < numChannels; i++) {
- + SupportPreconditions.checkArgument(this.stddev[i] != 0, "Stddev cannot be zero.");
- + if (this.stddev[i] != 1 || this.mean[i] != 0) {
- + allMeansAreZeroAndAllDevsAre1 = false;
- + }
- + }
- + this.isIdentityOp = allMeansAreZeroAndAllDevsAre1;
- }
- - this.isIdentityOp = allMeansAreZeroAndAllDevsAre1;
- - }
-
- - /**
- - * Applies the defined normalization on given tensor and returns the result.
- - *
- - * <p>Note: {@code input} is possibly the same instance with the output.
- - *
- - * @param input input tensor. It may be the same instance with the output.
- - * @return output tensor.
- - */
- - @Override
- - @NonNull
- - public TensorBuffer apply(@NonNull TensorBuffer input) {
- - if (isIdentityOp) {
- - return input;
- - }
- - int[] shape = input.getShape();
- - SupportPreconditions.checkArgument(
- - numChannels == 1 || (shape.length != 0 && shape[shape.length - 1] == numChannels),
- - "Number of means (stddevs) is not same with number of channels (size of last axis).");
- - // TODO(136750944): Eliminate the array copy here.
- - float[] values = input.getFloatArray();
- - int j = 0;
- - for (int i = 0; i < values.length; i++) {
- - values[i] = (values[i] - mean[j]) / stddev[j];
- - j = (j + 1) % numChannels;
- - }
- - TensorBuffer output;
- - if (input.isDynamic()) {
- - output = TensorBufferFloat.createDynamic(DataType.FLOAT32);
- - } else {
- - output = TensorBufferFloat.createFixedSize(shape, DataType.FLOAT32);
- + /**
- + * Applies the defined normalization on given tensor and returns the result.
- + *
- + * <p>Note: {@code input} is possibly the same instance with the output.
- + *
- + * @param input input tensor. It may be the same instance with the output.
- + * @return output tensor.
- + */
- + @Override
- + @NonNull
- + public TensorBuffer apply(@NonNull TensorBuffer input) {
- + if (isIdentityOp) {
- + return input;
- + }
- + int[] shape = input.getShape();
- + SupportPreconditions.checkArgument(
- + numChannels == 1 || (shape.length != 0 && shape[shape.length - 1] == numChannels),
- + "Number of means (stddevs) is not same with number of channels (size of last axis).");
- + // TODO(136750944): Eliminate the array copy here.
- + float[] values = input.getFloatArray();
- + int j = 0;
- + for (int i = 0; i < values.length; i++) {
- + values[i] = (values[i] - mean[j]) / stddev[j];
- + j = (j + 1) % numChannels;
- + }
- + TensorBuffer output;
- + if (input.isDynamic()) {
- + output = TensorBufferFloat.createDynamic(DataType.FLOAT32);
- + } else {
- + output = TensorBufferFloat.createFixedSize(shape, DataType.FLOAT32);
- + }
- + output.loadArray(values, shape);
- + return output;
- }
- - output.loadArray(values, shape);
- - return output;
- - }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/QuantizeOp.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/QuantizeOp.java
- index 8b3e82aee13ef..84cb856fd4ed9 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/QuantizeOp.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/QuantizeOp.java
- @@ -33,9 +33,8 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
- * as 0.
- */
- public class QuantizeOp extends NormalizeOp implements TensorOperator {
- -
- - public QuantizeOp(float zeroPoint, float scale) {
- - // Quantization: f = (q - z) * s, i.e. q = f / s + z = (f - (-z * s)) / s
- - super(-zeroPoint * scale, scale);
- - }
- + public QuantizeOp(float zeroPoint, float scale) {
- + // Quantization: f = (q - z) * s, i.e. q = f / s + z = (f - (-z * s)) / s
- + super(-zeroPoint * scale, scale);
- + }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/BitmapContainer.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/BitmapContainer.java
- index 9bee78d139efa..f9b6a1f874bff 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/BitmapContainer.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/BitmapContainer.java
- @@ -21,67 +21,67 @@ import static org.tensorflow.lite.support.common.internal.SupportPreconditions.c
- import android.graphics.Bitmap;
- import android.graphics.Bitmap.Config;
- import android.media.Image;
- +
- import org.tensorflow.lite.DataType;
- import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
-
- /** Holds a {@link Bitmap} and converts it to other image formats as needed. */
- final class BitmapContainer implements ImageContainer {
- -
- - private final Bitmap bitmap;
- -
- - /**
- - * Creates a {@link BitmapContainer} object with ARGB_8888 {@link Bitmap}.
- - *
- - * @throws IllegalArgumentException if the bitmap configuration is not ARGB_8888
- - */
- - static BitmapContainer create(Bitmap bitmap) {
- - return new BitmapContainer(bitmap);
- - }
- -
- - private BitmapContainer(Bitmap bitmap) {
- - checkNotNull(bitmap, "Cannot load null bitmap.");
- - checkArgument(
- - bitmap.getConfig().equals(Config.ARGB_8888), "Only supports loading ARGB_8888 bitmaps.");
- - this.bitmap = bitmap;
- - }
- -
- - @Override
- - public BitmapContainer clone() {
- - return create(bitmap.copy(bitmap.getConfig(), bitmap.isMutable()));
- - }
- -
- - @Override
- - public Bitmap getBitmap() {
- - // Not making a defensive copy for performance considerations. During image processing,
- - // users may need to set and get the bitmap many times.
- - return bitmap;
- - }
- -
- - @Override
- - public TensorBuffer getTensorBuffer(DataType dataType) {
- - TensorBuffer buffer = TensorBuffer.createDynamic(dataType);
- - ImageConversions.convertBitmapToTensorBuffer(bitmap, buffer);
- - return buffer;
- - }
- -
- - @Override
- - public Image getMediaImage() {
- - throw new UnsupportedOperationException(
- - "Converting from Bitmap to android.media.Image is unsupported.");
- - }
- -
- - @Override
- - public int getWidth() {
- - return bitmap.getWidth();
- - }
- -
- - @Override
- - public int getHeight() {
- - return bitmap.getHeight();
- - }
- -
- - @Override
- - public ColorSpaceType getColorSpaceType() {
- - return ColorSpaceType.fromBitmapConfig(bitmap.getConfig());
- - }
- + private final Bitmap bitmap;
- +
- + /**
- + * Creates a {@link BitmapContainer} object with ARGB_8888 {@link Bitmap}.
- + *
- + * @throws IllegalArgumentException if the bitmap configuration is not ARGB_8888
- + */
- + static BitmapContainer create(Bitmap bitmap) {
- + return new BitmapContainer(bitmap);
- + }
- +
- + private BitmapContainer(Bitmap bitmap) {
- + checkNotNull(bitmap, "Cannot load null bitmap.");
- + checkArgument(bitmap.getConfig().equals(Config.ARGB_8888),
- + "Only supports loading ARGB_8888 bitmaps.");
- + this.bitmap = bitmap;
- + }
- +
- + @Override
- + public BitmapContainer clone() {
- + return create(bitmap.copy(bitmap.getConfig(), bitmap.isMutable()));
- + }
- +
- + @Override
- + public Bitmap getBitmap() {
- + // Not making a defensive copy for performance considerations. During image processing,
- + // users may need to set and get the bitmap many times.
- + return bitmap;
- + }
- +
- + @Override
- + public TensorBuffer getTensorBuffer(DataType dataType) {
- + TensorBuffer buffer = TensorBuffer.createDynamic(dataType);
- + ImageConversions.convertBitmapToTensorBuffer(bitmap, buffer);
- + return buffer;
- + }
- +
- + @Override
- + public Image getMediaImage() {
- + throw new UnsupportedOperationException(
- + "Converting from Bitmap to android.media.Image is unsupported.");
- + }
- +
- + @Override
- + public int getWidth() {
- + return bitmap.getWidth();
- + }
- +
- + @Override
- + public int getHeight() {
- + return bitmap.getHeight();
- + }
- +
- + @Override
- + public ColorSpaceType getColorSpaceType() {
- + return ColorSpaceType.fromBitmapConfig(bitmap.getConfig());
- + }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/BoundingBoxUtil.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/BoundingBoxUtil.java
- index 8571d6227e136..a2e833b68d6d0 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/BoundingBoxUtil.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/BoundingBoxUtil.java
- @@ -18,13 +18,15 @@ package org.tensorflow.lite.support.image;
- import static org.tensorflow.lite.support.common.internal.SupportPreconditions.checkArgument;
-
- import android.graphics.RectF;
- +
- +import org.tensorflow.lite.DataType;
- +import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
- +
- import java.nio.ByteBuffer;
- import java.nio.FloatBuffer;
- import java.util.ArrayList;
- import java.util.Arrays;
- import java.util.List;
- -import org.tensorflow.lite.DataType;
- -import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
-
- /**
- * Helper class for converting values that represents bounding boxes into rectangles.
- @@ -37,207 +39,186 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
- * elements in each type is configurable as well.
- */
- public final class BoundingBoxUtil {
- + /** Denotes how a bounding box is represented. */
- + public enum Type {
- + /**
- + * Represents the bounding box by using the combination of boundaries, {left, top, right,
- + * bottom}. The default order is {left, top, right, bottom}. Other orders can be indicated
- + * by an index array.
- + */
- + BOUNDARIES,
- + /**
- + * Represents the bounding box by using the upper_left corner, width and height. The default
- + * order is {upper_left_x, upper_left_y, width, height}. Other orders can be indicated by an
- + * index array.
- + */
- + UPPER_LEFT,
- + /**
- + * Represents the bounding box by using the center of the box, width and height. The default
- + * order is {center_x, center_y, width, height}. Other orders can be indicated by an index
- + * array.
- + */
- + CENTER,
- + }
- +
- + /** Denotes if the coordinates are actual pixels or relative ratios. */
- + public enum CoordinateType {
- + /** The coordinates are relative ratios in range [0, 1]. */
- + RATIO,
- + /** The coordinates are actual pixel values. */
- + PIXEL
- + }
-
- - /** Denotes how a bounding box is represented. */
- - public enum Type {
- - /**
- - * Represents the bounding box by using the combination of boundaries, {left, top, right,
- - * bottom}. The default order is {left, top, right, bottom}. Other orders can be indicated by an
- - * index array.
- - */
- - BOUNDARIES,
- - /**
- - * Represents the bounding box by using the upper_left corner, width and height. The default
- - * order is {upper_left_x, upper_left_y, width, height}. Other orders can be indicated by an
- - * index array.
- - */
- - UPPER_LEFT,
- /**
- - * Represents the bounding box by using the center of the box, width and height. The default
- - * order is {center_x, center_y, width, height}. Other orders can be indicated by an index
- - * array.
- + * Creates a list of bounding boxes from a {@link TensorBuffer} which represents bounding boxes.
- + *
- + * @param tensor holds the data representing some boxes.
- + * @param valueIndex denotes the order of the elements defined in each bounding box type. An
- + * empty
- + * index array represent the default order of each bounding box type. For example, to denote
- + * the default order of BOUNDARIES, {left, top, right, bottom}, the index should be {0, 1,
- + * 2, 3}. To denote the order {left, right, top, bottom}, the order should be {0, 2, 1, 3}.
- + * <p>The index array can be applied to all bounding box types to adjust the order of their
- + * corresponding underlying elements.
- + * @param boundingBoxAxis specifies the index of the dimension that represents bounding box. The
- + * size of that dimension is required to be 4. Index here starts from 0. For example, if the
- + * tensor has shape 4x10, the axis for bounding boxes is likely to be 0. Negative axis is
- + * also supported: -1 gives the last axis and -2 gives the second, .etc. theFor shape 10x4, the
- + * axis is likely to be 1 (or -1, equivalently).
- + * @param type defines how values should be converted into boxes. See {@link Type}
- + * @param coordinateType defines how values are interpreted to coordinates. See {@link
- + * CoordinateType}
- + * @param height the height of the image which the boxes belong to. Only has effects when {@code
- + * coordinateType} is {@link CoordinateType#RATIO}
- + * @param width the width of the image which the boxes belong to. Only has effects when {@code
- + * coordinateType} is {@link CoordinateType#RATIO}
- + * @return A list of bounding boxes that the {@code tensor} represents. All dimensions except
- + * {@code boundingBoxAxis} will be collapsed with order kept. For example, given {@code
- + * tensor} with shape {1, 4, 10, 2} and {@code boundingBoxAxis = 1}, The result will be a
- + * list of 20 bounding boxes.
- + * @throws IllegalArgumentException if size of bounding box dimension (set by {@code
- + * boundingBoxAxis}) is not 4.
- + * @throws IllegalArgumentException if {@code boundingBoxAxis} is not in {@code (-(D+1), D)}
- + * where
- + * {@code D} is the number of dimensions of the {@code tensor}.
- + * @throws IllegalArgumentException if {@code tensor} has data type other than {@link
- + * DataType#FLOAT32}.
- */
- - CENTER,
- - }
- -
- - /** Denotes if the coordinates are actual pixels or relative ratios. */
- - public enum CoordinateType {
- - /** The coordinates are relative ratios in range [0, 1]. */
- - RATIO,
- - /** The coordinates are actual pixel values. */
- - PIXEL
- - }
- -
- - /**
- - * Creates a list of bounding boxes from a {@link TensorBuffer} which represents bounding boxes.
- - *
- - * @param tensor holds the data representing some boxes.
- - * @param valueIndex denotes the order of the elements defined in each bounding box type. An empty
- - * index array represent the default order of each bounding box type. For example, to denote
- - * the default order of BOUNDARIES, {left, top, right, bottom}, the index should be {0, 1, 2,
- - * 3}. To denote the order {left, right, top, bottom}, the order should be {0, 2, 1, 3}.
- - * <p>The index array can be applied to all bounding box types to adjust the order of their
- - * corresponding underlying elements.
- - * @param boundingBoxAxis specifies the index of the dimension that represents bounding box. The
- - * size of that dimension is required to be 4. Index here starts from 0. For example, if the
- - * tensor has shape 4x10, the axis for bounding boxes is likely to be 0. Negative axis is also
- - * supported: -1 gives the last axis and -2 gives the second, .etc. theFor shape 10x4, the
- - * axis is likely to be 1 (or -1, equivalently).
- - * @param type defines how values should be converted into boxes. See {@link Type}
- - * @param coordinateType defines how values are interpreted to coordinates. See {@link
- - * CoordinateType}
- - * @param height the height of the image which the boxes belong to. Only has effects when {@code
- - * coordinateType} is {@link CoordinateType#RATIO}
- - * @param width the width of the image which the boxes belong to. Only has effects when {@code
- - * coordinateType} is {@link CoordinateType#RATIO}
- - * @return A list of bounding boxes that the {@code tensor} represents. All dimensions except
- - * {@code boundingBoxAxis} will be collapsed with order kept. For example, given {@code
- - * tensor} with shape {1, 4, 10, 2} and {@code boundingBoxAxis = 1}, The result will be a list
- - * of 20 bounding boxes.
- - * @throws IllegalArgumentException if size of bounding box dimension (set by {@code
- - * boundingBoxAxis}) is not 4.
- - * @throws IllegalArgumentException if {@code boundingBoxAxis} is not in {@code (-(D+1), D)} where
- - * {@code D} is the number of dimensions of the {@code tensor}.
- - * @throws IllegalArgumentException if {@code tensor} has data type other than {@link
- - * DataType#FLOAT32}.
- - */
- - public static List<RectF> convert(
- - TensorBuffer tensor,
- - int[] valueIndex,
- - int boundingBoxAxis,
- - Type type,
- - CoordinateType coordinateType,
- - int height,
- - int width) {
- - int[] shape = tensor.getShape();
- - checkArgument(
- - boundingBoxAxis >= -shape.length && boundingBoxAxis < shape.length,
- - String.format(
- - "Axis %d is not in range (-(D+1), D), where D is the number of dimensions of input"
- - + " tensor (shape=%s)",
- - boundingBoxAxis, Arrays.toString(shape)));
- - if (boundingBoxAxis < 0) {
- - boundingBoxAxis = shape.length + boundingBoxAxis;
- - }
- - checkArgument(
- - shape[boundingBoxAxis] == 4,
- - String.format(
- - "Size of bounding box dimension %d is not 4. Got %d in shape %s",
- - boundingBoxAxis, shape[boundingBoxAxis], Arrays.toString(shape)));
- - checkArgument(
- - valueIndex.length == 4,
- - String.format(
- - "Bounding box index array length %d is not 4. Got index array %s",
- - valueIndex.length, Arrays.toString(valueIndex)));
- - checkArgument(
- - tensor.getDataType() == DataType.FLOAT32,
- - "Bounding Boxes only create from FLOAT32 buffers. Got: " + tensor.getDataType().name());
- - List<RectF> boundingBoxList = new ArrayList<>();
- - // Collapse dimensions to {a, 4, b}. So each bounding box could be represent as (i, j), and its
- - // four values are (i, k, j), where 0 <= k < 4. We can compute the 4 flattened index by
- - // i * 4b + k * b + j.
- - int a = 1;
- - for (int i = 0; i < boundingBoxAxis; i++) {
- - a *= shape[i];
- + public static List<RectF> convert(TensorBuffer tensor, int[] valueIndex, int boundingBoxAxis,
- + Type type, CoordinateType coordinateType, int height, int width) {
- + int[] shape = tensor.getShape();
- + checkArgument(boundingBoxAxis >= -shape.length && boundingBoxAxis < shape.length,
- + String.format(
- + "Axis %d is not in range (-(D+1), D), where D is the number of dimensions of input"
- + + " tensor (shape=%s)",
- + boundingBoxAxis, Arrays.toString(shape)));
- + if (boundingBoxAxis < 0) {
- + boundingBoxAxis = shape.length + boundingBoxAxis;
- + }
- + checkArgument(shape[boundingBoxAxis] == 4,
- + String.format("Size of bounding box dimension %d is not 4. Got %d in shape %s",
- + boundingBoxAxis, shape[boundingBoxAxis], Arrays.toString(shape)));
- + checkArgument(valueIndex.length == 4,
- + String.format("Bounding box index array length %d is not 4. Got index array %s",
- + valueIndex.length, Arrays.toString(valueIndex)));
- + checkArgument(tensor.getDataType() == DataType.FLOAT32,
- + "Bounding Boxes only create from FLOAT32 buffers. Got: "
- + + tensor.getDataType().name());
- + List<RectF> boundingBoxList = new ArrayList<>();
- + // Collapse dimensions to {a, 4, b}. So each bounding box could be represent as (i, j), and
- + // its four values are (i, k, j), where 0 <= k < 4. We can compute the 4 flattened index by
- + // i * 4b + k * b + j.
- + int a = 1;
- + for (int i = 0; i < boundingBoxAxis; i++) {
- + a *= shape[i];
- + }
- + int b = 1;
- + for (int i = boundingBoxAxis + 1; i < shape.length; i++) {
- + b *= shape[i];
- + }
- + float[] values = new float[4];
- + ByteBuffer byteBuffer = tensor.getBuffer();
- + byteBuffer.rewind();
- + FloatBuffer floatBuffer = byteBuffer.asFloatBuffer();
- + for (int i = 0; i < a; i++) {
- + for (int j = 0; j < b; j++) {
- + for (int k = 0; k < 4; k++) {
- + values[k] = floatBuffer.get((i * 4 + k) * b + j);
- + }
- + boundingBoxList.add(convertOneBoundingBox(
- + values, valueIndex, type, coordinateType, height, width));
- + }
- + }
- + byteBuffer.rewind();
- + return boundingBoxList;
- }
- - int b = 1;
- - for (int i = boundingBoxAxis + 1; i < shape.length; i++) {
- - b *= shape[i];
- +
- + private static RectF convertOneBoundingBox(float[] values, int[] valueIndex, Type type,
- + CoordinateType coordinateType, int height, int width) {
- + float[] orderedValues = new float[4];
- + for (int i = 0; i < 4; i++) {
- + orderedValues[i] = values[valueIndex[i]];
- + }
- + return convertOneBoundingBox(orderedValues, type, coordinateType, height, width);
- }
- - float[] values = new float[4];
- - ByteBuffer byteBuffer = tensor.getBuffer();
- - byteBuffer.rewind();
- - FloatBuffer floatBuffer = byteBuffer.asFloatBuffer();
- - for (int i = 0; i < a; i++) {
- - for (int j = 0; j < b; j++) {
- - for (int k = 0; k < 4; k++) {
- - values[k] = floatBuffer.get((i * 4 + k) * b + j);
- +
- + private static RectF convertOneBoundingBox(
- + float[] values, Type type, CoordinateType coordinateType, int height, int width) {
- + switch (type) {
- + case BOUNDARIES:
- + return convertFromBoundaries(values, coordinateType, height, width);
- + case UPPER_LEFT:
- + return convertFromUpperLeft(values, coordinateType, height, width);
- + case CENTER:
- + return convertFromCenter(values, coordinateType, height, width);
- }
- - boundingBoxList.add(
- - convertOneBoundingBox(values, valueIndex, type, coordinateType, height, width));
- - }
- + throw new IllegalArgumentException("Cannot recognize BoundingBox.Type " + type);
- }
- - byteBuffer.rewind();
- - return boundingBoxList;
- - }
- -
- - private static RectF convertOneBoundingBox(
- - float[] values,
- - int[] valueIndex,
- - Type type,
- - CoordinateType coordinateType,
- - int height,
- - int width) {
- - float[] orderedValues = new float[4];
- - for (int i = 0; i < 4; i++) {
- - orderedValues[i] = values[valueIndex[i]];
- +
- + private static RectF convertFromBoundaries(
- + float[] values, CoordinateType coordinateType, int imageHeight, int imageWidth) {
- + float left = values[0];
- + float top = values[1];
- + float right = values[2];
- + float bottom = values[3];
- + return getRectF(left, top, right, bottom, imageHeight, imageWidth, coordinateType);
- + }
- +
- + private static RectF convertFromUpperLeft(
- + float[] values, CoordinateType coordinateType, int imageHeight, int imageWidth) {
- + float left = values[0];
- + float top = values[1];
- + float right = values[0] + values[2];
- + float bottom = values[1] + values[3];
- + return getRectF(left, top, right, bottom, imageHeight, imageWidth, coordinateType);
- }
- - return convertOneBoundingBox(orderedValues, type, coordinateType, height, width);
- - }
- -
- - private static RectF convertOneBoundingBox(
- - float[] values, Type type, CoordinateType coordinateType, int height, int width) {
- - switch (type) {
- - case BOUNDARIES:
- - return convertFromBoundaries(values, coordinateType, height, width);
- - case UPPER_LEFT:
- - return convertFromUpperLeft(values, coordinateType, height, width);
- - case CENTER:
- - return convertFromCenter(values, coordinateType, height, width);
- +
- + private static RectF convertFromCenter(
- + float[] values, CoordinateType coordinateType, int imageHeight, int imageWidth) {
- + float centerX = values[0];
- + float centerY = values[1];
- + float w = values[2];
- + float h = values[3];
- +
- + float left = centerX - w / 2;
- + float top = centerY - h / 2;
- + float right = centerX + w / 2;
- + float bottom = centerY + h / 2;
- + return getRectF(left, top, right, bottom, imageHeight, imageWidth, coordinateType);
- }
- - throw new IllegalArgumentException("Cannot recognize BoundingBox.Type " + type);
- - }
- -
- - private static RectF convertFromBoundaries(
- - float[] values, CoordinateType coordinateType, int imageHeight, int imageWidth) {
- - float left = values[0];
- - float top = values[1];
- - float right = values[2];
- - float bottom = values[3];
- - return getRectF(left, top, right, bottom, imageHeight, imageWidth, coordinateType);
- - }
- -
- - private static RectF convertFromUpperLeft(
- - float[] values, CoordinateType coordinateType, int imageHeight, int imageWidth) {
- - float left = values[0];
- - float top = values[1];
- - float right = values[0] + values[2];
- - float bottom = values[1] + values[3];
- - return getRectF(left, top, right, bottom, imageHeight, imageWidth, coordinateType);
- - }
- -
- - private static RectF convertFromCenter(
- - float[] values, CoordinateType coordinateType, int imageHeight, int imageWidth) {
- - float centerX = values[0];
- - float centerY = values[1];
- - float w = values[2];
- - float h = values[3];
- -
- - float left = centerX - w / 2;
- - float top = centerY - h / 2;
- - float right = centerX + w / 2;
- - float bottom = centerY + h / 2;
- - return getRectF(left, top, right, bottom, imageHeight, imageWidth, coordinateType);
- - }
- -
- - private static RectF getRectF(
- - float left,
- - float top,
- - float right,
- - float bottom,
- - int imageHeight,
- - int imageWidth,
- - CoordinateType coordinateType) {
- - if (coordinateType == CoordinateType.PIXEL) {
- - return new RectF(left, top, right, bottom);
- - } else if (coordinateType == CoordinateType.RATIO) {
- - return new RectF(
- - left * imageWidth, top * imageHeight, right * imageWidth, bottom * imageHeight);
- - } else {
- - throw new IllegalArgumentException("Cannot convert coordinate type " + coordinateType);
- +
- + private static RectF getRectF(float left, float top, float right, float bottom, int imageHeight,
- + int imageWidth, CoordinateType coordinateType) {
- + if (coordinateType == CoordinateType.PIXEL) {
- + return new RectF(left, top, right, bottom);
- + } else if (coordinateType == CoordinateType.RATIO) {
- + return new RectF(
- + left * imageWidth, top * imageHeight, right * imageWidth, bottom * imageHeight);
- + } else {
- + throw new IllegalArgumentException("Cannot convert coordinate type " + coordinateType);
- + }
- }
- - }
-
- - // Private constructor to prevent initialization.
- - private BoundingBoxUtil() {}
- + // Private constructor to prevent initialization.
- + private BoundingBoxUtil() {}
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ColorSpaceType.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ColorSpaceType.java
- index 457bcf1da1de3..716cacdf7bf51 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ColorSpaceType.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ColorSpaceType.java
- @@ -20,354 +20,351 @@ import static org.tensorflow.lite.support.common.internal.SupportPreconditions.c
- import android.graphics.Bitmap;
- import android.graphics.Bitmap.Config;
- import android.graphics.ImageFormat;
- -import java.util.Arrays;
- +
- import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
-
- +import java.util.Arrays;
- +
- /** Represents the type of color space of an image. */
- public enum ColorSpaceType {
- - /** Each pixel has red, green, and blue color components. */
- - RGB(0) {
- -
- - // The channel axis should always be 3 for RGB images.
- - private static final int CHANNEL_VALUE = 3;
- -
- - @Override
- - Bitmap convertTensorBufferToBitmap(TensorBuffer buffer) {
- - return ImageConversions.convertRgbTensorBufferToBitmap(buffer);
- + /** Each pixel has red, green, and blue color components. */
- + RGB(0) {
- + // The channel axis should always be 3 for RGB images.
- + private static final int CHANNEL_VALUE = 3;
- +
- + @Override
- + Bitmap convertTensorBufferToBitmap(TensorBuffer buffer) {
- + return ImageConversions.convertRgbTensorBufferToBitmap(buffer);
- + }
- +
- + @Override
- + int getChannelValue() {
- + return CHANNEL_VALUE;
- + }
- +
- + @Override
- + int[] getNormalizedShape(int[] shape) {
- + switch (shape.length) {
- + // The shape is in (h, w, c) format.
- + case 3:
- + return insertValue(shape, BATCH_DIM, BATCH_VALUE);
- + case 4:
- + return shape;
- + default:
- + throw new IllegalArgumentException(getShapeInfoMessage()
- + + "The provided image shape is " + Arrays.toString(shape));
- + }
- + }
- +
- + @Override
- + int getNumElements(int height, int width) {
- + return height * width * CHANNEL_VALUE;
- + }
- +
- + @Override
- + String getShapeInfoMessage() {
- + return "The shape of a RGB image should be (h, w, c) or (1, h, w, c), and channels"
- + + " representing R, G, B in order. ";
- + }
- +
- + @Override
- + Config toBitmapConfig() {
- + return Config.ARGB_8888;
- + }
- + },
- +
- + /** Each pixel is a single element representing only the amount of light. */
- + GRAYSCALE(1) {
- + // The channel axis should always be 1 for grayscale images.
- + private static final int CHANNEL_VALUE = 1;
- +
- + @Override
- + Bitmap convertTensorBufferToBitmap(TensorBuffer buffer) {
- + return ImageConversions.convertGrayscaleTensorBufferToBitmap(buffer);
- + }
- +
- + @Override
- + int getChannelValue() {
- + return CHANNEL_VALUE;
- + }
- +
- + @Override
- + int[] getNormalizedShape(int[] shape) {
- + switch (shape.length) {
- + // The shape is in (h, w) format.
- + case 2:
- + int[] shapeWithBatch = insertValue(shape, BATCH_DIM, BATCH_VALUE);
- + return insertValue(shapeWithBatch, CHANNEL_DIM, CHANNEL_VALUE);
- + case 4:
- + return shape;
- + default:
- + // (1, h, w) and (h, w, 1) are potential grayscale image shapes. However, since
- + // they both have three dimensions, it will require extra info to differentiate
- + // between them. Since we haven't encountered real use cases of these two
- + // shapes, they are not supported at this moment to avoid confusion. We may want
- + // to revisit it in the future.
- + throw new IllegalArgumentException(getShapeInfoMessage()
- + + "The provided image shape is " + Arrays.toString(shape));
- + }
- + }
- +
- + @Override
- + int getNumElements(int height, int width) {
- + return height * width;
- + }
- +
- + @Override
- + String getShapeInfoMessage() {
- + return "The shape of a grayscale image should be (h, w) or (1, h, w, 1). ";
- + }
- +
- + @Override
- + Config toBitmapConfig() {
- + return Config.ALPHA_8;
- + }
- + },
- +
- + /** YUV420sp format, encoded as "YYYYYYYY UVUV". */
- + NV12(2) {
- + @Override
- + int getNumElements(int height, int width) {
- + return getYuv420NumElements(height, width);
- + }
- + },
- +
- + /**
- + * YUV420sp format, encoded as "YYYYYYYY VUVU", the standard picture format on Android Camera1
- + * preview.
- + */
- + NV21(3) {
- + @Override
- + int getNumElements(int height, int width) {
- + return getYuv420NumElements(height, width);
- + }
- + },
- +
- + /** YUV420p format, encoded as "YYYYYYYY VV UU". */
- + YV12(4) {
- + @Override
- + int getNumElements(int height, int width) {
- + return getYuv420NumElements(height, width);
- + }
- + },
- +
- + /** YUV420p format, encoded as "YYYYYYYY UU VV". */
- + YV21(5) {
- + @Override
- + int getNumElements(int height, int width) {
- + return getYuv420NumElements(height, width);
- + }
- + },
- +
- + /**
- + * YUV420 format corresponding to {@link android.graphics.ImageFormat#YUV_420_888}. The actual
- + * encoding format (i.e. NV12 / Nv21 / YV12 / YV21) depends on the implementation of the image.
- + *
- + * <p>Use this format only when you load an {@link android.media.Image}.
- + */
- + YUV_420_888(6) {
- + @Override
- + int getNumElements(int height, int width) {
- + return getYuv420NumElements(height, width);
- + }
- + };
- +
- + private static final int BATCH_DIM = 0; // The first element of the normalizaed shape.
- + private static final int BATCH_VALUE = 1; // The batch axis should always be one.
- + private static final int HEIGHT_DIM = 1; // The second element of the normalizaed shape.
- + private static final int WIDTH_DIM = 2; // The third element of the normalizaed shape.
- + private static final int CHANNEL_DIM = 3; // The fourth element of the normalizaed shape.
- + private final int value;
- +
- + ColorSpaceType(int value) {
- + this.value = value;
- }
-
- - @Override
- - int getChannelValue() {
- - return CHANNEL_VALUE;
- + /**
- + * Converts a bitmap configuration into the corresponding color space type.
- + *
- + * @throws IllegalArgumentException if the config is unsupported
- + */
- + static ColorSpaceType fromBitmapConfig(Config config) {
- + switch (config) {
- + case ARGB_8888:
- + return ColorSpaceType.RGB;
- + case ALPHA_8:
- + return ColorSpaceType.GRAYSCALE;
- + default:
- + throw new IllegalArgumentException(
- + "Bitmap configuration: " + config + ", is not supported yet.");
- + }
- }
-
- - @Override
- - int[] getNormalizedShape(int[] shape) {
- - switch (shape.length) {
- - // The shape is in (h, w, c) format.
- - case 3:
- - return insertValue(shape, BATCH_DIM, BATCH_VALUE);
- - case 4:
- - return shape;
- - default:
- - throw new IllegalArgumentException(
- - getShapeInfoMessage() + "The provided image shape is " + Arrays.toString(shape));
- - }
- + /**
- + * Converts an {@link ImageFormat} value into the corresponding color space type.
- + *
- + * @throws IllegalArgumentException if the config is unsupported
- + */
- + static ColorSpaceType fromImageFormat(int imageFormat) {
- + switch (imageFormat) {
- + case ImageFormat.NV21:
- + return ColorSpaceType.NV21;
- + case ImageFormat.YV12:
- + return ColorSpaceType.YV12;
- + case ImageFormat.YUV_420_888:
- + return ColorSpaceType.YUV_420_888;
- + default:
- + throw new IllegalArgumentException(
- + "ImageFormat: " + imageFormat + ", is not supported yet.");
- + }
- }
-
- - @Override
- - int getNumElements(int height, int width) {
- - return height * width * CHANNEL_VALUE;
- + public int getValue() {
- + return value;
- }
-
- - @Override
- - String getShapeInfoMessage() {
- - return "The shape of a RGB image should be (h, w, c) or (1, h, w, c), and channels"
- - + " representing R, G, B in order. ";
- + /**
- + * Verifies if the given shape matches the color space type.
- + *
- + * @throws IllegalArgumentException if {@code shape} does not match the color space type
- + * @throws UnsupportedOperationException if the color space type is not RGB or GRAYSCALE
- + */
- + void assertShape(int[] shape) {
- + assertRgbOrGrayScale("assertShape()");
- +
- + int[] normalizedShape = getNormalizedShape(shape);
- + checkArgument(isValidNormalizedShape(normalizedShape),
- + getShapeInfoMessage() + "The provided image shape is " + Arrays.toString(shape));
- }
-
- - @Override
- - Config toBitmapConfig() {
- - return Config.ARGB_8888;
- + /**
- + * Verifies if the given {@code numElements} in an image buffer matches {@code height} / {@code
- + * width} under this color space type. For example, the {@code numElements} of an RGB image of
- + * 30 x 20 should be {@code 30 * 20 * 3 = 1800}; the {@code numElements} of a NV21 image of 30 x
- + * 20 should be {@code 30 * 20 + ((30 + 1) / 2 * (20 + 1) / 2) * 2 = 952}.
- + *
- + * @throws IllegalArgumentException if {@code shape} does not match the color space type
- + */
- + void assertNumElements(int numElements, int height, int width) {
- + checkArgument(numElements >= getNumElements(height, width),
- + String.format(
- + "The given number of elements (%d) does not match the image (%s) in %d x %d. The"
- + + " expected number of elements should be at least %d.",
- + numElements, this.name(), height, width, getNumElements(height, width)));
- }
- - },
- -
- - /** Each pixel is a single element representing only the amount of light. */
- - GRAYSCALE(1) {
- -
- - // The channel axis should always be 1 for grayscale images.
- - private static final int CHANNEL_VALUE = 1;
-
- - @Override
- + /**
- + * Converts a {@link TensorBuffer} that represents an image to a Bitmap with the color space
- + * type.
- + *
- + * @throws IllegalArgumentException if the shape of buffer does not match the color space type,
- + * @throws UnsupportedOperationException if the color space type is not RGB or GRAYSCALE
- + */
- Bitmap convertTensorBufferToBitmap(TensorBuffer buffer) {
- - return ImageConversions.convertGrayscaleTensorBufferToBitmap(buffer);
- + throw new UnsupportedOperationException(
- + "convertTensorBufferToBitmap() is unsupported for the color space type "
- + + this.name());
- }
-
- - @Override
- - int getChannelValue() {
- - return CHANNEL_VALUE;
- + /**
- + * Returns the width of the given shape corresponding to the color space type.
- + *
- + * @throws IllegalArgumentException if {@code shape} does not match the color space type
- + * @throws UnsupportedOperationException if the color space type is not RGB or GRAYSCALE
- + */
- + int getWidth(int[] shape) {
- + assertRgbOrGrayScale("getWidth()");
- + assertShape(shape);
- + return getNormalizedShape(shape)[WIDTH_DIM];
- }
-
- - @Override
- - int[] getNormalizedShape(int[] shape) {
- - switch (shape.length) {
- - // The shape is in (h, w) format.
- - case 2:
- - int[] shapeWithBatch = insertValue(shape, BATCH_DIM, BATCH_VALUE);
- - return insertValue(shapeWithBatch, CHANNEL_DIM, CHANNEL_VALUE);
- - case 4:
- - return shape;
- - default:
- - // (1, h, w) and (h, w, 1) are potential grayscale image shapes. However, since they
- - // both have three dimensions, it will require extra info to differentiate between them.
- - // Since we haven't encountered real use cases of these two shapes, they are not supported
- - // at this moment to avoid confusion. We may want to revisit it in the future.
- - throw new IllegalArgumentException(
- - getShapeInfoMessage() + "The provided image shape is " + Arrays.toString(shape));
- - }
- + /**
- + * Returns the height of the given shape corresponding to the color space type.
- + *
- + * @throws IllegalArgumentException if {@code shape} does not match the color space type
- + * @throws UnsupportedOperationException if the color space type is not RGB or GRAYSCALE
- + */
- + int getHeight(int[] shape) {
- + assertRgbOrGrayScale("getHeight()");
- + assertShape(shape);
- + return getNormalizedShape(shape)[HEIGHT_DIM];
- }
-
- - @Override
- - int getNumElements(int height, int width) {
- - return height * width;
- + /**
- + * Returns the channel value corresponding to the color space type.
- + *
- + * @throws UnsupportedOperationException if the color space type is not RGB or GRAYSCALE
- + */
- + int getChannelValue() {
- + throw new UnsupportedOperationException(
- + "getChannelValue() is unsupported for the color space type " + this.name());
- + }
- + /**
- + * Gets the normalized shape in the form of (1, h, w, c). Sometimes, a given shape may not have
- + * batch or channel axis.
- + *
- + * @throws UnsupportedOperationException if the color space type is not RGB or GRAYSCALE
- + */
- + int[] getNormalizedShape(int[] shape) {
- + throw new UnsupportedOperationException(
- + "getNormalizedShape() is unsupported for the color space type " + this.name());
- }
-
- - @Override
- + /**
- + * Returns the shape information corresponding to the color space type.
- + *
- + * @throws UnsupportedOperationException if the color space type is not RGB or GRAYSCALE
- + */
- String getShapeInfoMessage() {
- - return "The shape of a grayscale image should be (h, w) or (1, h, w, 1). ";
- + throw new UnsupportedOperationException(
- + "getShapeInfoMessage() is unsupported for the color space type " + this.name());
- }
-
- - @Override
- + /**
- + * Converts the color space type to the corresponding bitmap config.
- + *
- + * @throws UnsupportedOperationException if the color space type is not RGB or GRAYSCALE
- + */
- Config toBitmapConfig() {
- - return Config.ALPHA_8;
- + throw new UnsupportedOperationException(
- + "toBitmapConfig() is unsupported for the color space type " + this.name());
- }
- - },
-
- - /** YUV420sp format, encoded as "YYYYYYYY UVUV". */
- - NV12(2) {
- - @Override
- - int getNumElements(int height, int width) {
- - return getYuv420NumElements(height, width);
- - }
- - },
- -
- - /**
- - * YUV420sp format, encoded as "YYYYYYYY VUVU", the standard picture format on Android Camera1
- - * preview.
- - */
- - NV21(3) {
- - @Override
- - int getNumElements(int height, int width) {
- - return getYuv420NumElements(height, width);
- - }
- - },
- + /**
- + * Gets the number of elements given the height and width of an image. For example, the number
- + * of elements of an RGB image of 30 x 20 is {@code 30 * 20 * 3 = 1800}; the number of elements
- + * of a NV21 image of 30 x 20 is {@code 30 * 20 + ((30 + 1) / 2 * (20 + 1) / 2) * 2 = 952}.
- + */
- + abstract int getNumElements(int height, int width);
-
- - /** YUV420p format, encoded as "YYYYYYYY VV UU". */
- - YV12(4) {
- - @Override
- - int getNumElements(int height, int width) {
- - return getYuv420NumElements(height, width);
- + private static int getYuv420NumElements(int height, int width) {
- + // Height and width of U/V planes are half of the Y plane.
- + return height * width + ((height + 1) / 2) * ((width + 1) / 2) * 2;
- }
- - },
-
- - /** YUV420p format, encoded as "YYYYYYYY UU VV". */
- - YV21(5) {
- - @Override
- - int getNumElements(int height, int width) {
- - return getYuv420NumElements(height, width);
- + /** Inserts a value at the specified position and return the new array. */
- + private static int[] insertValue(int[] array, int pos, int value) {
- + int[] newArray = new int[array.length + 1];
- + for (int i = 0; i < pos; i++) {
- + newArray[i] = array[i];
- + }
- + newArray[pos] = value;
- + for (int i = pos + 1; i < newArray.length; i++) {
- + newArray[i] = array[i - 1];
- + }
- + return newArray;
- }
- - },
- -
- - /**
- - * YUV420 format corresponding to {@link android.graphics.ImageFormat#YUV_420_888}. The actual
- - * encoding format (i.e. NV12 / Nv21 / YV12 / YV21) depends on the implementation of the image.
- - *
- - * <p>Use this format only when you load an {@link android.media.Image}.
- - */
- - YUV_420_888(6) {
- - @Override
- - int getNumElements(int height, int width) {
- - return getYuv420NumElements(height, width);
- - }
- - };
- -
- - private static final int BATCH_DIM = 0; // The first element of the normalizaed shape.
- - private static final int BATCH_VALUE = 1; // The batch axis should always be one.
- - private static final int HEIGHT_DIM = 1; // The second element of the normalizaed shape.
- - private static final int WIDTH_DIM = 2; // The third element of the normalizaed shape.
- - private static final int CHANNEL_DIM = 3; // The fourth element of the normalizaed shape.
- - private final int value;
- -
- - ColorSpaceType(int value) {
- - this.value = value;
- - }
- -
- - /**
- - * Converts a bitmap configuration into the corresponding color space type.
- - *
- - * @throws IllegalArgumentException if the config is unsupported
- - */
- - static ColorSpaceType fromBitmapConfig(Config config) {
- - switch (config) {
- - case ARGB_8888:
- - return ColorSpaceType.RGB;
- - case ALPHA_8:
- - return ColorSpaceType.GRAYSCALE;
- - default:
- - throw new IllegalArgumentException(
- - "Bitmap configuration: " + config + ", is not supported yet.");
- - }
- - }
- -
- - /**
- - * Converts an {@link ImageFormat} value into the corresponding color space type.
- - *
- - * @throws IllegalArgumentException if the config is unsupported
- - */
- - static ColorSpaceType fromImageFormat(int imageFormat) {
- - switch (imageFormat) {
- - case ImageFormat.NV21:
- - return ColorSpaceType.NV21;
- - case ImageFormat.YV12:
- - return ColorSpaceType.YV12;
- - case ImageFormat.YUV_420_888:
- - return ColorSpaceType.YUV_420_888;
- - default:
- - throw new IllegalArgumentException(
- - "ImageFormat: " + imageFormat + ", is not supported yet.");
- - }
- - }
- -
- - public int getValue() {
- - return value;
- - }
- -
- - /**
- - * Verifies if the given shape matches the color space type.
- - *
- - * @throws IllegalArgumentException if {@code shape} does not match the color space type
- - * @throws UnsupportedOperationException if the color space type is not RGB or GRAYSCALE
- - */
- - void assertShape(int[] shape) {
- - assertRgbOrGrayScale("assertShape()");
- -
- - int[] normalizedShape = getNormalizedShape(shape);
- - checkArgument(
- - isValidNormalizedShape(normalizedShape),
- - getShapeInfoMessage() + "The provided image shape is " + Arrays.toString(shape));
- - }
- -
- - /**
- - * Verifies if the given {@code numElements} in an image buffer matches {@code height} / {@code
- - * width} under this color space type. For example, the {@code numElements} of an RGB image of 30
- - * x 20 should be {@code 30 * 20 * 3 = 1800}; the {@code numElements} of a NV21 image of 30 x 20
- - * should be {@code 30 * 20 + ((30 + 1) / 2 * (20 + 1) / 2) * 2 = 952}.
- - *
- - * @throws IllegalArgumentException if {@code shape} does not match the color space type
- - */
- - void assertNumElements(int numElements, int height, int width) {
- - checkArgument(
- - numElements >= getNumElements(height, width),
- - String.format(
- - "The given number of elements (%d) does not match the image (%s) in %d x %d. The"
- - + " expected number of elements should be at least %d.",
- - numElements, this.name(), height, width, getNumElements(height, width)));
- - }
- -
- - /**
- - * Converts a {@link TensorBuffer} that represents an image to a Bitmap with the color space type.
- - *
- - * @throws IllegalArgumentException if the shape of buffer does not match the color space type,
- - * @throws UnsupportedOperationException if the color space type is not RGB or GRAYSCALE
- - */
- - Bitmap convertTensorBufferToBitmap(TensorBuffer buffer) {
- - throw new UnsupportedOperationException(
- - "convertTensorBufferToBitmap() is unsupported for the color space type " + this.name());
- - }
- -
- - /**
- - * Returns the width of the given shape corresponding to the color space type.
- - *
- - * @throws IllegalArgumentException if {@code shape} does not match the color space type
- - * @throws UnsupportedOperationException if the color space type is not RGB or GRAYSCALE
- - */
- - int getWidth(int[] shape) {
- - assertRgbOrGrayScale("getWidth()");
- - assertShape(shape);
- - return getNormalizedShape(shape)[WIDTH_DIM];
- - }
- -
- - /**
- - * Returns the height of the given shape corresponding to the color space type.
- - *
- - * @throws IllegalArgumentException if {@code shape} does not match the color space type
- - * @throws UnsupportedOperationException if the color space type is not RGB or GRAYSCALE
- - */
- - int getHeight(int[] shape) {
- - assertRgbOrGrayScale("getHeight()");
- - assertShape(shape);
- - return getNormalizedShape(shape)[HEIGHT_DIM];
- - }
- -
- - /**
- - * Returns the channel value corresponding to the color space type.
- - *
- - * @throws UnsupportedOperationException if the color space type is not RGB or GRAYSCALE
- - */
- - int getChannelValue() {
- - throw new UnsupportedOperationException(
- - "getChannelValue() is unsupported for the color space type " + this.name());
- - }
- - /**
- - * Gets the normalized shape in the form of (1, h, w, c). Sometimes, a given shape may not have
- - * batch or channel axis.
- - *
- - * @throws UnsupportedOperationException if the color space type is not RGB or GRAYSCALE
- - */
- - int[] getNormalizedShape(int[] shape) {
- - throw new UnsupportedOperationException(
- - "getNormalizedShape() is unsupported for the color space type " + this.name());
- - }
- -
- - /**
- - * Returns the shape information corresponding to the color space type.
- - *
- - * @throws UnsupportedOperationException if the color space type is not RGB or GRAYSCALE
- - */
- - String getShapeInfoMessage() {
- - throw new UnsupportedOperationException(
- - "getShapeInfoMessage() is unsupported for the color space type " + this.name());
- - }
- -
- - /**
- - * Converts the color space type to the corresponding bitmap config.
- - *
- - * @throws UnsupportedOperationException if the color space type is not RGB or GRAYSCALE
- - */
- - Config toBitmapConfig() {
- - throw new UnsupportedOperationException(
- - "toBitmapConfig() is unsupported for the color space type " + this.name());
- - }
- -
- - /**
- - * Gets the number of elements given the height and width of an image. For example, the number of
- - * elements of an RGB image of 30 x 20 is {@code 30 * 20 * 3 = 1800}; the number of elements of a
- - * NV21 image of 30 x 20 is {@code 30 * 20 + ((30 + 1) / 2 * (20 + 1) / 2) * 2 = 952}.
- - */
- - abstract int getNumElements(int height, int width);
- -
- - private static int getYuv420NumElements(int height, int width) {
- - // Height and width of U/V planes are half of the Y plane.
- - return height * width + ((height + 1) / 2) * ((width + 1) / 2) * 2;
- - }
- -
- - /** Inserts a value at the specified position and return the new array. */
- - private static int[] insertValue(int[] array, int pos, int value) {
- - int[] newArray = new int[array.length + 1];
- - for (int i = 0; i < pos; i++) {
- - newArray[i] = array[i];
- - }
- - newArray[pos] = value;
- - for (int i = pos + 1; i < newArray.length; i++) {
- - newArray[i] = array[i - 1];
- +
- + protected boolean isValidNormalizedShape(int[] shape) {
- + return shape[BATCH_DIM] == BATCH_VALUE && shape[HEIGHT_DIM] > 0 && shape[WIDTH_DIM] > 0
- + && shape[CHANNEL_DIM] == getChannelValue();
- }
- - return newArray;
- - }
- -
- - protected boolean isValidNormalizedShape(int[] shape) {
- - return shape[BATCH_DIM] == BATCH_VALUE
- - && shape[HEIGHT_DIM] > 0
- - && shape[WIDTH_DIM] > 0
- - && shape[CHANNEL_DIM] == getChannelValue();
- - }
- -
- - /** Some existing methods are only valid for RGB and GRAYSCALE images. */
- - private void assertRgbOrGrayScale(String unsupportedMethodName) {
- - if (this != ColorSpaceType.RGB && this != ColorSpaceType.GRAYSCALE) {
- - throw new UnsupportedOperationException(
- - unsupportedMethodName
- - + " only supports RGB and GRAYSCALE formats, but not "
- - + this.name());
- +
- + /** Some existing methods are only valid for RGB and GRAYSCALE images. */
- + private void assertRgbOrGrayScale(String unsupportedMethodName) {
- + if (this != ColorSpaceType.RGB && this != ColorSpaceType.GRAYSCALE) {
- + throw new UnsupportedOperationException(unsupportedMethodName
- + + " only supports RGB and GRAYSCALE formats, but not " + this.name());
- + }
- }
- - }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageContainer.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageContainer.java
- index 379d14798d62d..5c097da5ecb6d 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageContainer.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageContainer.java
- @@ -17,6 +17,7 @@ package org.tensorflow.lite.support.image;
-
- import android.graphics.Bitmap;
- import android.media.Image;
- +
- import org.tensorflow.lite.DataType;
- import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
-
- @@ -32,28 +33,27 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
- * </ul>
- */
- interface ImageContainer {
- + /** Performs deep copy of the {@link ImageContainer}. */
- + ImageContainer clone();
-
- - /** Performs deep copy of the {@link ImageContainer}. */
- - ImageContainer clone();
- -
- - /** Returns the width of the image. */
- - int getWidth();
- + /** Returns the width of the image. */
- + int getWidth();
-
- - /** Returns the height of the image. */
- - int getHeight();
- + /** Returns the height of the image. */
- + int getHeight();
-
- - /** Gets the {@link Bitmap} representation of the underlying image format. */
- - Bitmap getBitmap();
- + /** Gets the {@link Bitmap} representation of the underlying image format. */
- + Bitmap getBitmap();
-
- - /**
- - * Gets the {@link TensorBuffer} representation with the specific {@code dataType} of the
- - * underlying image format.
- - */
- - TensorBuffer getTensorBuffer(DataType dataType);
- + /**
- + * Gets the {@link TensorBuffer} representation with the specific {@code dataType} of the
- + * underlying image format.
- + */
- + TensorBuffer getTensorBuffer(DataType dataType);
-
- - /** Gets the {@link Image} representation of the underlying image format. */
- - Image getMediaImage();
- + /** Gets the {@link Image} representation of the underlying image format. */
- + Image getMediaImage();
-
- - /** Returns the color space type of the image. */
- - ColorSpaceType getColorSpaceType();
- + /** Returns the color space type of the image. */
- + ColorSpaceType getColorSpaceType();
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageConversions.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageConversions.java
- index 8ed169c49348e..7ed5306fd9f96 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageConversions.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageConversions.java
- @@ -17,128 +17,127 @@ package org.tensorflow.lite.support.image;
-
- import android.graphics.Bitmap;
- import android.graphics.Color;
- -import java.nio.ByteBuffer;
- -import java.nio.ByteOrder;
- +
- import org.tensorflow.lite.DataType;
- import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
-
- +import java.nio.ByteBuffer;
- +import java.nio.ByteOrder;
- +
- /**
- * Implements some stateless image conversion methods.
- *
- * <p>This class is an internal helper for {@link org.tensorflow.lite.support.image}.
- */
- class ImageConversions {
- + /**
- + * Converts a {@link TensorBuffer} that represents a RGB image to an ARGB_8888 Bitmap.
- + *
- + * <p>Data in buffer will be converted into integer to match the Bitmap API.
- + *
- + * @param buffer a RGB image. Its shape should be either (h, w, 3) or (1, h, w, 3)
- + * @throws IllegalArgumentException if the shape of buffer is neither (h, w, 3) nor (1, h, w, 3)
- + */
- + static Bitmap convertRgbTensorBufferToBitmap(TensorBuffer buffer) {
- + int[] shape = buffer.getShape();
- + ColorSpaceType rgb = ColorSpaceType.RGB;
- + rgb.assertShape(shape);
-
- - /**
- - * Converts a {@link TensorBuffer} that represents a RGB image to an ARGB_8888 Bitmap.
- - *
- - * <p>Data in buffer will be converted into integer to match the Bitmap API.
- - *
- - * @param buffer a RGB image. Its shape should be either (h, w, 3) or (1, h, w, 3)
- - * @throws IllegalArgumentException if the shape of buffer is neither (h, w, 3) nor (1, h, w, 3)
- - */
- - static Bitmap convertRgbTensorBufferToBitmap(TensorBuffer buffer) {
- - int[] shape = buffer.getShape();
- - ColorSpaceType rgb = ColorSpaceType.RGB;
- - rgb.assertShape(shape);
- -
- - int h = rgb.getHeight(shape);
- - int w = rgb.getWidth(shape);
- - Bitmap bitmap = Bitmap.createBitmap(w, h, rgb.toBitmapConfig());
- -
- - // TODO(b/138904567): Find a way to avoid creating multiple intermediate buffers every time.
- - int[] intValues = new int[w * h];
- - int[] rgbValues = buffer.getIntArray();
- - for (int i = 0, j = 0; i < intValues.length; i++) {
- - int r = rgbValues[j++];
- - int g = rgbValues[j++];
- - int b = rgbValues[j++];
- - intValues[i] = Color.rgb(r, g, b);
- - }
- - bitmap.setPixels(intValues, 0, w, 0, 0, w, h);
- -
- - return bitmap;
- - }
- -
- - /**
- - * Converts a {@link TensorBuffer} that represents a grayscale image to an ALPHA_8 Bitmap.
- - *
- - * <p>Data in buffer will be converted into integer to match the Bitmap API.
- - *
- - * @param buffer a grayscale image. Its shape should be either (h, w) or (1, h, w)
- - * @throws IllegalArgumentException if the shape of buffer is neither (h, w) nor (1, h, w, 1)
- - */
- - static Bitmap convertGrayscaleTensorBufferToBitmap(TensorBuffer buffer) {
- - // Convert buffer into Uint8 as needed.
- - TensorBuffer uint8Buffer =
- - buffer.getDataType() == DataType.UINT8
- - ? buffer
- - : TensorBuffer.createFrom(buffer, DataType.UINT8);
- -
- - int[] shape = uint8Buffer.getShape();
- - ColorSpaceType grayscale = ColorSpaceType.GRAYSCALE;
- - grayscale.assertShape(shape);
- -
- - // Even though `Bitmap.createBitmap(int[] colors, int width, int height, Bitmap.Config config)`
- - // seems to work for internal Android testing framework, but it actually doesn't work for the
- - // real Android environment.
- - //
- - // The only reliable way to create an ALPHA_8 Bitmap is to use `copyPixelsFromBuffer()` to load
- - // the pixels from a ByteBuffer, and then use `copyPixelsToBuffer` to read out.
- - // Note: for ALPHA_8 Bitmap, methods such as, `setPixels()` and `getPixels()` do not work.
- - Bitmap bitmap =
- - Bitmap.createBitmap(
- - grayscale.getWidth(shape), grayscale.getHeight(shape), grayscale.toBitmapConfig());
- - uint8Buffer.getBuffer().rewind();
- - bitmap.copyPixelsFromBuffer(uint8Buffer.getBuffer());
- - return bitmap;
- - }
- -
- - /**
- - * Converts an Image in a Bitmap to a TensorBuffer (3D Tensor: Width-Height-Channel) whose memory
- - * is already allocated, or could be dynamically allocated.
- - *
- - * @param bitmap The Bitmap object representing the image. Currently we only support ARGB_8888
- - * config.
- - * @param buffer The destination of the conversion. Needs to be created in advance. If it's
- - * fixed-size, its flat size should be w*h*3.
- - * @throws IllegalArgumentException if the buffer is fixed-size, but the size doesn't match.
- - */
- - static void convertBitmapToTensorBuffer(Bitmap bitmap, TensorBuffer buffer) {
- - int w = bitmap.getWidth();
- - int h = bitmap.getHeight();
- - int[] intValues = new int[w * h];
- - bitmap.getPixels(intValues, 0, w, 0, 0, w, h);
- - // TODO(b/138904567): Find a way to avoid creating multiple intermediate buffers every time.
- - int[] shape = new int[] {h, w, 3};
- - switch (buffer.getDataType()) {
- - case UINT8:
- - byte[] byteArr = new byte[w * h * 3];
- + int h = rgb.getHeight(shape);
- + int w = rgb.getWidth(shape);
- + Bitmap bitmap = Bitmap.createBitmap(w, h, rgb.toBitmapConfig());
- +
- + // TODO(b/138904567): Find a way to avoid creating multiple intermediate buffers every time.
- + int[] intValues = new int[w * h];
- + int[] rgbValues = buffer.getIntArray();
- for (int i = 0, j = 0; i < intValues.length; i++) {
- - byteArr[j++] = (byte) ((intValues[i] >> 16) & 0xff);
- - byteArr[j++] = (byte) ((intValues[i] >> 8) & 0xff);
- - byteArr[j++] = (byte) (intValues[i] & 0xff);
- + int r = rgbValues[j++];
- + int g = rgbValues[j++];
- + int b = rgbValues[j++];
- + intValues[i] = Color.rgb(r, g, b);
- }
- - ByteBuffer byteBuffer = ByteBuffer.wrap(byteArr);
- - byteBuffer.order(ByteOrder.nativeOrder());
- - buffer.loadBuffer(byteBuffer, shape);
- - break;
- - case FLOAT32:
- - float[] floatArr = new float[w * h * 3];
- - for (int i = 0, j = 0; i < intValues.length; i++) {
- - floatArr[j++] = (float) ((intValues[i] >> 16) & 0xff);
- - floatArr[j++] = (float) ((intValues[i] >> 8) & 0xff);
- - floatArr[j++] = (float) (intValues[i] & 0xff);
- + bitmap.setPixels(intValues, 0, w, 0, 0, w, h);
- +
- + return bitmap;
- + }
- +
- + /**
- + * Converts a {@link TensorBuffer} that represents a grayscale image to an ALPHA_8 Bitmap.
- + *
- + * <p>Data in buffer will be converted into integer to match the Bitmap API.
- + *
- + * @param buffer a grayscale image. Its shape should be either (h, w) or (1, h, w)
- + * @throws IllegalArgumentException if the shape of buffer is neither (h, w) nor (1, h, w, 1)
- + */
- + static Bitmap convertGrayscaleTensorBufferToBitmap(TensorBuffer buffer) {
- + // Convert buffer into Uint8 as needed.
- + TensorBuffer uint8Buffer = buffer.getDataType() == DataType.UINT8
- + ? buffer
- + : TensorBuffer.createFrom(buffer, DataType.UINT8);
- +
- + int[] shape = uint8Buffer.getShape();
- + ColorSpaceType grayscale = ColorSpaceType.GRAYSCALE;
- + grayscale.assertShape(shape);
- +
- + // Even though `Bitmap.createBitmap(int[] colors, int width, int height, Bitmap.Config
- + // config)` seems to work for internal Android testing framework, but it actually doesn't
- + // work for the real Android environment.
- + //
- + // The only reliable way to create an ALPHA_8 Bitmap is to use `copyPixelsFromBuffer()` to
- + // load the pixels from a ByteBuffer, and then use `copyPixelsToBuffer` to read out. Note:
- + // for ALPHA_8 Bitmap, methods such as, `setPixels()` and `getPixels()` do not work.
- + Bitmap bitmap = Bitmap.createBitmap(
- + grayscale.getWidth(shape), grayscale.getHeight(shape), grayscale.toBitmapConfig());
- + uint8Buffer.getBuffer().rewind();
- + bitmap.copyPixelsFromBuffer(uint8Buffer.getBuffer());
- + return bitmap;
- + }
- +
- + /**
- + * Converts an Image in a Bitmap to a TensorBuffer (3D Tensor: Width-Height-Channel) whose
- + * memory is already allocated, or could be dynamically allocated.
- + *
- + * @param bitmap The Bitmap object representing the image. Currently we only support ARGB_8888
- + * config.
- + * @param buffer The destination of the conversion. Needs to be created in advance. If it's
- + * fixed-size, its flat size should be w*h*3.
- + * @throws IllegalArgumentException if the buffer is fixed-size, but the size doesn't match.
- + */
- + static void convertBitmapToTensorBuffer(Bitmap bitmap, TensorBuffer buffer) {
- + int w = bitmap.getWidth();
- + int h = bitmap.getHeight();
- + int[] intValues = new int[w * h];
- + bitmap.getPixels(intValues, 0, w, 0, 0, w, h);
- + // TODO(b/138904567): Find a way to avoid creating multiple intermediate buffers every time.
- + int[] shape = new int[] {h, w, 3};
- + switch (buffer.getDataType()) {
- + case UINT8:
- + byte[] byteArr = new byte[w * h * 3];
- + for (int i = 0, j = 0; i < intValues.length; i++) {
- + byteArr[j++] = (byte) ((intValues[i] >> 16) & 0xff);
- + byteArr[j++] = (byte) ((intValues[i] >> 8) & 0xff);
- + byteArr[j++] = (byte) (intValues[i] & 0xff);
- + }
- + ByteBuffer byteBuffer = ByteBuffer.wrap(byteArr);
- + byteBuffer.order(ByteOrder.nativeOrder());
- + buffer.loadBuffer(byteBuffer, shape);
- + break;
- + case FLOAT32:
- + float[] floatArr = new float[w * h * 3];
- + for (int i = 0, j = 0; i < intValues.length; i++) {
- + floatArr[j++] = (float) ((intValues[i] >> 16) & 0xff);
- + floatArr[j++] = (float) ((intValues[i] >> 8) & 0xff);
- + floatArr[j++] = (float) (intValues[i] & 0xff);
- + }
- + buffer.loadArray(floatArr, shape);
- + break;
- + default:
- + // Should never happen.
- + throw new IllegalStateException(
- + "The type of TensorBuffer, " + buffer.getBuffer() + ", is unsupported.");
- }
- - buffer.loadArray(floatArr, shape);
- - break;
- - default:
- - // Should never happen.
- - throw new IllegalStateException(
- - "The type of TensorBuffer, " + buffer.getBuffer() + ", is unsupported.");
- }
- - }
-
- - // Hide the constructor as the class is static.
- - private ImageConversions() {}
- + // Hide the constructor as the class is static.
- + private ImageConversions() {}
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageOperator.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageOperator.java
- index 1e546634e90e7..e852569490f0b 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageOperator.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageOperator.java
- @@ -16,28 +16,29 @@ limitations under the License.
- package org.tensorflow.lite.support.image;
-
- import android.graphics.PointF;
- +
- import org.tensorflow.lite.support.common.Operator;
-
- /** Operates a TensorImage object. Used in ImageProcessor. */
- public interface ImageOperator extends Operator<TensorImage> {
- - /** @see org.tensorflow.lite.support.common.Operator#apply(java.lang.Object) */
- - @Override
- - TensorImage apply(TensorImage image);
- -
- - /** Computes the width of the expected output image when input image size is given. */
- - int getOutputImageWidth(int inputImageHeight, int inputImageWidth);
- -
- - /** Computes the height of the expected output image when input image size is given. */
- - int getOutputImageHeight(int inputImageHeight, int inputImageWidth);
- -
- - /**
- - * Transforms a point from coordinates system of the result image back to the one of the input
- - * image.
- - *
- - * @param point the point from the result coordinates system.
- - * @param inputImageHeight the height of input image.
- - * @param inputImageWidth the width of input image.
- - * @return the point with the coordinates from the coordinates system of the input image.
- - */
- - PointF inverseTransform(PointF point, int inputImageHeight, int inputImageWidth);
- + /** @see org.tensorflow.lite.support.common.Operator#apply(java.lang.Object) */
- + @Override
- + TensorImage apply(TensorImage image);
- +
- + /** Computes the width of the expected output image when input image size is given. */
- + int getOutputImageWidth(int inputImageHeight, int inputImageWidth);
- +
- + /** Computes the height of the expected output image when input image size is given. */
- + int getOutputImageHeight(int inputImageHeight, int inputImageWidth);
- +
- + /**
- + * Transforms a point from coordinates system of the result image back to the one of the input
- + * image.
- + *
- + * @param point the point from the result coordinates system.
- + * @param inputImageHeight the height of input image.
- + * @param inputImageWidth the width of input image.
- + * @return the point with the coordinates from the coordinates system of the input image.
- + */
- + PointF inverseTransform(PointF point, int inputImageHeight, int inputImageWidth);
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageProcessor.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageProcessor.java
- index c44aa9efad708..c7d51355920ee 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageProcessor.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageProcessor.java
- @@ -20,9 +20,7 @@ import static java.lang.Math.min;
-
- import android.graphics.PointF;
- import android.graphics.RectF;
- -import java.util.ArrayList;
- -import java.util.List;
- -import java.util.ListIterator;
- +
- import org.tensorflow.lite.support.common.Operator;
- import org.tensorflow.lite.support.common.SequentialProcessor;
- import org.tensorflow.lite.support.common.TensorOperator;
- @@ -30,6 +28,10 @@ import org.tensorflow.lite.support.common.internal.SupportPreconditions;
- import org.tensorflow.lite.support.image.ops.Rot90Op;
- import org.tensorflow.lite.support.image.ops.TensorOperatorWrapper;
-
- +import java.util.ArrayList;
- +import java.util.List;
- +import java.util.ListIterator;
- +
- /**
- * ImageProcessor is a helper class for preprocessing and postprocessing {@link TensorImage}. It
- * could transform a {@link TensorImage} to another by executing a chain of {@link ImageOperator}.
- @@ -55,156 +57,159 @@ import org.tensorflow.lite.support.image.ops.TensorOperatorWrapper;
- * @see ImageProcessor#process(TensorImage) to apply the processor on a {@code TensorImage}
- */
- public class ImageProcessor extends SequentialProcessor<TensorImage> {
- - private ImageProcessor(Builder builder) {
- - super(builder);
- - }
- -
- - /**
- - * Transforms a point from coordinates system of the result image back to the one of the input
- - * image.
- - *
- - * @param point the point from the result coordinates system.
- - * @param inputImageHeight the height of input image.
- - * @param inputImageWidth the width of input image.
- - * @return the point with the coordinates from the coordinates system of the input image.
- - */
- - public PointF inverseTransform(PointF point, int inputImageHeight, int inputImageWidth) {
- - List<Integer> widths = new ArrayList<>();
- - List<Integer> heights = new ArrayList<>();
- - int currentWidth = inputImageWidth;
- - int currentHeight = inputImageHeight;
- - for (Operator<TensorImage> op : operatorList) {
- - widths.add(currentWidth);
- - heights.add(currentHeight);
- - ImageOperator imageOperator = (ImageOperator) op;
- - int newHeight = imageOperator.getOutputImageHeight(currentHeight, currentWidth);
- - int newWidth = imageOperator.getOutputImageWidth(currentHeight, currentWidth);
- - currentHeight = newHeight;
- - currentWidth = newWidth;
- + private ImageProcessor(Builder builder) {
- + super(builder);
- }
- - ListIterator<Operator<TensorImage>> opIterator = operatorList.listIterator(operatorList.size());
- - ListIterator<Integer> widthIterator = widths.listIterator(widths.size());
- - ListIterator<Integer> heightIterator = heights.listIterator(heights.size());
- - while (opIterator.hasPrevious()) {
- - ImageOperator imageOperator = (ImageOperator) opIterator.previous();
- - int height = heightIterator.previous();
- - int width = widthIterator.previous();
- - point = imageOperator.inverseTransform(point, height, width);
- +
- + /**
- + * Transforms a point from coordinates system of the result image back to the one of the input
- + * image.
- + *
- + * @param point the point from the result coordinates system.
- + * @param inputImageHeight the height of input image.
- + * @param inputImageWidth the width of input image.
- + * @return the point with the coordinates from the coordinates system of the input image.
- + */
- + public PointF inverseTransform(PointF point, int inputImageHeight, int inputImageWidth) {
- + List<Integer> widths = new ArrayList<>();
- + List<Integer> heights = new ArrayList<>();
- + int currentWidth = inputImageWidth;
- + int currentHeight = inputImageHeight;
- + for (Operator<TensorImage> op : operatorList) {
- + widths.add(currentWidth);
- + heights.add(currentHeight);
- + ImageOperator imageOperator = (ImageOperator) op;
- + int newHeight = imageOperator.getOutputImageHeight(currentHeight, currentWidth);
- + int newWidth = imageOperator.getOutputImageWidth(currentHeight, currentWidth);
- + currentHeight = newHeight;
- + currentWidth = newWidth;
- + }
- + ListIterator<Operator<TensorImage>> opIterator =
- + operatorList.listIterator(operatorList.size());
- + ListIterator<Integer> widthIterator = widths.listIterator(widths.size());
- + ListIterator<Integer> heightIterator = heights.listIterator(heights.size());
- + while (opIterator.hasPrevious()) {
- + ImageOperator imageOperator = (ImageOperator) opIterator.previous();
- + int height = heightIterator.previous();
- + int width = widthIterator.previous();
- + point = imageOperator.inverseTransform(point, height, width);
- + }
- + return point;
- + }
- +
- + /**
- + * Transforms a rectangle from coordinates system of the result image back to the one of the
- + * input image.
- + *
- + * @param rect the rectangle from the result coordinates system.
- + * @param inputImageHeight the height of input image.
- + * @param inputImageWidth the width of input image.
- + * @return the rectangle with the coordinates from the coordinates system of the input image.
- + */
- + public RectF inverseTransform(RectF rect, int inputImageHeight, int inputImageWidth) {
- + // when rotation is involved, corner order may change - top left changes to bottom right,
- + // .etc
- + PointF p1 = inverseTransform(
- + new PointF(rect.left, rect.top), inputImageHeight, inputImageWidth);
- + PointF p2 = inverseTransform(
- + new PointF(rect.right, rect.bottom), inputImageHeight, inputImageWidth);
- + return new RectF(min(p1.x, p2.x), min(p1.y, p2.y), max(p1.x, p2.x), max(p1.y, p2.y));
- }
- - return point;
- - }
- -
- - /**
- - * Transforms a rectangle from coordinates system of the result image back to the one of the input
- - * image.
- - *
- - * @param rect the rectangle from the result coordinates system.
- - * @param inputImageHeight the height of input image.
- - * @param inputImageWidth the width of input image.
- - * @return the rectangle with the coordinates from the coordinates system of the input image.
- - */
- - public RectF inverseTransform(RectF rect, int inputImageHeight, int inputImageWidth) {
- - // when rotation is involved, corner order may change - top left changes to bottom right, .etc
- - PointF p1 =
- - inverseTransform(new PointF(rect.left, rect.top), inputImageHeight, inputImageWidth);
- - PointF p2 =
- - inverseTransform(new PointF(rect.right, rect.bottom), inputImageHeight, inputImageWidth);
- - return new RectF(min(p1.x, p2.x), min(p1.y, p2.y), max(p1.x, p2.x), max(p1.y, p2.y));
- - }
- -
- - /**
- - * Processes a {@link TensorImage} object with prepared {@link TensorOperator}.
- - *
- - * @throws IllegalArgumentException if the image is not supported by any op.
- - */
- - @Override
- - public TensorImage process(TensorImage image) {
- - return super.process(image);
- - }
- -
- - /**
- - * The Builder to create an ImageProcessor, which could be executed later.
- - *
- - * @see #add(TensorOperator) to add a general TensorOperator
- - * @see #add(ImageOperator) to add an ImageOperator
- - * @see #build() complete the building process and get a built Processor
- - */
- - public static class Builder extends SequentialProcessor.Builder<TensorImage> {
- - public Builder() {
- - super();
- +
- + /**
- + * Processes a {@link TensorImage} object with prepared {@link TensorOperator}.
- + *
- + * @throws IllegalArgumentException if the image is not supported by any op.
- + */
- + @Override
- + public TensorImage process(TensorImage image) {
- + return super.process(image);
- }
-
- /**
- - * Adds an {@link ImageOperator} into the Operator chain.
- + * The Builder to create an ImageProcessor, which could be executed later.
- *
- - * @param op the Operator instance to be executed then
- + * @see #add(TensorOperator) to add a general TensorOperator
- + * @see #add(ImageOperator) to add an ImageOperator
- + * @see #build() complete the building process and get a built Processor
- */
- - public Builder add(ImageOperator op) {
- - super.add(op);
- - return this;
- + public static class Builder extends SequentialProcessor.Builder<TensorImage> {
- + public Builder() {
- + super();
- + }
- +
- + /**
- + * Adds an {@link ImageOperator} into the Operator chain.
- + *
- + * @param op the Operator instance to be executed then
- + */
- + public Builder add(ImageOperator op) {
- + super.add(op);
- + return this;
- + }
- +
- + /**
- + * Adds a {@link TensorOperator} into the Operator chain. In execution, the processor calls
- + * {@link TensorImage#getTensorBuffer()} to transform the {@link TensorImage} by
- + * transforming the underlying {@link
- + * org.tensorflow.lite.support.tensorbuffer.TensorBuffer}.
- + *
- + * @param op the Operator instance to be executed then
- + */
- + public Builder add(TensorOperator op) {
- + return add(new TensorOperatorWrapper(op));
- + }
- +
- + /** Completes the building process and gets the {@link ImageProcessor} instance. */
- + @Override
- + public ImageProcessor build() {
- + return new ImageProcessor(this);
- + }
- }
-
- /**
- - * Adds a {@link TensorOperator} into the Operator chain. In execution, the processor calls
- - * {@link TensorImage#getTensorBuffer()} to transform the {@link TensorImage} by transforming
- - * the underlying {@link org.tensorflow.lite.support.tensorbuffer.TensorBuffer}.
- + * Updates the number of rotations for the first {@link Rot90Op} in this {@link ImageProcessor}.
- + *
- + * <p><b>WARNING:</b>this method is <b>not</b> thread-safe. Updating the number of rotations and
- + * then processing images (using {@link #process}) must be protected from concurrent access with
- + * additional synchronization.
- *
- - * @param op the Operator instance to be executed then
- + * @param k the number of rotations
- + * @throws IllegalStateException if {@link Rot90Op} has not been added to this {@link
- + * ImageProcessor}
- */
- - public Builder add(TensorOperator op) {
- - return add(new TensorOperatorWrapper(op));
- + public void updateNumberOfRotations(int k) {
- + updateNumberOfRotations(k, /*occurrence=*/0);
- }
-
- - /** Completes the building process and gets the {@link ImageProcessor} instance. */
- - @Override
- - public ImageProcessor build() {
- - return new ImageProcessor(this);
- + /**
- + * Updates the number of rotations for the {@link Rot90Op} specified by {@code occurrence} in
- + * this
- + * {@link ImageProcessor}.
- + *
- + * <p><b>WARNING:</b>this method is <b>not</b> thread-safe. Updating the number of rotations and
- + * then processing images (using {@link #process}) must be protected from concurrent access with
- + * additional synchronization.
- + *
- + * @param k the number of rotations
- + * @param occurrence the index of perticular {@link Rot90Op} in this {@link ImageProcessor}. For
- + * example, if the second {@link Rot90Op} needs to be updated, {@code occurrence} should be
- + * set to 1.
- + * @throws IndexOutOfBoundsException if {@code occurrence} is negative or is not less than the
- + * number of {@link Rot90Op} in this {@link ImageProcessor}
- + * @throws IllegalStateException if {@link Rot90Op} has not been added to this {@link
- + * ImageProcessor}
- + */
- + public synchronized void updateNumberOfRotations(int k, int occurrence) {
- + SupportPreconditions.checkState(operatorIndex.containsKey(Rot90Op.class.getName()),
- + "The Rot90Op has not been added to the ImageProcessor.");
- +
- + List<Integer> indexes = operatorIndex.get(Rot90Op.class.getName());
- + SupportPreconditions.checkElementIndex(occurrence, indexes.size(), "occurrence");
- +
- + // The index of the Rot90Op to be replaced in operatorList.
- + int index = indexes.get(occurrence);
- + Rot90Op newRot = new Rot90Op(k);
- + operatorList.set(index, newRot);
- }
- - }
- -
- - /**
- - * Updates the number of rotations for the first {@link Rot90Op} in this {@link ImageProcessor}.
- - *
- - * <p><b>WARNING:</b>this method is <b>not</b> thread-safe. Updating the number of rotations and
- - * then processing images (using {@link #process}) must be protected from concurrent access with
- - * additional synchronization.
- - *
- - * @param k the number of rotations
- - * @throws IllegalStateException if {@link Rot90Op} has not been added to this {@link
- - * ImageProcessor}
- - */
- - public void updateNumberOfRotations(int k) {
- - updateNumberOfRotations(k, /*occurrence=*/ 0);
- - }
- -
- - /**
- - * Updates the number of rotations for the {@link Rot90Op} specified by {@code occurrence} in this
- - * {@link ImageProcessor}.
- - *
- - * <p><b>WARNING:</b>this method is <b>not</b> thread-safe. Updating the number of rotations and
- - * then processing images (using {@link #process}) must be protected from concurrent access with
- - * additional synchronization.
- - *
- - * @param k the number of rotations
- - * @param occurrence the index of perticular {@link Rot90Op} in this {@link ImageProcessor}. For
- - * example, if the second {@link Rot90Op} needs to be updated, {@code occurrence} should be
- - * set to 1.
- - * @throws IndexOutOfBoundsException if {@code occurrence} is negative or is not less than the
- - * number of {@link Rot90Op} in this {@link ImageProcessor}
- - * @throws IllegalStateException if {@link Rot90Op} has not been added to this {@link
- - * ImageProcessor}
- - */
- - public synchronized void updateNumberOfRotations(int k, int occurrence) {
- - SupportPreconditions.checkState(
- - operatorIndex.containsKey(Rot90Op.class.getName()),
- - "The Rot90Op has not been added to the ImageProcessor.");
- -
- - List<Integer> indexes = operatorIndex.get(Rot90Op.class.getName());
- - SupportPreconditions.checkElementIndex(occurrence, indexes.size(), "occurrence");
- -
- - // The index of the Rot90Op to be replaced in operatorList.
- - int index = indexes.get(occurrence);
- - Rot90Op newRot = new Rot90Op(k);
- - operatorList.set(index, newRot);
- - }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageProperties.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageProperties.java
- index 96daf85a02f5a..f61f59fa13ce7 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageProperties.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageProperties.java
- @@ -26,52 +26,51 @@ import com.google.auto.value.AutoValue;
- */
- @AutoValue
- public abstract class ImageProperties {
- + private static final int DEFAULT_HEIGHT = -1;
- + private static final int DEFAULT_WIDTH = -1;
-
- - private static final int DEFAULT_HEIGHT = -1;
- - private static final int DEFAULT_WIDTH = -1;
- -
- - public abstract int getHeight();
- -
- - public abstract int getWidth();
- -
- - public abstract ColorSpaceType getColorSpaceType();
- -
- - public static Builder builder() {
- - return new AutoValue_ImageProperties.Builder()
- - .setHeight(DEFAULT_HEIGHT)
- - .setWidth(DEFAULT_WIDTH);
- - }
- -
- - /**
- - * Builder for {@link ImageProperties}. Different image objects may require different properties.
- - * See the detais below:
- - *
- - * <ul>
- - * {@link org.tensorflow.lite.support.tensorbuffer.TensorBuffer}:
- - * <li>Mandatory proterties: height / width / colorSpaceType. The shape of the TensorBuffer
- - * object will not be used to determine image height and width.
- - * </ul>
- - */
- - @AutoValue.Builder
- - public abstract static class Builder {
- - public abstract Builder setHeight(int height);
- -
- - public abstract Builder setWidth(int width);
- -
- - public abstract Builder setColorSpaceType(ColorSpaceType colorSpaceType);
- -
- - abstract ImageProperties autoBuild();
- -
- - public ImageProperties build() {
- - ImageProperties properties = autoBuild();
- - // If width or hight are not configured by the Builder, they will be -1.
- - // Enforcing all properties to be populated (AutoValue will error out if objects, like
- - // colorSpaceType, are not set up), since they are required for TensorBuffer images.
- - // If in the future we have some image object types that only require a portion of these
- - // properties, we can delay the check when TensorImage#load() is executed.
- - checkState(properties.getHeight() >= 0, "Negative image height is not allowed.");
- - checkState(properties.getWidth() >= 0, "Negative image width is not allowed.");
- - return properties;
- + public abstract int getHeight();
- +
- + public abstract int getWidth();
- +
- + public abstract ColorSpaceType getColorSpaceType();
- +
- + public static Builder builder() {
- + return new AutoValue_ImageProperties.Builder()
- + .setHeight(DEFAULT_HEIGHT)
- + .setWidth(DEFAULT_WIDTH);
- + }
- +
- + /**
- + * Builder for {@link ImageProperties}. Different image objects may require different
- + * properties. See the detais below:
- + *
- + * <ul>
- + * {@link org.tensorflow.lite.support.tensorbuffer.TensorBuffer}:
- + * <li>Mandatory proterties: height / width / colorSpaceType. The shape of the TensorBuffer
- + * object will not be used to determine image height and width.
- + * </ul>
- + */
- + @AutoValue.Builder
- + public abstract static class Builder {
- + public abstract Builder setHeight(int height);
- +
- + public abstract Builder setWidth(int width);
- +
- + public abstract Builder setColorSpaceType(ColorSpaceType colorSpaceType);
- +
- + abstract ImageProperties autoBuild();
- +
- + public ImageProperties build() {
- + ImageProperties properties = autoBuild();
- + // If width or hight are not configured by the Builder, they will be -1.
- + // Enforcing all properties to be populated (AutoValue will error out if objects, like
- + // colorSpaceType, are not set up), since they are required for TensorBuffer images.
- + // If in the future we have some image object types that only require a portion of these
- + // properties, we can delay the check when TensorImage#load() is executed.
- + checkState(properties.getHeight() >= 0, "Negative image height is not allowed.");
- + checkState(properties.getWidth() >= 0, "Negative image width is not allowed.");
- + return properties;
- + }
- }
- - }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/MediaImageContainer.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/MediaImageContainer.java
- index 50d787b5afab1..519aacaf7f20b 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/MediaImageContainer.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/MediaImageContainer.java
- @@ -21,65 +21,65 @@ import static org.tensorflow.lite.support.common.internal.SupportPreconditions.c
- import android.graphics.Bitmap;
- import android.graphics.ImageFormat;
- import android.media.Image;
- +
- import org.tensorflow.lite.DataType;
- import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
-
- /** Holds an {@link Image} and converts it to other image formats as needed. */
- final class MediaImageContainer implements ImageContainer {
- -
- - private final Image image;
- -
- - /**
- - * Creates a {@link MediaImageContainer} object with a YUV_420_888 {@link Image}.
- - *
- - * @throws IllegalArgumentException if the {@link ImageFormat} of {@code image} is not ARGB_8888
- - */
- - static MediaImageContainer create(Image image) {
- - return new MediaImageContainer(image);
- - }
- -
- - private MediaImageContainer(Image image) {
- - checkNotNull(image, "Cannot load null Image.");
- - checkArgument(
- - image.getFormat() == ImageFormat.YUV_420_888, "Only supports loading YUV_420_888 Image.");
- - this.image = image;
- - }
- -
- - @Override
- - public MediaImageContainer clone() {
- - throw new UnsupportedOperationException(
- - "android.media.Image is an abstract class and cannot be cloned.");
- - }
- -
- - @Override
- - public Bitmap getBitmap() {
- - throw new UnsupportedOperationException(
- - "Converting an android.media.Image to Bitmap is not supported.");
- - }
- -
- - @Override
- - public TensorBuffer getTensorBuffer(DataType dataType) {
- - throw new UnsupportedOperationException(
- - "Converting an android.media.Image to TesorBuffer is not supported.");
- - }
- -
- - @Override
- - public Image getMediaImage() {
- - return image;
- - }
- -
- - @Override
- - public int getWidth() {
- - return image.getWidth();
- - }
- -
- - @Override
- - public int getHeight() {
- - return image.getHeight();
- - }
- -
- - @Override
- - public ColorSpaceType getColorSpaceType() {
- - return ColorSpaceType.fromImageFormat(image.getFormat());
- - }
- + private final Image image;
- +
- + /**
- + * Creates a {@link MediaImageContainer} object with a YUV_420_888 {@link Image}.
- + *
- + * @throws IllegalArgumentException if the {@link ImageFormat} of {@code image} is not ARGB_8888
- + */
- + static MediaImageContainer create(Image image) {
- + return new MediaImageContainer(image);
- + }
- +
- + private MediaImageContainer(Image image) {
- + checkNotNull(image, "Cannot load null Image.");
- + checkArgument(image.getFormat() == ImageFormat.YUV_420_888,
- + "Only supports loading YUV_420_888 Image.");
- + this.image = image;
- + }
- +
- + @Override
- + public MediaImageContainer clone() {
- + throw new UnsupportedOperationException(
- + "android.media.Image is an abstract class and cannot be cloned.");
- + }
- +
- + @Override
- + public Bitmap getBitmap() {
- + throw new UnsupportedOperationException(
- + "Converting an android.media.Image to Bitmap is not supported.");
- + }
- +
- + @Override
- + public TensorBuffer getTensorBuffer(DataType dataType) {
- + throw new UnsupportedOperationException(
- + "Converting an android.media.Image to TesorBuffer is not supported.");
- + }
- +
- + @Override
- + public Image getMediaImage() {
- + return image;
- + }
- +
- + @Override
- + public int getWidth() {
- + return image.getWidth();
- + }
- +
- + @Override
- + public int getHeight() {
- + return image.getHeight();
- + }
- +
- + @Override
- + public ColorSpaceType getColorSpaceType() {
- + return ColorSpaceType.fromImageFormat(image.getFormat());
- + }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/MlImageAdapter.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/MlImageAdapter.java
- index ed066e5308fb9..03017bf733f02 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/MlImageAdapter.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/MlImageAdapter.java
- @@ -21,91 +21,99 @@ import com.google.android.odml.image.MediaImageExtractor;
- import com.google.android.odml.image.MlImage;
- import com.google.android.odml.image.MlImage.ImageFormat;
- import com.google.auto.value.AutoValue;
- +
- import java.nio.ByteBuffer;
-
- /** Converts {@code MlImage} to {@link TensorImage} and vice versa. */
- public class MlImageAdapter {
- + /** Proxies an {@link ImageFormat} and its equivalent {@link ColorSpaceType}. */
- + @AutoValue
- + abstract static class ImageFormatProxy {
- + abstract ColorSpaceType getColorSpaceType();
-
- - /** Proxies an {@link ImageFormat} and its equivalent {@link ColorSpaceType}. */
- - @AutoValue
- - abstract static class ImageFormatProxy {
- -
- - abstract ColorSpaceType getColorSpaceType();
- + @ImageFormat
- + abstract int getImageFormat();
-
- - @ImageFormat
- - abstract int getImageFormat();
- -
- - static ImageFormatProxy createFromImageFormat(@ImageFormat int format) {
- - switch (format) {
- - case MlImage.IMAGE_FORMAT_RGB:
- - return new AutoValue_MlImageAdapter_ImageFormatProxy(ColorSpaceType.RGB, format);
- - case MlImage.IMAGE_FORMAT_NV12:
- - return new AutoValue_MlImageAdapter_ImageFormatProxy(ColorSpaceType.NV12, format);
- - case MlImage.IMAGE_FORMAT_NV21:
- - return new AutoValue_MlImageAdapter_ImageFormatProxy(ColorSpaceType.NV21, format);
- - case MlImage.IMAGE_FORMAT_YV12:
- - return new AutoValue_MlImageAdapter_ImageFormatProxy(ColorSpaceType.YV12, format);
- - case MlImage.IMAGE_FORMAT_YV21:
- - return new AutoValue_MlImageAdapter_ImageFormatProxy(ColorSpaceType.YV21, format);
- - case MlImage.IMAGE_FORMAT_YUV_420_888:
- - return new AutoValue_MlImageAdapter_ImageFormatProxy(ColorSpaceType.YUV_420_888, format);
- - case MlImage.IMAGE_FORMAT_ALPHA:
- - return new AutoValue_MlImageAdapter_ImageFormatProxy(ColorSpaceType.GRAYSCALE, format);
- - case MlImage.IMAGE_FORMAT_RGBA:
- - case MlImage.IMAGE_FORMAT_JPEG:
- - case MlImage.IMAGE_FORMAT_UNKNOWN:
- - throw new IllegalArgumentException(
- - "Cannot create ColorSpaceType from MlImage format: " + format);
- - default:
- - throw new AssertionError("Illegal @ImageFormat: " + format);
- - }
- + static ImageFormatProxy createFromImageFormat(@ImageFormat int format) {
- + switch (format) {
- + case MlImage.IMAGE_FORMAT_RGB:
- + return new AutoValue_MlImageAdapter_ImageFormatProxy(
- + ColorSpaceType.RGB, format);
- + case MlImage.IMAGE_FORMAT_NV12:
- + return new AutoValue_MlImageAdapter_ImageFormatProxy(
- + ColorSpaceType.NV12, format);
- + case MlImage.IMAGE_FORMAT_NV21:
- + return new AutoValue_MlImageAdapter_ImageFormatProxy(
- + ColorSpaceType.NV21, format);
- + case MlImage.IMAGE_FORMAT_YV12:
- + return new AutoValue_MlImageAdapter_ImageFormatProxy(
- + ColorSpaceType.YV12, format);
- + case MlImage.IMAGE_FORMAT_YV21:
- + return new AutoValue_MlImageAdapter_ImageFormatProxy(
- + ColorSpaceType.YV21, format);
- + case MlImage.IMAGE_FORMAT_YUV_420_888:
- + return new AutoValue_MlImageAdapter_ImageFormatProxy(
- + ColorSpaceType.YUV_420_888, format);
- + case MlImage.IMAGE_FORMAT_ALPHA:
- + return new AutoValue_MlImageAdapter_ImageFormatProxy(
- + ColorSpaceType.GRAYSCALE, format);
- + case MlImage.IMAGE_FORMAT_RGBA:
- + case MlImage.IMAGE_FORMAT_JPEG:
- + case MlImage.IMAGE_FORMAT_UNKNOWN:
- + throw new IllegalArgumentException(
- + "Cannot create ColorSpaceType from MlImage format: " + format);
- + default:
- + throw new AssertionError("Illegal @ImageFormat: " + format);
- + }
- + }
- }
- - }
-
- - /**
- - * Creates a {@link TensorImage} from an {@link MlImage}.
- - *
- - * <p>IMPORTANT: The returned {@link TensorImage} shares storage with {@code mlImage}, so do not
- - * modify the contained object in the {@link TensorImage}, as {@code MlImage} expects its
- - * contained data are immutable. Also, callers should use {@code MlImage#getInternal()#acquire()}
- - * and {@code MlImage#release()} to avoid the {@code mlImage} being released unexpectedly.
- - *
- - * @throws IllegalArgumentException if the {@code mlImage} is built from an unsupported container.
- - */
- - public static TensorImage createTensorImageFrom(MlImage mlImage) {
- - // TODO(b/190670174): Choose the best storage from multiple containers.
- - com.google.android.odml.image.ImageProperties mlImageProperties =
- - mlImage.getContainedImageProperties().get(0);
- - switch (mlImageProperties.getStorageType()) {
- - case MlImage.STORAGE_TYPE_BITMAP:
- - return TensorImage.fromBitmap(BitmapExtractor.extract(mlImage));
- - case MlImage.STORAGE_TYPE_MEDIA_IMAGE:
- - TensorImage mediaTensorImage = new TensorImage();
- - mediaTensorImage.load(MediaImageExtractor.extract(mlImage));
- - return mediaTensorImage;
- - case MlImage.STORAGE_TYPE_BYTEBUFFER:
- - ByteBuffer buffer = ByteBufferExtractor.extract(mlImage);
- - ImageFormatProxy formatProxy =
- - ImageFormatProxy.createFromImageFormat(mlImageProperties.getImageFormat());
- - TensorImage byteBufferTensorImage = new TensorImage();
- - ImageProperties properties =
- - ImageProperties.builder()
- - .setColorSpaceType(formatProxy.getColorSpaceType())
- - .setHeight(mlImage.getHeight())
- - .setWidth(mlImage.getWidth())
- - .build();
- - byteBufferTensorImage.load(buffer, properties);
- - return byteBufferTensorImage;
- - default:
- - throw new IllegalArgumentException(
- - "Illegal storage type: " + mlImageProperties.getStorageType());
- + /**
- + * Creates a {@link TensorImage} from an {@link MlImage}.
- + *
- + * <p>IMPORTANT: The returned {@link TensorImage} shares storage with {@code mlImage}, so do not
- + * modify the contained object in the {@link TensorImage}, as {@code MlImage} expects its
- + * contained data are immutable. Also, callers should use {@code
- + * MlImage#getInternal()#acquire()} and {@code MlImage#release()} to avoid the {@code mlImage}
- + * being released unexpectedly.
- + *
- + * @throws IllegalArgumentException if the {@code mlImage} is built from an unsupported
- + * container.
- + */
- + public static TensorImage createTensorImageFrom(MlImage mlImage) {
- + // TODO(b/190670174): Choose the best storage from multiple containers.
- + com.google.android.odml.image.ImageProperties mlImageProperties =
- + mlImage.getContainedImageProperties().get(0);
- + switch (mlImageProperties.getStorageType()) {
- + case MlImage.STORAGE_TYPE_BITMAP:
- + return TensorImage.fromBitmap(BitmapExtractor.extract(mlImage));
- + case MlImage.STORAGE_TYPE_MEDIA_IMAGE:
- + TensorImage mediaTensorImage = new TensorImage();
- + mediaTensorImage.load(MediaImageExtractor.extract(mlImage));
- + return mediaTensorImage;
- + case MlImage.STORAGE_TYPE_BYTEBUFFER:
- + ByteBuffer buffer = ByteBufferExtractor.extract(mlImage);
- + ImageFormatProxy formatProxy =
- + ImageFormatProxy.createFromImageFormat(mlImageProperties.getImageFormat());
- + TensorImage byteBufferTensorImage = new TensorImage();
- + ImageProperties properties =
- + ImageProperties.builder()
- + .setColorSpaceType(formatProxy.getColorSpaceType())
- + .setHeight(mlImage.getHeight())
- + .setWidth(mlImage.getWidth())
- + .build();
- + byteBufferTensorImage.load(buffer, properties);
- + return byteBufferTensorImage;
- + default:
- + throw new IllegalArgumentException(
- + "Illegal storage type: " + mlImageProperties.getStorageType());
- + }
- }
- - }
-
- - /** Creatas a {@link ColorSpaceType} from {@code MlImage.ImageFormat}. */
- - public static ColorSpaceType createColorSpaceTypeFrom(@ImageFormat int imageFormat) {
- - return ImageFormatProxy.createFromImageFormat(imageFormat).getColorSpaceType();
- - }
- + /** Creatas a {@link ColorSpaceType} from {@code MlImage.ImageFormat}. */
- + public static ColorSpaceType createColorSpaceTypeFrom(@ImageFormat int imageFormat) {
- + return ImageFormatProxy.createFromImageFormat(imageFormat).getColorSpaceType();
- + }
-
- - private MlImageAdapter() {}
- + private MlImageAdapter() {}
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/TensorBufferContainer.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/TensorBufferContainer.java
- index 39e2ceb9db521..6dfef70ba67f7 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/TensorBufferContainer.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/TensorBufferContainer.java
- @@ -20,118 +20,108 @@ import static org.tensorflow.lite.support.common.internal.SupportPreconditions.c
- import android.graphics.Bitmap;
- import android.media.Image;
- import android.util.Log;
- +
- import org.tensorflow.lite.DataType;
- import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
-
- /** Holds a {@link TensorBuffer} and converts it to other image formats as needed. */
- final class TensorBufferContainer implements ImageContainer {
- + private final TensorBuffer buffer;
- + private final ColorSpaceType colorSpaceType;
- + private final int height;
- + private final int width;
- + private static final String TAG = TensorBufferContainer.class.getSimpleName();
- +
- + /**
- + * Creates a {@link TensorBufferContainer} object with the specified {@link
- + * TensorImage#ColorSpaceType}.
- + *
- + * <p>Only supports {@link ColorSapceType#RGB} and {@link ColorSpaceType#GRAYSCALE}. Use {@link
- + * #create(TensorBuffer, ImageProperties)} for other color space types.
- + *
- + * @throws IllegalArgumentException if the shape of the {@link TensorBuffer} does not match the
- + * specified color space type, or if the color space type is not supported
- + */
- + static TensorBufferContainer create(TensorBuffer buffer, ColorSpaceType colorSpaceType) {
- + checkArgument(
- + colorSpaceType == ColorSpaceType.RGB || colorSpaceType == ColorSpaceType.GRAYSCALE,
- + "Only ColorSpaceType.RGB and ColorSpaceType.GRAYSCALE are supported. Use"
- + + " `create(TensorBuffer, ImageProperties)` for other color space types.");
- +
- + return new TensorBufferContainer(buffer, colorSpaceType,
- + colorSpaceType.getHeight(buffer.getShape()),
- + colorSpaceType.getWidth(buffer.getShape()));
- + }
-
- - private final TensorBuffer buffer;
- - private final ColorSpaceType colorSpaceType;
- - private final int height;
- - private final int width;
- - private static final String TAG = TensorBufferContainer.class.getSimpleName();
- -
- - /**
- - * Creates a {@link TensorBufferContainer} object with the specified {@link
- - * TensorImage#ColorSpaceType}.
- - *
- - * <p>Only supports {@link ColorSapceType#RGB} and {@link ColorSpaceType#GRAYSCALE}. Use {@link
- - * #create(TensorBuffer, ImageProperties)} for other color space types.
- - *
- - * @throws IllegalArgumentException if the shape of the {@link TensorBuffer} does not match the
- - * specified color space type, or if the color space type is not supported
- - */
- - static TensorBufferContainer create(TensorBuffer buffer, ColorSpaceType colorSpaceType) {
- - checkArgument(
- - colorSpaceType == ColorSpaceType.RGB || colorSpaceType == ColorSpaceType.GRAYSCALE,
- - "Only ColorSpaceType.RGB and ColorSpaceType.GRAYSCALE are supported. Use"
- - + " `create(TensorBuffer, ImageProperties)` for other color space types.");
- -
- - return new TensorBufferContainer(
- - buffer,
- - colorSpaceType,
- - colorSpaceType.getHeight(buffer.getShape()),
- - colorSpaceType.getWidth(buffer.getShape()));
- - }
- -
- - static TensorBufferContainer create(TensorBuffer buffer, ImageProperties imageProperties) {
- - return new TensorBufferContainer(
- - buffer,
- - imageProperties.getColorSpaceType(),
- - imageProperties.getHeight(),
- - imageProperties.getWidth());
- - }
- -
- - private TensorBufferContainer(
- - TensorBuffer buffer, ColorSpaceType colorSpaceType, int height, int width) {
- - checkArgument(
- - colorSpaceType != ColorSpaceType.YUV_420_888,
- - "The actual encoding format of YUV420 is required. Choose a ColorSpaceType from: NV12,"
- - + " NV21, YV12, YV21. Use YUV_420_888 only when loading an android.media.Image.");
- -
- - colorSpaceType.assertNumElements(buffer.getFlatSize(), height, width);
- - this.buffer = buffer;
- - this.colorSpaceType = colorSpaceType;
- - this.height = height;
- - this.width = width;
- - }
- -
- - @Override
- - public TensorBufferContainer clone() {
- - return new TensorBufferContainer(
- - TensorBuffer.createFrom(buffer, buffer.getDataType()),
- - colorSpaceType,
- - getHeight(),
- - getWidth());
- - }
- -
- - @Override
- - public Bitmap getBitmap() {
- - if (buffer.getDataType() != DataType.UINT8) {
- - // Print warning instead of throwing an exception. When using float models, users may want to
- - // convert the resulting float image into Bitmap. That's fine to do so, as long as they are
- - // aware of the potential accuracy lost when casting to uint8.
- - Log.w(
- - TAG,
- - "<Warning> TensorBufferContainer is holding a non-uint8 image. The conversion to Bitmap"
- - + " will cause numeric casting and clamping on the data value.");
- + static TensorBufferContainer create(TensorBuffer buffer, ImageProperties imageProperties) {
- + return new TensorBufferContainer(buffer, imageProperties.getColorSpaceType(),
- + imageProperties.getHeight(), imageProperties.getWidth());
- }
-
- - return colorSpaceType.convertTensorBufferToBitmap(buffer);
- - }
- -
- - @Override
- - public TensorBuffer getTensorBuffer(DataType dataType) {
- - // If the data type of buffer is desired, return it directly. Not making a defensive copy for
- - // performance considerations. During image processing, users may need to set and get the
- - // TensorBuffer many times.
- - // Otherwise, create another one with the expected data type.
- - return buffer.getDataType() == dataType ? buffer : TensorBuffer.createFrom(buffer, dataType);
- - }
- -
- - @Override
- - public Image getMediaImage() {
- - throw new UnsupportedOperationException(
- - "Converting from TensorBuffer to android.media.Image is unsupported.");
- - }
- -
- - @Override
- - public int getWidth() {
- - // In case the underlying buffer in Tensorbuffer gets updated after TensorImage is created.
- - colorSpaceType.assertNumElements(buffer.getFlatSize(), height, width);
- - return width;
- - }
- -
- - @Override
- - public int getHeight() {
- - // In case the underlying buffer in Tensorbuffer gets updated after TensorImage is created.
- - colorSpaceType.assertNumElements(buffer.getFlatSize(), height, width);
- - return height;
- - }
- -
- - @Override
- - public ColorSpaceType getColorSpaceType() {
- - return colorSpaceType;
- - }
- + private TensorBufferContainer(
- + TensorBuffer buffer, ColorSpaceType colorSpaceType, int height, int width) {
- + checkArgument(colorSpaceType != ColorSpaceType.YUV_420_888,
- + "The actual encoding format of YUV420 is required. Choose a ColorSpaceType from: NV12,"
- + + " NV21, YV12, YV21. Use YUV_420_888 only when loading an android.media.Image.");
- +
- + colorSpaceType.assertNumElements(buffer.getFlatSize(), height, width);
- + this.buffer = buffer;
- + this.colorSpaceType = colorSpaceType;
- + this.height = height;
- + this.width = width;
- + }
- +
- + @Override
- + public TensorBufferContainer clone() {
- + return new TensorBufferContainer(TensorBuffer.createFrom(buffer, buffer.getDataType()),
- + colorSpaceType, getHeight(), getWidth());
- + }
- +
- + @Override
- + public Bitmap getBitmap() {
- + if (buffer.getDataType() != DataType.UINT8) {
- + // Print warning instead of throwing an exception. When using float models, users may
- + // want to convert the resulting float image into Bitmap. That's fine to do so, as long
- + // as they are aware of the potential accuracy lost when casting to uint8.
- + Log.w(TAG,
- + "<Warning> TensorBufferContainer is holding a non-uint8 image. The conversion to Bitmap"
- + + " will cause numeric casting and clamping on the data value.");
- + }
- +
- + return colorSpaceType.convertTensorBufferToBitmap(buffer);
- + }
- +
- + @Override
- + public TensorBuffer getTensorBuffer(DataType dataType) {
- + // If the data type of buffer is desired, return it directly. Not making a defensive copy
- + // for performance considerations. During image processing, users may need to set and get
- + // the TensorBuffer many times. Otherwise, create another one with the expected data type.
- + return buffer.getDataType() == dataType ? buffer
- + : TensorBuffer.createFrom(buffer, dataType);
- + }
- +
- + @Override
- + public Image getMediaImage() {
- + throw new UnsupportedOperationException(
- + "Converting from TensorBuffer to android.media.Image is unsupported.");
- + }
- +
- + @Override
- + public int getWidth() {
- + // In case the underlying buffer in Tensorbuffer gets updated after TensorImage is created.
- + colorSpaceType.assertNumElements(buffer.getFlatSize(), height, width);
- + return width;
- + }
- +
- + @Override
- + public int getHeight() {
- + // In case the underlying buffer in Tensorbuffer gets updated after TensorImage is created.
- + colorSpaceType.assertNumElements(buffer.getFlatSize(), height, width);
- + return height;
- + }
- +
- + @Override
- + public ColorSpaceType getColorSpaceType() {
- + return colorSpaceType;
- + }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/TensorImage.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/TensorImage.java
- index 1624971817aba..83cf4c0f648b2 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/TensorImage.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/TensorImage.java
- @@ -19,10 +19,12 @@ import static org.tensorflow.lite.support.common.internal.SupportPreconditions.c
-
- import android.graphics.Bitmap;
- import android.media.Image;
- -import java.nio.ByteBuffer;
- +
- import org.tensorflow.lite.DataType;
- import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
-
- +import java.nio.ByteBuffer;
- +
- /**
- * TensorImage is the wrapper class for Image object. When using image processing utils in
- * TFLite.support library, it's common to convert image objects in variant types to TensorImage at
- @@ -49,350 +51,357 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
- // TODO(b/138907116): Support loading images from TensorBuffer with properties.
- // TODO(b/138905544): Support directly loading RGBBytes, YUVBytes and other types if necessary.
- public class TensorImage {
- + private final DataType dataType;
- + private ImageContainer container = null;
- +
- + /**
- + * Initializes a {@link TensorImage} object.
- + *
- + * <p>Note: the data type of this {@link TensorImage} is {@link DataType#UINT8}. Use {@link
- + * #TensorImage(DataType)} if other data types are preferred.
- + */
- + public TensorImage() {
- + this(DataType.UINT8);
- + }
- +
- + /**
- + * Initializes a {@link TensorImage} object with the specified data type.
- + *
- + * <p>When getting a {@link TensorBuffer} or a {@link ByteBuffer} from this {@link TensorImage},
- + * such as using {@link #getTensorBuffer} and {@link #getBuffer}, the data values will be
- + * converted to the specified data type.
- + *
- + * <p>Note: the shape of a {@link TensorImage} is not fixed. It can be adjusted to the shape of
- + * the image being loaded to this {@link TensorImage}.
- + *
- + * @param dataType the expected data type of the resulting {@link TensorBuffer}. The type is
- + * always fixed during the lifetime of the {@link TensorImage}. To convert the data type,
- + * use
- + * {@link #createFrom(TensorImage, DataType)} to create a copy and convert data type at the
- + * same time.
- + * @throws IllegalArgumentException if {@code dataType} is neither {@link DataType#UINT8} nor
- + * {@link DataType#FLOAT32}
- + */
- + public TensorImage(DataType dataType) {
- + checkArgument(dataType == DataType.UINT8 || dataType == DataType.FLOAT32,
- + "Illegal data type for TensorImage: Only FLOAT32 and UINT8 are accepted");
- + this.dataType = dataType;
- + }
- +
- + /**
- + * Initializes a {@link TensorImage} object of {@link DataType#UINT8} with a {@link
- + * android.graphics.Bitmap} .
- + *
- + * @see #load(Bitmap) for reusing the object when it's expensive to create objects frequently,
- + * because every call of {@code fromBitmap} creates a new {@link TensorImage}.
- + */
- + public static TensorImage fromBitmap(Bitmap bitmap) {
- + TensorImage image = new TensorImage();
- + image.load(bitmap);
- + return image;
- + }
- +
- + /**
- + * Creates a deep-copy of a given {@link TensorImage} with the desired data type.
- + *
- + * @param src the {@link TensorImage} to copy from
- + * @param dataType the expected data type of newly created {@link TensorImage}
- + * @return a {@link TensorImage} whose data is copied from {@code src} and data type is {@code
- + * dataType}
- + */
- + public static TensorImage createFrom(TensorImage src, DataType dataType) {
- + TensorImage dst = new TensorImage(dataType);
- + dst.container = src.container.clone();
- + return dst;
- + }
- +
- + /**
- + * Loads a {@link android.graphics.Bitmap} image object into this {@link TensorImage}.
- + *
- + * <p>Note: if the {@link TensorImage} has data type other than {@link DataType#UINT8}, numeric
- + * casting and clamping will be applied when calling {@link #getTensorBuffer} and {@link
- + * #getBuffer}, where the {@link android.graphics.Bitmap} will be converted into a {@link
- + * TensorBuffer}.
- + *
- + * <p>Important: when loading a bitmap, DO NOT MODIFY the bitmap from the caller side anymore.
- + * The
- + * {@link TensorImage} object will rely on the bitmap. It will probably modify the bitmap as
- + * well. In this method, we perform a zero-copy approach for that bitmap, by simply holding its
- + * reference. Use {@code bitmap.copy(bitmap.getConfig(), true)} to create a copy if necessary.
- + *
- + * <p>Note: to get the best performance, please load images in the same shape to avoid memory
- + * re-allocation.
- + *
- + * @throws IllegalArgumentException if {@code bitmap} is not in ARGB_8888
- + */
- + public void load(Bitmap bitmap) {
- + container = BitmapContainer.create(bitmap);
- + }
- +
- + /**
- + * Loads a float array as RGB pixels into this {@link TensorImage}, representing the pixels
- + * inside.
- + *
- + * <p>Note: if the {@link TensorImage} has a data type other than {@link DataType#FLOAT32},
- + * numeric casting and clamping will be applied when calling {@link #getTensorBuffer} and {@link
- + * #getBuffer}.
- + *
- + * @param pixels the RGB pixels representing the image
- + * @param shape the shape of the image, should either in form (h, w, 3), or in form (1, h, w, 3)
- + * @throws IllegalArgumentException if the shape is neither (h, w, 3) nor (1, h, w, 3)
- + */
- + public void load(float[] pixels, int[] shape) {
- + TensorBuffer buffer = TensorBuffer.createDynamic(getDataType());
- + buffer.loadArray(pixels, shape);
- + load(buffer);
- + }
-
- - private final DataType dataType;
- - private ImageContainer container = null;
- -
- - /**
- - * Initializes a {@link TensorImage} object.
- - *
- - * <p>Note: the data type of this {@link TensorImage} is {@link DataType#UINT8}. Use {@link
- - * #TensorImage(DataType)} if other data types are preferred.
- - */
- - public TensorImage() {
- - this(DataType.UINT8);
- - }
- -
- - /**
- - * Initializes a {@link TensorImage} object with the specified data type.
- - *
- - * <p>When getting a {@link TensorBuffer} or a {@link ByteBuffer} from this {@link TensorImage},
- - * such as using {@link #getTensorBuffer} and {@link #getBuffer}, the data values will be
- - * converted to the specified data type.
- - *
- - * <p>Note: the shape of a {@link TensorImage} is not fixed. It can be adjusted to the shape of
- - * the image being loaded to this {@link TensorImage}.
- - *
- - * @param dataType the expected data type of the resulting {@link TensorBuffer}. The type is
- - * always fixed during the lifetime of the {@link TensorImage}. To convert the data type, use
- - * {@link #createFrom(TensorImage, DataType)} to create a copy and convert data type at the
- - * same time.
- - * @throws IllegalArgumentException if {@code dataType} is neither {@link DataType#UINT8} nor
- - * {@link DataType#FLOAT32}
- - */
- - public TensorImage(DataType dataType) {
- - checkArgument(
- - dataType == DataType.UINT8 || dataType == DataType.FLOAT32,
- - "Illegal data type for TensorImage: Only FLOAT32 and UINT8 are accepted");
- - this.dataType = dataType;
- - }
- -
- - /**
- - * Initializes a {@link TensorImage} object of {@link DataType#UINT8} with a {@link
- - * android.graphics.Bitmap} .
- - *
- - * @see #load(Bitmap) for reusing the object when it's expensive to create objects frequently,
- - * because every call of {@code fromBitmap} creates a new {@link TensorImage}.
- - */
- - public static TensorImage fromBitmap(Bitmap bitmap) {
- - TensorImage image = new TensorImage();
- - image.load(bitmap);
- - return image;
- - }
- -
- - /**
- - * Creates a deep-copy of a given {@link TensorImage} with the desired data type.
- - *
- - * @param src the {@link TensorImage} to copy from
- - * @param dataType the expected data type of newly created {@link TensorImage}
- - * @return a {@link TensorImage} whose data is copied from {@code src} and data type is {@code
- - * dataType}
- - */
- - public static TensorImage createFrom(TensorImage src, DataType dataType) {
- - TensorImage dst = new TensorImage(dataType);
- - dst.container = src.container.clone();
- - return dst;
- - }
- -
- - /**
- - * Loads a {@link android.graphics.Bitmap} image object into this {@link TensorImage}.
- - *
- - * <p>Note: if the {@link TensorImage} has data type other than {@link DataType#UINT8}, numeric
- - * casting and clamping will be applied when calling {@link #getTensorBuffer} and {@link
- - * #getBuffer}, where the {@link android.graphics.Bitmap} will be converted into a {@link
- - * TensorBuffer}.
- - *
- - * <p>Important: when loading a bitmap, DO NOT MODIFY the bitmap from the caller side anymore. The
- - * {@link TensorImage} object will rely on the bitmap. It will probably modify the bitmap as well.
- - * In this method, we perform a zero-copy approach for that bitmap, by simply holding its
- - * reference. Use {@code bitmap.copy(bitmap.getConfig(), true)} to create a copy if necessary.
- - *
- - * <p>Note: to get the best performance, please load images in the same shape to avoid memory
- - * re-allocation.
- - *
- - * @throws IllegalArgumentException if {@code bitmap} is not in ARGB_8888
- - */
- - public void load(Bitmap bitmap) {
- - container = BitmapContainer.create(bitmap);
- - }
- -
- - /**
- - * Loads a float array as RGB pixels into this {@link TensorImage}, representing the pixels
- - * inside.
- - *
- - * <p>Note: if the {@link TensorImage} has a data type other than {@link DataType#FLOAT32},
- - * numeric casting and clamping will be applied when calling {@link #getTensorBuffer} and {@link
- - * #getBuffer}.
- - *
- - * @param pixels the RGB pixels representing the image
- - * @param shape the shape of the image, should either in form (h, w, 3), or in form (1, h, w, 3)
- - * @throws IllegalArgumentException if the shape is neither (h, w, 3) nor (1, h, w, 3)
- - */
- - public void load(float[] pixels, int[] shape) {
- - TensorBuffer buffer = TensorBuffer.createDynamic(getDataType());
- - buffer.loadArray(pixels, shape);
- - load(buffer);
- - }
- -
- - /**
- - * Loads an int array as RGB pixels into this {@link TensorImage}, representing the pixels inside.
- - *
- - * <p>Note: numeric casting and clamping will be applied to convert the values into the data type
- - * of this {@link TensorImage} when calling {@link #getTensorBuffer} and {@link #getBuffer}.
- - *
- - * @param pixels the RGB pixels representing the image
- - * @param shape the shape of the image, should either in form (h, w, 3), or in form (1, h, w, 3)
- - * @throws IllegalArgumentException if the shape is neither (h, w, 3) nor (1, h, w, 3)
- - */
- - public void load(int[] pixels, int[] shape) {
- - TensorBuffer buffer = TensorBuffer.createDynamic(getDataType());
- - buffer.loadArray(pixels, shape);
- - load(buffer);
- - }
- -
- - /**
- - * Loads a {@link TensorBuffer} containing pixel values. The color layout should be RGB.
- - *
- - * <p>Note: if the data type of {@code buffer} does not match that of this {@link TensorImage},
- - * numeric casting and clamping will be applied when calling {@link #getTensorBuffer} and {@link
- - * #getBuffer}.
- - *
- - * @param buffer the {@link TensorBuffer} to be loaded. Its shape should be either (h, w, 3) or
- - * (1, h, w, 3)
- - * @throws IllegalArgumentException if the shape is neither (h, w, 3) nor (1, h, w, 3)
- - */
- - public void load(TensorBuffer buffer) {
- - load(buffer, ColorSpaceType.RGB);
- - }
- -
- - /**
- - * Loads a {@link TensorBuffer} containing pixel values with the specific {@link ColorSpaceType}.
- - *
- - * <p>Only supports {@link ColorSpaceType#RGB} and {@link ColorSpaceType#GRAYSCALE}. Use {@link
- - * #load(TensorBuffer, ImageProperties)} for other color space types.
- - *
- - * <p>Note: if the data type of {@code buffer} does not match that of this {@link TensorImage},
- - * numeric casting and clamping will be applied when calling {@link #getTensorBuffer} and {@link
- - * #getBuffer}.
- - *
- - * @param buffer the {@link TensorBuffer} to be loaded. Its shape should be either (h, w, 3) or
- - * (1, h, w, 3) for RGB images, and either (h, w) or (1, h, w) for GRAYSCALE images
- - * @throws IllegalArgumentException if the shape of buffer does not match the color space type, or
- - * if the color space type is not supported
- - */
- - public void load(TensorBuffer buffer, ColorSpaceType colorSpaceType) {
- - checkArgument(
- - colorSpaceType == ColorSpaceType.RGB || colorSpaceType == ColorSpaceType.GRAYSCALE,
- - "Only ColorSpaceType.RGB and ColorSpaceType.GRAYSCALE are supported. Use"
- - + " `load(TensorBuffer, ImageProperties)` for other color space types.");
- -
- - container = TensorBufferContainer.create(buffer, colorSpaceType);
- - }
- -
- - /**
- - * Loads a {@link TensorBuffer} containing pixel values with the specific {@link ImageProperties}.
- - *
- - * <p>The shape of the {@link TensorBuffer} will not be used to determine image height and width.
- - * Set image properties through {@link ImageProperties}.
- - *
- - * <p>Note: if the data type of {@code buffer} does not match that of this {@link TensorImage},
- - * numeric casting and clamping will be applied when calling {@link #getTensorBuffer} and {@link
- - * #getBuffer}.
- - *
- - * @throws IllegalArgumentException if buffer size is less than the image size indicated by image
- - * height, width, and color space type in {@link ImageProperties}
- - */
- - public void load(TensorBuffer buffer, ImageProperties imageProperties) {
- - container = TensorBufferContainer.create(buffer, imageProperties);
- - }
- -
- - /**
- - * Loads a {@link ByteBuffer} containing pixel values with the specific {@link ImageProperties}.
- - *
- - * <p>Note: if the data type of {@code buffer} does not match that of this {@link TensorImage},
- - * numeric casting and clamping will be applied when calling {@link #getTensorBuffer} and {@link
- - * #getBuffer}.
- - *
- - * @throws IllegalArgumentException if buffer size is less than the image size indicated by image
- - * height, width, and color space type in {@link ImageProperties}
- - */
- - public void load(ByteBuffer buffer, ImageProperties imageProperties) {
- - TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.UINT8);
- - tensorBuffer.loadBuffer(buffer, new int[] {buffer.limit()});
- - container = TensorBufferContainer.create(tensorBuffer, imageProperties);
- - }
- -
- - /**
- - * Loads an {@link android.media.Image} object into this {@link TensorImage}.
- - *
- - * <p>The main usage of this method is to load an {@link android.media.Image} object as model
- - * input to the <a href="TFLite Task
- - * Library">https://www.tensorflow.org/lite/inference_with_metadata/task_library/overview</a>.
- - * {@link TensorImage} backed by {@link android.media.Image} is not supported by {@link
- - * ImageProcessor}.
- - *
- - * <p>* @throws IllegalArgumentException if the {@link android.graphics.ImageFormat} of {@code
- - * image} is not YUV_420_888
- - */
- - public void load(Image image) {
- - container = MediaImageContainer.create(image);
- - }
- -
- - /**
- - * Returns a {@link android.graphics.Bitmap} representation of this {@link TensorImage}.
- - *
- - * <p>Numeric casting and clamping will be applied if the stored data is not uint8.
- - *
- - * <p>Note that, the reliable way to get pixels from an {@code ALPHA_8} Bitmap is to use {@code
- - * copyPixelsToBuffer}. Bitmap methods such as, `setPixels()` and `getPixels` do not work.
- - *
- - * <p>Important: it's only a reference. DO NOT MODIFY. We don't create a copy here for performance
- - * concern, but if modification is necessary, please make a copy.
- - *
- - * @return a reference to a {@link android.graphics.Bitmap} in {@code ARGB_8888} config ("A"
- - * channel is always opaque) or in {@code ALPHA_8}, depending on the {@link ColorSpaceType} of
- - * this {@link TensorBuffer}.
- - * @throws IllegalStateException if the {@link TensorImage} never loads data
- - */
- - public Bitmap getBitmap() {
- - if (container == null) {
- - throw new IllegalStateException("No image has been loaded yet.");
- + /**
- + * Loads an int array as RGB pixels into this {@link TensorImage}, representing the pixels
- + * inside.
- + *
- + * <p>Note: numeric casting and clamping will be applied to convert the values into the data
- + * type of this {@link TensorImage} when calling {@link #getTensorBuffer} and {@link
- + * #getBuffer}.
- + *
- + * @param pixels the RGB pixels representing the image
- + * @param shape the shape of the image, should either in form (h, w, 3), or in form (1, h, w, 3)
- + * @throws IllegalArgumentException if the shape is neither (h, w, 3) nor (1, h, w, 3)
- + */
- + public void load(int[] pixels, int[] shape) {
- + TensorBuffer buffer = TensorBuffer.createDynamic(getDataType());
- + buffer.loadArray(pixels, shape);
- + load(buffer);
- }
-
- - return container.getBitmap();
- - }
- -
- - /**
- - * Returns a {@link ByteBuffer} representation of this {@link TensorImage} with the expected data
- - * type.
- - *
- - * <p>Numeric casting and clamping will be applied if the stored data is different from the data
- - * type of the {@link TensorImage}.
- - *
- - * <p>Important: it's only a reference. DO NOT MODIFY. We don't create a copy here for performance
- - * concern, but if modification is necessary, please make a copy.
- - *
- - * <p>It's essentially a short cut for {@code getTensorBuffer().getBuffer()}.
- - *
- - * @return a reference to a {@link ByteBuffer} which holds the image data
- - * @throws IllegalStateException if the {@link TensorImage} never loads data
- - */
- - public ByteBuffer getBuffer() {
- - return getTensorBuffer().getBuffer();
- - }
- -
- - /**
- - * Returns a {@link TensorBuffer} representation of this {@link TensorImage} with the expected
- - * data type.
- - *
- - * <p>Numeric casting and clamping will be applied if the stored data is different from the data
- - * type of the {@link TensorImage}.
- - *
- - * <p>Important: it's only a reference. DO NOT MODIFY. We don't create a copy here for performance
- - * concern, but if modification is necessary, please make a copy.
- - *
- - * @return a reference to a {@link TensorBuffer} which holds the image data
- - * @throws IllegalStateException if the {@link TensorImage} never loads data
- - */
- - public TensorBuffer getTensorBuffer() {
- - if (container == null) {
- - throw new IllegalStateException("No image has been loaded yet.");
- + /**
- + * Loads a {@link TensorBuffer} containing pixel values. The color layout should be RGB.
- + *
- + * <p>Note: if the data type of {@code buffer} does not match that of this {@link TensorImage},
- + * numeric casting and clamping will be applied when calling {@link #getTensorBuffer} and {@link
- + * #getBuffer}.
- + *
- + * @param buffer the {@link TensorBuffer} to be loaded. Its shape should be either (h, w, 3) or
- + * (1, h, w, 3)
- + * @throws IllegalArgumentException if the shape is neither (h, w, 3) nor (1, h, w, 3)
- + */
- + public void load(TensorBuffer buffer) {
- + load(buffer, ColorSpaceType.RGB);
- }
-
- - return container.getTensorBuffer(dataType);
- - }
- -
- - /**
- - * Returns an {@link android.media.Image} representation of this {@link TensorImage}.
- - *
- - * <p>This method only works when the {@link TensorImage} is backed by an {@link
- - * android.media.Image}, meaning you need to first load an {@link android.media.Image} through
- - * {@link #load(Image)}.
- - *
- - * <p>Important: it's only a reference. DO NOT MODIFY. We don't create a copy here for performance
- - * concern, but if modification is necessary, please make a copy.
- - *
- - * @return a reference to a {@link android.graphics.Bitmap} in {@code ARGB_8888} config ("A"
- - * channel is always opaque) or in {@code ALPHA_8}, depending on the {@link ColorSpaceType} of
- - * this {@link TensorBuffer}.
- - * @throws IllegalStateException if the {@link TensorImage} never loads data
- - */
- - public Image getMediaImage() {
- - if (container == null) {
- - throw new IllegalStateException("No image has been loaded yet.");
- + /**
- + * Loads a {@link TensorBuffer} containing pixel values with the specific {@link
- + * ColorSpaceType}.
- + *
- + * <p>Only supports {@link ColorSpaceType#RGB} and {@link ColorSpaceType#GRAYSCALE}. Use {@link
- + * #load(TensorBuffer, ImageProperties)} for other color space types.
- + *
- + * <p>Note: if the data type of {@code buffer} does not match that of this {@link TensorImage},
- + * numeric casting and clamping will be applied when calling {@link #getTensorBuffer} and {@link
- + * #getBuffer}.
- + *
- + * @param buffer the {@link TensorBuffer} to be loaded. Its shape should be either (h, w, 3) or
- + * (1, h, w, 3) for RGB images, and either (h, w) or (1, h, w) for GRAYSCALE images
- + * @throws IllegalArgumentException if the shape of buffer does not match the color space type,
- + * or
- + * if the color space type is not supported
- + */
- + public void load(TensorBuffer buffer, ColorSpaceType colorSpaceType) {
- + checkArgument(
- + colorSpaceType == ColorSpaceType.RGB || colorSpaceType == ColorSpaceType.GRAYSCALE,
- + "Only ColorSpaceType.RGB and ColorSpaceType.GRAYSCALE are supported. Use"
- + + " `load(TensorBuffer, ImageProperties)` for other color space types.");
- +
- + container = TensorBufferContainer.create(buffer, colorSpaceType);
- }
-
- - return container.getMediaImage();
- - }
- -
- - /**
- - * Gets the data type of this {@link TensorImage}.
- - *
- - * @return a data type. Currently only {@link DataType#UINT8} and {@link DataType#FLOAT32} are
- - * supported.
- - */
- - public DataType getDataType() {
- - return dataType;
- - }
- -
- - /**
- - * Gets the color space type of this {@link TensorImage}.
- - *
- - * @throws IllegalStateException if the {@link TensorImage} never loads data
- - */
- - public ColorSpaceType getColorSpaceType() {
- - if (container == null) {
- - throw new IllegalStateException("No image has been loaded yet.");
- + /**
- + * Loads a {@link TensorBuffer} containing pixel values with the specific {@link
- + * ImageProperties}.
- + *
- + * <p>The shape of the {@link TensorBuffer} will not be used to determine image height and
- + * width. Set image properties through {@link ImageProperties}.
- + *
- + * <p>Note: if the data type of {@code buffer} does not match that of this {@link TensorImage},
- + * numeric casting and clamping will be applied when calling {@link #getTensorBuffer} and {@link
- + * #getBuffer}.
- + *
- + * @throws IllegalArgumentException if buffer size is less than the image size indicated by
- + * image
- + * height, width, and color space type in {@link ImageProperties}
- + */
- + public void load(TensorBuffer buffer, ImageProperties imageProperties) {
- + container = TensorBufferContainer.create(buffer, imageProperties);
- }
-
- - return container.getColorSpaceType();
- - }
- -
- - /**
- - * Gets the image width.
- - *
- - * @throws IllegalStateException if the {@link TensorImage} never loads data
- - * @throws IllegalArgumentException if the underlying data is corrupted
- - */
- - public int getWidth() {
- - if (container == null) {
- - throw new IllegalStateException("No image has been loaded yet.");
- + /**
- + * Loads a {@link ByteBuffer} containing pixel values with the specific {@link ImageProperties}.
- + *
- + * <p>Note: if the data type of {@code buffer} does not match that of this {@link TensorImage},
- + * numeric casting and clamping will be applied when calling {@link #getTensorBuffer} and {@link
- + * #getBuffer}.
- + *
- + * @throws IllegalArgumentException if buffer size is less than the image size indicated by
- + * image
- + * height, width, and color space type in {@link ImageProperties}
- + */
- + public void load(ByteBuffer buffer, ImageProperties imageProperties) {
- + TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.UINT8);
- + tensorBuffer.loadBuffer(buffer, new int[] {buffer.limit()});
- + container = TensorBufferContainer.create(tensorBuffer, imageProperties);
- }
-
- - return container.getWidth();
- - }
- -
- - /**
- - * Gets the image height.
- - *
- - * @throws IllegalStateException if the {@link TensorImage} never loads data
- - * @throws IllegalArgumentException if the underlying data is corrupted
- - */
- - public int getHeight() {
- - if (container == null) {
- - throw new IllegalStateException("No image has been loaded yet.");
- + /**
- + * Loads an {@link android.media.Image} object into this {@link TensorImage}.
- + *
- + * <p>The main usage of this method is to load an {@link android.media.Image} object as model
- + * input to the <a href="TFLite Task
- + * Library">https://www.tensorflow.org/lite/inference_with_metadata/task_library/overview</a>.
- + * {@link TensorImage} backed by {@link android.media.Image} is not supported by {@link
- + * ImageProcessor}.
- + *
- + * <p>* @throws IllegalArgumentException if the {@link android.graphics.ImageFormat} of {@code
- + * image} is not YUV_420_888
- + */
- + public void load(Image image) {
- + container = MediaImageContainer.create(image);
- }
-
- - return container.getHeight();
- - }
- + /**
- + * Returns a {@link android.graphics.Bitmap} representation of this {@link TensorImage}.
- + *
- + * <p>Numeric casting and clamping will be applied if the stored data is not uint8.
- + *
- + * <p>Note that, the reliable way to get pixels from an {@code ALPHA_8} Bitmap is to use {@code
- + * copyPixelsToBuffer}. Bitmap methods such as, `setPixels()` and `getPixels` do not work.
- + *
- + * <p>Important: it's only a reference. DO NOT MODIFY. We don't create a copy here for
- + * performance concern, but if modification is necessary, please make a copy.
- + *
- + * @return a reference to a {@link android.graphics.Bitmap} in {@code ARGB_8888} config ("A"
- + * channel is always opaque) or in {@code ALPHA_8}, depending on the {@link ColorSpaceType}
- + * of this {@link TensorBuffer}.
- + * @throws IllegalStateException if the {@link TensorImage} never loads data
- + */
- + public Bitmap getBitmap() {
- + if (container == null) {
- + throw new IllegalStateException("No image has been loaded yet.");
- + }
- +
- + return container.getBitmap();
- + }
- +
- + /**
- + * Returns a {@link ByteBuffer} representation of this {@link TensorImage} with the expected
- + * data type.
- + *
- + * <p>Numeric casting and clamping will be applied if the stored data is different from the data
- + * type of the {@link TensorImage}.
- + *
- + * <p>Important: it's only a reference. DO NOT MODIFY. We don't create a copy here for
- + * performance concern, but if modification is necessary, please make a copy.
- + *
- + * <p>It's essentially a short cut for {@code getTensorBuffer().getBuffer()}.
- + *
- + * @return a reference to a {@link ByteBuffer} which holds the image data
- + * @throws IllegalStateException if the {@link TensorImage} never loads data
- + */
- + public ByteBuffer getBuffer() {
- + return getTensorBuffer().getBuffer();
- + }
- +
- + /**
- + * Returns a {@link TensorBuffer} representation of this {@link TensorImage} with the expected
- + * data type.
- + *
- + * <p>Numeric casting and clamping will be applied if the stored data is different from the data
- + * type of the {@link TensorImage}.
- + *
- + * <p>Important: it's only a reference. DO NOT MODIFY. We don't create a copy here for
- + * performance concern, but if modification is necessary, please make a copy.
- + *
- + * @return a reference to a {@link TensorBuffer} which holds the image data
- + * @throws IllegalStateException if the {@link TensorImage} never loads data
- + */
- + public TensorBuffer getTensorBuffer() {
- + if (container == null) {
- + throw new IllegalStateException("No image has been loaded yet.");
- + }
- +
- + return container.getTensorBuffer(dataType);
- + }
- +
- + /**
- + * Returns an {@link android.media.Image} representation of this {@link TensorImage}.
- + *
- + * <p>This method only works when the {@link TensorImage} is backed by an {@link
- + * android.media.Image}, meaning you need to first load an {@link android.media.Image} through
- + * {@link #load(Image)}.
- + *
- + * <p>Important: it's only a reference. DO NOT MODIFY. We don't create a copy here for
- + * performance concern, but if modification is necessary, please make a copy.
- + *
- + * @return a reference to a {@link android.graphics.Bitmap} in {@code ARGB_8888} config ("A"
- + * channel is always opaque) or in {@code ALPHA_8}, depending on the {@link ColorSpaceType}
- + * of this {@link TensorBuffer}.
- + * @throws IllegalStateException if the {@link TensorImage} never loads data
- + */
- + public Image getMediaImage() {
- + if (container == null) {
- + throw new IllegalStateException("No image has been loaded yet.");
- + }
- +
- + return container.getMediaImage();
- + }
- +
- + /**
- + * Gets the data type of this {@link TensorImage}.
- + *
- + * @return a data type. Currently only {@link DataType#UINT8} and {@link DataType#FLOAT32} are
- + * supported.
- + */
- + public DataType getDataType() {
- + return dataType;
- + }
- +
- + /**
- + * Gets the color space type of this {@link TensorImage}.
- + *
- + * @throws IllegalStateException if the {@link TensorImage} never loads data
- + */
- + public ColorSpaceType getColorSpaceType() {
- + if (container == null) {
- + throw new IllegalStateException("No image has been loaded yet.");
- + }
- +
- + return container.getColorSpaceType();
- + }
- +
- + /**
- + * Gets the image width.
- + *
- + * @throws IllegalStateException if the {@link TensorImage} never loads data
- + * @throws IllegalArgumentException if the underlying data is corrupted
- + */
- + public int getWidth() {
- + if (container == null) {
- + throw new IllegalStateException("No image has been loaded yet.");
- + }
- +
- + return container.getWidth();
- + }
- +
- + /**
- + * Gets the image height.
- + *
- + * @throws IllegalStateException if the {@link TensorImage} never loads data
- + * @throws IllegalArgumentException if the underlying data is corrupted
- + */
- + public int getHeight() {
- + if (container == null) {
- + throw new IllegalStateException("No image has been loaded yet.");
- + }
- +
- + return container.getHeight();
- + }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/ResizeOp.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/ResizeOp.java
- index 06391de9cc3e0..adccf23dc97f0 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/ResizeOp.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/ResizeOp.java
- @@ -19,6 +19,7 @@ import static org.tensorflow.lite.support.common.internal.SupportPreconditions.c
-
- import android.graphics.Bitmap;
- import android.graphics.PointF;
- +
- import org.checkerframework.checker.nullness.qual.NonNull;
- import org.tensorflow.lite.support.image.ColorSpaceType;
- import org.tensorflow.lite.support.image.ImageOperator;
- @@ -32,64 +33,60 @@ import org.tensorflow.lite.support.image.TensorImage;
- * @see ResizeWithCropOrPadOp for resizing without content distortion.
- */
- public class ResizeOp implements ImageOperator {
- + /** Algorithms for resizing. */
- + public enum ResizeMethod { BILINEAR, NEAREST_NEIGHBOR }
-
- - /** Algorithms for resizing. */
- - public enum ResizeMethod {
- - BILINEAR,
- - NEAREST_NEIGHBOR
- - }
- -
- - private final int targetHeight;
- - private final int targetWidth;
- - private final boolean useBilinear;
- + private final int targetHeight;
- + private final int targetWidth;
- + private final boolean useBilinear;
-
- - /**
- - * Creates a ResizeOp which can resize images to specified size in specified method.
- - *
- - * @param targetHeight The expected height of resized image.
- - * @param targetWidth The expected width of resized image.
- - * @param resizeMethod The algorithm to use for resizing. Options: {@link ResizeMethod}
- - */
- - public ResizeOp(int targetHeight, int targetWidth, ResizeMethod resizeMethod) {
- - this.targetHeight = targetHeight;
- - this.targetWidth = targetWidth;
- - useBilinear = (resizeMethod == ResizeMethod.BILINEAR);
- - }
- + /**
- + * Creates a ResizeOp which can resize images to specified size in specified method.
- + *
- + * @param targetHeight The expected height of resized image.
- + * @param targetWidth The expected width of resized image.
- + * @param resizeMethod The algorithm to use for resizing. Options: {@link ResizeMethod}
- + */
- + public ResizeOp(int targetHeight, int targetWidth, ResizeMethod resizeMethod) {
- + this.targetHeight = targetHeight;
- + this.targetWidth = targetWidth;
- + useBilinear = (resizeMethod == ResizeMethod.BILINEAR);
- + }
-
- - /**
- - * Applies the defined resizing on given image and returns the result.
- - *
- - * <p>Note: the content of input {@code image} will change, and {@code image} is the same instance
- - * with the output.
- - *
- - * @param image input image.
- - * @return output image.
- - */
- - @Override
- - @NonNull
- - public TensorImage apply(@NonNull TensorImage image) {
- - checkArgument(
- - image.getColorSpaceType() == ColorSpaceType.RGB,
- - "Only RGB images are supported in ResizeOp, but not " + image.getColorSpaceType().name());
- - Bitmap scaled =
- - Bitmap.createScaledBitmap(image.getBitmap(), targetWidth, targetHeight, useBilinear);
- - image.load(scaled);
- - return image;
- - }
- + /**
- + * Applies the defined resizing on given image and returns the result.
- + *
- + * <p>Note: the content of input {@code image} will change, and {@code image} is the same
- + * instance with the output.
- + *
- + * @param image input image.
- + * @return output image.
- + */
- + @Override
- + @NonNull
- + public TensorImage apply(@NonNull TensorImage image) {
- + checkArgument(image.getColorSpaceType() == ColorSpaceType.RGB,
- + "Only RGB images are supported in ResizeOp, but not "
- + + image.getColorSpaceType().name());
- + Bitmap scaled = Bitmap.createScaledBitmap(
- + image.getBitmap(), targetWidth, targetHeight, useBilinear);
- + image.load(scaled);
- + return image;
- + }
-
- - @Override
- - public int getOutputImageHeight(int inputImageHeight, int inputImageWidth) {
- - return targetHeight;
- - }
- + @Override
- + public int getOutputImageHeight(int inputImageHeight, int inputImageWidth) {
- + return targetHeight;
- + }
-
- - @Override
- - public int getOutputImageWidth(int inputImageHeight, int inputImageWidth) {
- - return targetWidth;
- - }
- + @Override
- + public int getOutputImageWidth(int inputImageHeight, int inputImageWidth) {
- + return targetWidth;
- + }
-
- - @Override
- - public PointF inverseTransform(PointF point, int inputImageHeight, int inputImageWidth) {
- - return new PointF(
- - point.x * inputImageWidth / targetWidth, point.y * inputImageHeight / targetHeight);
- - }
- + @Override
- + public PointF inverseTransform(PointF point, int inputImageHeight, int inputImageWidth) {
- + return new PointF(
- + point.x * inputImageWidth / targetWidth, point.y * inputImageHeight / targetHeight);
- + }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/ResizeWithCropOrPadOp.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/ResizeWithCropOrPadOp.java
- index 66491090ac9c0..e5de5bbcf50d9 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/ResizeWithCropOrPadOp.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/ResizeWithCropOrPadOp.java
- @@ -22,6 +22,7 @@ import android.graphics.Bitmap.Config;
- import android.graphics.Canvas;
- import android.graphics.PointF;
- import android.graphics.Rect;
- +
- import org.checkerframework.checker.nullness.qual.NonNull;
- import org.tensorflow.lite.support.image.ColorSpaceType;
- import org.tensorflow.lite.support.image.ImageOperator;
- @@ -37,96 +38,95 @@ import org.tensorflow.lite.support.image.TensorImage;
- * @see ResizeOp for reszing images while stretching / compressing the content.
- */
- public class ResizeWithCropOrPadOp implements ImageOperator {
- - private final int targetHeight;
- - private final int targetWidth;
- - private final Bitmap output;
- -
- - /**
- - * Creates a ResizeWithCropOrPadOp which could crop/pad images to specified size. It adopts
- - * center-crop and zero-padding.
- - *
- - * @param targetHeight The expected height of cropped/padded image.
- - * @param targetWidth The expected width of cropped/padded image.
- - */
- - public ResizeWithCropOrPadOp(int targetHeight, int targetWidth) {
- - this.targetHeight = targetHeight;
- - this.targetWidth = targetWidth;
- - output = Bitmap.createBitmap(this.targetWidth, this.targetHeight, Config.ARGB_8888);
- - }
- + private final int targetHeight;
- + private final int targetWidth;
- + private final Bitmap output;
-
- - /**
- - * Applies the defined resizing with cropping or/and padding on given image and returns the
- - * result.
- - *
- - * <p>Note: the content of input {@code image} will change, and {@code image} is the same instance
- - * with the output.
- - *
- - * @param image input image.
- - * @return output image.
- - */
- - @Override
- - @NonNull
- - public TensorImage apply(@NonNull TensorImage image) {
- - checkArgument(
- - image.getColorSpaceType() == ColorSpaceType.RGB,
- - "Only RGB images are supported in ResizeWithCropOrPadOp, but not "
- - + image.getColorSpaceType().name());
- - Bitmap input = image.getBitmap();
- - int srcL;
- - int srcR;
- - int srcT;
- - int srcB;
- - int dstL;
- - int dstR;
- - int dstT;
- - int dstB;
- - int w = input.getWidth();
- - int h = input.getHeight();
- - if (targetWidth > w) { // padding
- - srcL = 0;
- - srcR = w;
- - dstL = (targetWidth - w) / 2;
- - dstR = dstL + w;
- - } else { // cropping
- - dstL = 0;
- - dstR = targetWidth;
- - srcL = (w - targetWidth) / 2;
- - srcR = srcL + targetWidth;
- + /**
- + * Creates a ResizeWithCropOrPadOp which could crop/pad images to specified size. It adopts
- + * center-crop and zero-padding.
- + *
- + * @param targetHeight The expected height of cropped/padded image.
- + * @param targetWidth The expected width of cropped/padded image.
- + */
- + public ResizeWithCropOrPadOp(int targetHeight, int targetWidth) {
- + this.targetHeight = targetHeight;
- + this.targetWidth = targetWidth;
- + output = Bitmap.createBitmap(this.targetWidth, this.targetHeight, Config.ARGB_8888);
- }
- - if (targetHeight > h) { // padding
- - srcT = 0;
- - srcB = h;
- - dstT = (targetHeight - h) / 2;
- - dstB = dstT + h;
- - } else { // cropping
- - dstT = 0;
- - dstB = targetHeight;
- - srcT = (h - targetHeight) / 2;
- - srcB = srcT + targetHeight;
- +
- + /**
- + * Applies the defined resizing with cropping or/and padding on given image and returns the
- + * result.
- + *
- + * <p>Note: the content of input {@code image} will change, and {@code image} is the same
- + * instance with the output.
- + *
- + * @param image input image.
- + * @return output image.
- + */
- + @Override
- + @NonNull
- + public TensorImage apply(@NonNull TensorImage image) {
- + checkArgument(image.getColorSpaceType() == ColorSpaceType.RGB,
- + "Only RGB images are supported in ResizeWithCropOrPadOp, but not "
- + + image.getColorSpaceType().name());
- + Bitmap input = image.getBitmap();
- + int srcL;
- + int srcR;
- + int srcT;
- + int srcB;
- + int dstL;
- + int dstR;
- + int dstT;
- + int dstB;
- + int w = input.getWidth();
- + int h = input.getHeight();
- + if (targetWidth > w) { // padding
- + srcL = 0;
- + srcR = w;
- + dstL = (targetWidth - w) / 2;
- + dstR = dstL + w;
- + } else { // cropping
- + dstL = 0;
- + dstR = targetWidth;
- + srcL = (w - targetWidth) / 2;
- + srcR = srcL + targetWidth;
- + }
- + if (targetHeight > h) { // padding
- + srcT = 0;
- + srcB = h;
- + dstT = (targetHeight - h) / 2;
- + dstB = dstT + h;
- + } else { // cropping
- + dstT = 0;
- + dstB = targetHeight;
- + srcT = (h - targetHeight) / 2;
- + srcB = srcT + targetHeight;
- + }
- + Rect src = new Rect(srcL, srcT, srcR, srcB);
- + Rect dst = new Rect(dstL, dstT, dstR, dstB);
- + new Canvas(output).drawBitmap(input, src, dst, null);
- + image.load(output);
- + return image;
- }
- - Rect src = new Rect(srcL, srcT, srcR, srcB);
- - Rect dst = new Rect(dstL, dstT, dstR, dstB);
- - new Canvas(output).drawBitmap(input, src, dst, null);
- - image.load(output);
- - return image;
- - }
-
- - @Override
- - public int getOutputImageHeight(int inputImageHeight, int inputImageWidth) {
- - return targetHeight;
- - }
- + @Override
- + public int getOutputImageHeight(int inputImageHeight, int inputImageWidth) {
- + return targetHeight;
- + }
-
- - @Override
- - public int getOutputImageWidth(int inputImageHeight, int inputImageWidth) {
- - return targetWidth;
- - }
- + @Override
- + public int getOutputImageWidth(int inputImageHeight, int inputImageWidth) {
- + return targetWidth;
- + }
-
- - @Override
- - public PointF inverseTransform(PointF point, int inputImageHeight, int inputImageWidth) {
- - return transformImpl(point, targetHeight, targetWidth, inputImageHeight, inputImageWidth);
- - }
- + @Override
- + public PointF inverseTransform(PointF point, int inputImageHeight, int inputImageWidth) {
- + return transformImpl(point, targetHeight, targetWidth, inputImageHeight, inputImageWidth);
- + }
-
- - private static PointF transformImpl(PointF point, int srcH, int srcW, int dstH, int dstW) {
- - return new PointF(point.x + (dstW - srcW) / 2, point.y + (dstH - srcH) / 2);
- - }
- + private static PointF transformImpl(PointF point, int srcH, int srcW, int dstH, int dstW) {
- + return new PointF(point.x + (dstW - srcW) / 2, point.y + (dstH - srcH) / 2);
- + }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/Rot90Op.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/Rot90Op.java
- index 849b4bc9ef3db..86413c90c69ca 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/Rot90Op.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/Rot90Op.java
- @@ -20,6 +20,7 @@ import static org.tensorflow.lite.support.common.internal.SupportPreconditions.c
- import android.graphics.Bitmap;
- import android.graphics.Matrix;
- import android.graphics.PointF;
- +
- import org.checkerframework.checker.nullness.qual.NonNull;
- import org.tensorflow.lite.support.image.ColorSpaceType;
- import org.tensorflow.lite.support.image.ImageOperator;
- @@ -27,83 +28,83 @@ import org.tensorflow.lite.support.image.TensorImage;
-
- /** Rotates image counter-clockwise. */
- public class Rot90Op implements ImageOperator {
- + private final int numRotation;
-
- - private final int numRotation;
- -
- - /** Creates a Rot90 Op which will rotate image by 90 degree counter-clockwise. */
- - public Rot90Op() {
- - this(1);
- - }
- + /** Creates a Rot90 Op which will rotate image by 90 degree counter-clockwise. */
- + public Rot90Op() {
- + this(1);
- + }
-
- - /**
- - * Creates a Rot90 Op which will rotate image by 90 degree for {@code k} times counter-clockwise.
- - *
- - * @param k The number of times the image is rotated by 90 degrees. If it's positive, the image
- - * will be rotated counter-clockwise. If it's negative, the op will rotate image clockwise.
- - */
- - public Rot90Op(int k) {
- - numRotation = k % 4;
- - }
- + /**
- + * Creates a Rot90 Op which will rotate image by 90 degree for {@code k} times
- + * counter-clockwise.
- + *
- + * @param k The number of times the image is rotated by 90 degrees. If it's positive, the image
- + * will be rotated counter-clockwise. If it's negative, the op will rotate image clockwise.
- + */
- + public Rot90Op(int k) {
- + numRotation = k % 4;
- + }
-
- - /**
- - * Applies the defined rotation on given image and returns the result.
- - *
- - * <p>Note: the content of input {@code image} will change, and {@code image} is the same instance
- - * with the output.
- - *
- - * @param image input image.
- - * @return output image.
- - */
- - @NonNull
- - @Override
- - public TensorImage apply(@NonNull TensorImage image) {
- - checkArgument(
- - image.getColorSpaceType() == ColorSpaceType.RGB,
- - "Only RGB images are supported in Rot90Op, but not " + image.getColorSpaceType().name());
- - Bitmap input = image.getBitmap();
- - if (numRotation == 0) {
- - return image;
- + /**
- + * Applies the defined rotation on given image and returns the result.
- + *
- + * <p>Note: the content of input {@code image} will change, and {@code image} is the same
- + * instance with the output.
- + *
- + * @param image input image.
- + * @return output image.
- + */
- + @NonNull
- + @Override
- + public TensorImage apply(@NonNull TensorImage image) {
- + checkArgument(image.getColorSpaceType() == ColorSpaceType.RGB,
- + "Only RGB images are supported in Rot90Op, but not "
- + + image.getColorSpaceType().name());
- + Bitmap input = image.getBitmap();
- + if (numRotation == 0) {
- + return image;
- + }
- + int w = input.getWidth();
- + int h = input.getHeight();
- + Matrix matrix = new Matrix();
- + matrix.postTranslate(w * 0.5f, h * 0.5f);
- + matrix.postRotate(-90 * numRotation);
- + int newW = (numRotation % 2 == 0) ? w : h;
- + int newH = (numRotation % 2 == 0) ? h : w;
- + matrix.postTranslate(newW * 0.5f, newH * 0.5f);
- + Bitmap output = Bitmap.createBitmap(input, 0, 0, w, h, matrix, false);
- + image.load(output);
- + return image;
- }
- - int w = input.getWidth();
- - int h = input.getHeight();
- - Matrix matrix = new Matrix();
- - matrix.postTranslate(w * 0.5f, h * 0.5f);
- - matrix.postRotate(-90 * numRotation);
- - int newW = (numRotation % 2 == 0) ? w : h;
- - int newH = (numRotation % 2 == 0) ? h : w;
- - matrix.postTranslate(newW * 0.5f, newH * 0.5f);
- - Bitmap output = Bitmap.createBitmap(input, 0, 0, w, h, matrix, false);
- - image.load(output);
- - return image;
- - }
-
- - @Override
- - public int getOutputImageHeight(int inputImageHeight, int inputImageWidth) {
- - return (numRotation % 2 == 0) ? inputImageHeight : inputImageWidth;
- - }
- + @Override
- + public int getOutputImageHeight(int inputImageHeight, int inputImageWidth) {
- + return (numRotation % 2 == 0) ? inputImageHeight : inputImageWidth;
- + }
-
- - @Override
- - public int getOutputImageWidth(int inputImageHeight, int inputImageWidth) {
- - return (numRotation % 2 == 0) ? inputImageWidth : inputImageHeight;
- - }
- + @Override
- + public int getOutputImageWidth(int inputImageHeight, int inputImageWidth) {
- + return (numRotation % 2 == 0) ? inputImageWidth : inputImageHeight;
- + }
-
- - @Override
- - public PointF inverseTransform(PointF point, int inputImageHeight, int inputImageWidth) {
- - int inverseNumRotation = (4 - numRotation) % 4;
- - int height = getOutputImageHeight(inputImageHeight, inputImageWidth);
- - int width = getOutputImageWidth(inputImageHeight, inputImageWidth);
- - return transformImpl(point, height, width, inverseNumRotation);
- - }
- + @Override
- + public PointF inverseTransform(PointF point, int inputImageHeight, int inputImageWidth) {
- + int inverseNumRotation = (4 - numRotation) % 4;
- + int height = getOutputImageHeight(inputImageHeight, inputImageWidth);
- + int width = getOutputImageWidth(inputImageHeight, inputImageWidth);
- + return transformImpl(point, height, width, inverseNumRotation);
- + }
-
- - private static PointF transformImpl(PointF point, int height, int width, int numRotation) {
- - if (numRotation == 0) {
- - return point;
- - } else if (numRotation == 1) {
- - return new PointF(point.y, width - point.x);
- - } else if (numRotation == 2) {
- - return new PointF(width - point.x, height - point.y);
- - } else { // numRotation == 3
- - return new PointF(height - point.y, point.x);
- + private static PointF transformImpl(PointF point, int height, int width, int numRotation) {
- + if (numRotation == 0) {
- + return point;
- + } else if (numRotation == 1) {
- + return new PointF(point.y, width - point.x);
- + } else if (numRotation == 2) {
- + return new PointF(width - point.x, height - point.y);
- + } else { // numRotation == 3
- + return new PointF(height - point.y, point.x);
- + }
- }
- - }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/TensorOperatorWrapper.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/TensorOperatorWrapper.java
- index 5d10ac890e57b..feb2b3b7b0762 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/TensorOperatorWrapper.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/TensorOperatorWrapper.java
- @@ -16,6 +16,7 @@ limitations under the License.
- package org.tensorflow.lite.support.image.ops;
-
- import android.graphics.PointF;
- +
- import org.checkerframework.checker.nullness.qual.NonNull;
- import org.tensorflow.lite.support.common.TensorOperator;
- import org.tensorflow.lite.support.common.internal.SupportPreconditions;
- @@ -31,48 +32,47 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
- * @see org.tensorflow.lite.support.image.TensorImage
- */
- public class TensorOperatorWrapper implements ImageOperator {
- + private final TensorOperator tensorOp;
-
- - private final TensorOperator tensorOp;
- -
- - /**
- - * Wraps a {@link TensorOperator} object as an {@link ImageOperator}, so that the {@link
- - * TensorOperator} could handle {@link TensorImage} objects by handling its underlying {@link
- - * org.tensorflow.lite.support.tensorbuffer.TensorBuffer}.
- - *
- - * <p>Requirement: The {@code op} should not change coordinate system when applied on an image.
- - *
- - * @param op The created operator.
- - */
- - public TensorOperatorWrapper(TensorOperator op) {
- - tensorOp = op;
- - }
- + /**
- + * Wraps a {@link TensorOperator} object as an {@link ImageOperator}, so that the {@link
- + * TensorOperator} could handle {@link TensorImage} objects by handling its underlying {@link
- + * org.tensorflow.lite.support.tensorbuffer.TensorBuffer}.
- + *
- + * <p>Requirement: The {@code op} should not change coordinate system when applied on an image.
- + *
- + * @param op The created operator.
- + */
- + public TensorOperatorWrapper(TensorOperator op) {
- + tensorOp = op;
- + }
-
- - @Override
- - @NonNull
- - public TensorImage apply(@NonNull TensorImage image) {
- - SupportPreconditions.checkNotNull(image, "Op cannot apply on null image.");
- - TensorBuffer resBuffer = tensorOp.apply(image.getTensorBuffer());
- - // Some ops may change the data type of the underlying TensorBuffer, such as CastOp. Therefore,
- - // need to create a new TensorImage with the correct data type.
- - // However the underlying ops should not touch the color type.
- - ColorSpaceType colorSpaceType = image.getColorSpaceType();
- - TensorImage resImage = new TensorImage(resBuffer.getDataType());
- - resImage.load(resBuffer, colorSpaceType);
- - return resImage;
- - }
- + @Override
- + @NonNull
- + public TensorImage apply(@NonNull TensorImage image) {
- + SupportPreconditions.checkNotNull(image, "Op cannot apply on null image.");
- + TensorBuffer resBuffer = tensorOp.apply(image.getTensorBuffer());
- + // Some ops may change the data type of the underlying TensorBuffer, such as CastOp.
- + // Therefore, need to create a new TensorImage with the correct data type. However the
- + // underlying ops should not touch the color type.
- + ColorSpaceType colorSpaceType = image.getColorSpaceType();
- + TensorImage resImage = new TensorImage(resBuffer.getDataType());
- + resImage.load(resBuffer, colorSpaceType);
- + return resImage;
- + }
-
- - @Override
- - public int getOutputImageHeight(int inputImageHeight, int inputImageWidth) {
- - return inputImageHeight;
- - }
- + @Override
- + public int getOutputImageHeight(int inputImageHeight, int inputImageWidth) {
- + return inputImageHeight;
- + }
-
- - @Override
- - public int getOutputImageWidth(int inputImageHeight, int inputImageWidth) {
- - return inputImageWidth;
- - }
- + @Override
- + public int getOutputImageWidth(int inputImageHeight, int inputImageWidth) {
- + return inputImageWidth;
- + }
-
- - @Override
- - public PointF inverseTransform(PointF point, int inputImageHeight, int inputImageWidth) {
- - return point;
- - }
- + @Override
- + public PointF inverseTransform(PointF point, int inputImageHeight, int inputImageWidth) {
- + return point;
- + }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/TransformToGrayscaleOp.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/TransformToGrayscaleOp.java
- index bd3c10b254ac5..1a6f905b1bffd 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/TransformToGrayscaleOp.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/TransformToGrayscaleOp.java
- @@ -23,6 +23,7 @@ import android.graphics.ColorFilter;
- import android.graphics.ColorMatrixColorFilter;
- import android.graphics.Paint;
- import android.graphics.PointF;
- +
- import org.tensorflow.lite.support.image.ColorSpaceType;
- import org.tensorflow.lite.support.image.ImageOperator;
- import org.tensorflow.lite.support.image.TensorImage;
- @@ -41,77 +42,73 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
- * https://docs.opencv.org/master/de/d25/imgproc_color_conversions.html#color_convert_rgb_gray
- */
- public class TransformToGrayscaleOp implements ImageOperator {
- + // A matrix is created that will be applied later to canvas to generate grayscale image
- + // The luminance of each pixel is calculated as the weighted sum of the 3 RGB values
- + // Y = 0.299R + 0.587G + 0.114B
- + private static final float[] BITMAP_RGBA_GRAYSCALE_TRANSFORMATION =
- + new float[] {0.299F, 0.587F, 0.114F, 0.0F, 0.0F, 0.0F, 0.0F, 0.0F, 0.0F, 0.0F, 0.0F,
- + 0.0F, 0.0F, 0.0F, 0.0F, 0.0F, 0.0F, 0.0F, 1.0F, 0.0F};
-
- - // A matrix is created that will be applied later to canvas to generate grayscale image
- - // The luminance of each pixel is calculated as the weighted sum of the 3 RGB values
- - // Y = 0.299R + 0.587G + 0.114B
- - private static final float[] BITMAP_RGBA_GRAYSCALE_TRANSFORMATION =
- - new float[] {
- - 0.299F, 0.587F, 0.114F, 0.0F, 0.0F, 0.0F, 0.0F, 0.0F, 0.0F, 0.0F, 0.0F, 0.0F, 0.0F, 0.0F,
- - 0.0F, 0.0F, 0.0F, 0.0F, 1.0F, 0.0F
- - };
- -
- - /** Creates a TransformToGrayscaleOp. */
- - public TransformToGrayscaleOp() {}
- + /** Creates a TransformToGrayscaleOp. */
- + public TransformToGrayscaleOp() {}
-
- - /**
- - * Applies the transformation to grayscale and returns a {@link TensorImage}.
- - *
- - * <p>If the input image is already {@link
- - * org.tensorflow.lite.support.image.ColorSpaceType#GRAYSCALE}, this op will be a no-op.
- - *
- - * @throws IllegalArgumentException if the {@code image} is not {@link
- - * org.tensorflow.lite.support.image.ColorSpaceType#RGB} or {@link
- - * org.tensorflow.lite.support.image.ColorSpaceType#GRAYSCALE}.
- - */
- - @Override
- - public TensorImage apply(TensorImage image) {
- - if (image.getColorSpaceType() == ColorSpaceType.GRAYSCALE) {
- - return image;
- - } else {
- - checkArgument(
- - image.getColorSpaceType() == ColorSpaceType.RGB,
- - "Only RGB images are supported in TransformToGrayscaleOp, but not "
- - + image.getColorSpaceType().name());
- - }
- - int h = image.getHeight();
- - int w = image.getWidth();
- - Bitmap bmpGrayscale = Bitmap.createBitmap(w, h, Bitmap.Config.ARGB_8888);
- - Canvas canvas = new Canvas(bmpGrayscale);
- - Paint paint = new Paint();
- - ColorMatrixColorFilter colorMatrixFilter =
- - new ColorMatrixColorFilter(BITMAP_RGBA_GRAYSCALE_TRANSFORMATION);
- - paint.setColorFilter((ColorFilter) colorMatrixFilter);
- - canvas.drawBitmap(image.getBitmap(), 0.0F, 0.0F, paint);
- + /**
- + * Applies the transformation to grayscale and returns a {@link TensorImage}.
- + *
- + * <p>If the input image is already {@link
- + * org.tensorflow.lite.support.image.ColorSpaceType#GRAYSCALE}, this op will be a no-op.
- + *
- + * @throws IllegalArgumentException if the {@code image} is not {@link
- + * org.tensorflow.lite.support.image.ColorSpaceType#RGB} or {@link
- + * org.tensorflow.lite.support.image.ColorSpaceType#GRAYSCALE}.
- + */
- + @Override
- + public TensorImage apply(TensorImage image) {
- + if (image.getColorSpaceType() == ColorSpaceType.GRAYSCALE) {
- + return image;
- + } else {
- + checkArgument(image.getColorSpaceType() == ColorSpaceType.RGB,
- + "Only RGB images are supported in TransformToGrayscaleOp, but not "
- + + image.getColorSpaceType().name());
- + }
- + int h = image.getHeight();
- + int w = image.getWidth();
- + Bitmap bmpGrayscale = Bitmap.createBitmap(w, h, Bitmap.Config.ARGB_8888);
- + Canvas canvas = new Canvas(bmpGrayscale);
- + Paint paint = new Paint();
- + ColorMatrixColorFilter colorMatrixFilter =
- + new ColorMatrixColorFilter(BITMAP_RGBA_GRAYSCALE_TRANSFORMATION);
- + paint.setColorFilter((ColorFilter) colorMatrixFilter);
- + canvas.drawBitmap(image.getBitmap(), 0.0F, 0.0F, paint);
-
- - // Get the pixels from the generated grayscale image
- - int[] intValues = new int[w * h];
- - bmpGrayscale.getPixels(intValues, 0, w, 0, 0, w, h);
- - // Shape with one channel
- - int[] shape = new int[] {1, h, w, 1};
- + // Get the pixels from the generated grayscale image
- + int[] intValues = new int[w * h];
- + bmpGrayscale.getPixels(intValues, 0, w, 0, 0, w, h);
- + // Shape with one channel
- + int[] shape = new int[] {1, h, w, 1};
-
- - // Get R channel from ARGB color
- - for (int i = 0; i < intValues.length; i++) {
- - intValues[i] = ((intValues[i] >> 16) & 0xff);
- + // Get R channel from ARGB color
- + for (int i = 0; i < intValues.length; i++) {
- + intValues[i] = ((intValues[i] >> 16) & 0xff);
- + }
- + TensorBuffer buffer = TensorBuffer.createFixedSize(shape, image.getDataType());
- + buffer.loadArray(intValues, shape);
- + image.load(buffer, ColorSpaceType.GRAYSCALE);
- + return image;
- }
- - TensorBuffer buffer = TensorBuffer.createFixedSize(shape, image.getDataType());
- - buffer.loadArray(intValues, shape);
- - image.load(buffer, ColorSpaceType.GRAYSCALE);
- - return image;
- - }
-
- - @Override
- - public int getOutputImageHeight(int inputImageHeight, int inputImageWidth) {
- - return inputImageHeight;
- - }
- + @Override
- + public int getOutputImageHeight(int inputImageHeight, int inputImageWidth) {
- + return inputImageHeight;
- + }
-
- - @Override
- - public int getOutputImageWidth(int inputImageHeight, int inputImageWidth) {
- - return inputImageWidth;
- - }
- + @Override
- + public int getOutputImageWidth(int inputImageHeight, int inputImageWidth) {
- + return inputImageWidth;
- + }
-
- - @Override
- - public PointF inverseTransform(PointF point, int inputImageHeight, int inputImageWidth) {
- - return point;
- - }
- + @Override
- + public PointF inverseTransform(PointF point, int inputImageHeight, int inputImageWidth) {
- + return point;
- + }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/Category.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/Category.java
- index 8135ddcc28619..af56b70a77cf3 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/Category.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/Category.java
- @@ -15,9 +15,10 @@ limitations under the License.
-
- package org.tensorflow.lite.support.label;
-
- -import java.util.Objects;
- import org.tensorflow.lite.annotations.UsedByReflection;
-
- +import java.util.Objects;
- +
- /**
- * Category is a util class, contains a label, its display name, a float value as score, and the
- * index of the label in the corresponding label file. Typically it's used as result of
- @@ -25,102 +26,97 @@ import org.tensorflow.lite.annotations.UsedByReflection;
- */
- @UsedByReflection("TFLiteSupport/Task")
- public final class Category {
- - private static final int DEFAULT_INDEX = -1;
- - private static final float TOLERANCE = 1e-6f;
- - private final int index;
- - private final String label;
- - private final String displayName;
- - private final float score;
- -
- - /**
- - * Constructs a {@link Category} object.
- - *
- - * @param label the label of this category object
- - * @param displayName the display name of the label, which may be translated for different
- - * locales. For exmaple, a label, "apple", may be translated into Spanish for display purpose,
- - * so that the displayName is "manzana".
- - * @param score the probability score of this label category
- - * @param index the index of the label in the corresponding label file
- - */
- - @UsedByReflection("TFLiteSupport/Task")
- - public static Category create(String label, String displayName, float score, int index) {
- - return new Category(label, displayName, score, index);
- - }
- -
- - /** Constructs a {@link Category} object with the default index (-1). */
- - @UsedByReflection("TFLiteSupport/Task")
- - public static Category create(String label, String displayName, float score) {
- - return new Category(label, displayName, score, DEFAULT_INDEX);
- - }
- -
- - /** Constructs a {@link Category} object with an empty displayName and the default index (-1). */
- - @UsedByReflection("TFLiteSupport/Task")
- - public Category(String label, float score) {
- - this(label, /*displayName=*/ "", score, DEFAULT_INDEX);
- - }
- -
- - private Category(String label, String displayName, float score, int index) {
- - this.label = label;
- - this.displayName = displayName;
- - this.score = score;
- - this.index = index;
- - }
- -
- - /** Gets the reference of category's label. */
- - public String getLabel() {
- - return label;
- - }
- -
- - /**
- - * Gets the reference of category's displayName, a name in locale of the label.
- - *
- - * <p>The display name can be an empty string if this {@link Category} object is constructed
- - * without displayName, such as when using {@link #Category(String label, float score)}.
- - */
- - public String getDisplayName() {
- - return displayName;
- - }
- -
- - /** Gets the score of the category. */
- - public float getScore() {
- - return score;
- - }
- -
- - /**
- - * Gets the index of the category. The index value might be -1, which means it has not been set up
- - * properly and is invalid.
- - */
- - public int getIndex() {
- - return index;
- - }
- -
- - @Override
- - public boolean equals(Object o) {
- - if (o instanceof Category) {
- - Category other = (Category) o;
- - return (other.getLabel().equals(this.label)
- - && other.getDisplayName().equals(this.displayName)
- - && Math.abs(other.getScore() - this.score) < TOLERANCE
- - && other.getIndex() == this.index);
- + private static final int DEFAULT_INDEX = -1;
- + private static final float TOLERANCE = 1e-6f;
- + private final int index;
- + private final String label;
- + private final String displayName;
- + private final float score;
- +
- + /**
- + * Constructs a {@link Category} object.
- + *
- + * @param label the label of this category object
- + * @param displayName the display name of the label, which may be translated for different
- + * locales. For exmaple, a label, "apple", may be translated into Spanish for display
- + * purpose, so that the displayName is "manzana".
- + * @param score the probability score of this label category
- + * @param index the index of the label in the corresponding label file
- + */
- + @UsedByReflection("TFLiteSupport/Task")
- + public static Category create(String label, String displayName, float score, int index) {
- + return new Category(label, displayName, score, index);
- + }
- +
- + /** Constructs a {@link Category} object with the default index (-1). */
- + @UsedByReflection("TFLiteSupport/Task")
- + public static Category create(String label, String displayName, float score) {
- + return new Category(label, displayName, score, DEFAULT_INDEX);
- + }
- +
- + /**
- + * Constructs a {@link Category} object with an empty displayName and the default index (-1).
- + */
- + @UsedByReflection("TFLiteSupport/Task")
- + public Category(String label, float score) {
- + this(label, /*displayName=*/"", score, DEFAULT_INDEX);
- + }
- +
- + private Category(String label, String displayName, float score, int index) {
- + this.label = label;
- + this.displayName = displayName;
- + this.score = score;
- + this.index = index;
- + }
- +
- + /** Gets the reference of category's label. */
- + public String getLabel() {
- + return label;
- + }
- +
- + /**
- + * Gets the reference of category's displayName, a name in locale of the label.
- + *
- + * <p>The display name can be an empty string if this {@link Category} object is constructed
- + * without displayName, such as when using {@link #Category(String label, float score)}.
- + */
- + public String getDisplayName() {
- + return displayName;
- + }
- +
- + /** Gets the score of the category. */
- + public float getScore() {
- + return score;
- + }
- +
- + /**
- + * Gets the index of the category. The index value might be -1, which means it has not been set
- + * up properly and is invalid.
- + */
- + public int getIndex() {
- + return index;
- + }
- +
- + @Override
- + public boolean equals(Object o) {
- + if (o instanceof Category) {
- + Category other = (Category) o;
- + return (other.getLabel().equals(this.label)
- + && other.getDisplayName().equals(this.displayName)
- + && Math.abs(other.getScore() - this.score) < TOLERANCE
- + && other.getIndex() == this.index);
- + }
- + return false;
- + }
- +
- + @Override
- + public int hashCode() {
- + return Objects.hash(label, displayName, score, index);
- + }
- +
- + @Override
- + public String toString() {
- + return "<Category \"" + label + "\" (displayName=" + displayName + " score=" + score
- + + " index=" + index + ")>";
- }
- - return false;
- - }
- -
- - @Override
- - public int hashCode() {
- - return Objects.hash(label, displayName, score, index);
- - }
- -
- - @Override
- - public String toString() {
- - return "<Category \""
- - + label
- - + "\" (displayName="
- - + displayName
- - + " score="
- - + score
- - + " index="
- - + index
- - + ")>";
- - }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/LabelUtil.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/LabelUtil.java
- index af21d74e25f5d..56ee89f091e03 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/LabelUtil.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/LabelUtil.java
- @@ -16,49 +16,52 @@ limitations under the License.
- package org.tensorflow.lite.support.label;
-
- import android.util.Log;
- -import java.util.ArrayList;
- -import java.util.Arrays;
- -import java.util.List;
- +
- import org.checkerframework.checker.nullness.qual.NonNull;
- import org.tensorflow.lite.support.common.internal.SupportPreconditions;
- import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
-
- +import java.util.ArrayList;
- +import java.util.Arrays;
- +import java.util.List;
- +
- /** Label operation utils. */
- public class LabelUtil {
- - /**
- - * Maps an int value tensor to a list of string labels. It takes an array of strings as the
- - * dictionary. Example: if the given tensor is [3, 1, 0], and given labels is ["background",
- - * "apple", "banana", "cherry", "date"], the result will be ["date", "banana", "apple"].
- - *
- - * @param tensorBuffer A tensor with index values. The values should be non-negative integers, and
- - * each value {@code x} will be converted to {@code labels[x + offset]}. If the tensor is
- - * given as a float {@link TensorBuffer}, values will be cast to integers. All values that are
- - * out of bound will map to empty string.
- - * @param labels A list of strings, used as a dictionary to look up. The index of the array
- - * element will be used as the key. To get better performance, use an object that implements
- - * RandomAccess, such as {@link ArrayList}.
- - * @param offset The offset value when look up int values in the {@code labels}.
- - * @return the mapped strings. The length of the list is {@link TensorBuffer#getFlatSize}.
- - * @throws IllegalArgumentException if {@code tensorBuffer} or {@code labels} is null.
- - */
- - public static List<String> mapValueToLabels(
- - @NonNull TensorBuffer tensorBuffer, @NonNull List<String> labels, int offset) {
- - SupportPreconditions.checkNotNull(tensorBuffer, "Given tensor should not be null");
- - SupportPreconditions.checkNotNull(labels, "Given labels should not be null");
- - int[] values = tensorBuffer.getIntArray();
- - Log.d("values", Arrays.toString(values));
- - List<String> result = new ArrayList<>();
- - for (int v : values) {
- - int index = v + offset;
- - if (index < 0 || index >= labels.size()) {
- - result.add("");
- - } else {
- - result.add(labels.get(index));
- - }
- + /**
- + * Maps an int value tensor to a list of string labels. It takes an array of strings as the
- + * dictionary. Example: if the given tensor is [3, 1, 0], and given labels is ["background",
- + * "apple", "banana", "cherry", "date"], the result will be ["date", "banana", "apple"].
- + *
- + * @param tensorBuffer A tensor with index values. The values should be non-negative integers,
- + * and
- + * each value {@code x} will be converted to {@code labels[x + offset]}. If the tensor is
- + * given as a float {@link TensorBuffer}, values will be cast to integers. All values that
- + * are out of bound will map to empty string.
- + * @param labels A list of strings, used as a dictionary to look up. The index of the array
- + * element will be used as the key. To get better performance, use an object that implements
- + * RandomAccess, such as {@link ArrayList}.
- + * @param offset The offset value when look up int values in the {@code labels}.
- + * @return the mapped strings. The length of the list is {@link TensorBuffer#getFlatSize}.
- + * @throws IllegalArgumentException if {@code tensorBuffer} or {@code labels} is null.
- + */
- + public static List<String> mapValueToLabels(
- + @NonNull TensorBuffer tensorBuffer, @NonNull List<String> labels, int offset) {
- + SupportPreconditions.checkNotNull(tensorBuffer, "Given tensor should not be null");
- + SupportPreconditions.checkNotNull(labels, "Given labels should not be null");
- + int[] values = tensorBuffer.getIntArray();
- + Log.d("values", Arrays.toString(values));
- + List<String> result = new ArrayList<>();
- + for (int v : values) {
- + int index = v + offset;
- + if (index < 0 || index >= labels.size()) {
- + result.add("");
- + } else {
- + result.add(labels.get(index));
- + }
- + }
- + return result;
- }
- - return result;
- - }
-
- - // Private constructor to prevent initialization.
- - private LabelUtil() {}
- + // Private constructor to prevent initialization.
- + private LabelUtil() {}
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/TensorLabel.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/TensorLabel.java
- index bdab7cf464c1b..edd683cd08126 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/TensorLabel.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/TensorLabel.java
- @@ -16,16 +16,18 @@ limitations under the License.
- package org.tensorflow.lite.support.label;
-
- import android.content.Context;
- +
- +import org.checkerframework.checker.nullness.qual.NonNull;
- +import org.tensorflow.lite.DataType;
- +import org.tensorflow.lite.support.common.internal.SupportPreconditions;
- +import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
- +
- import java.nio.ByteBuffer;
- import java.util.ArrayList;
- import java.util.Arrays;
- import java.util.LinkedHashMap;
- import java.util.List;
- import java.util.Map;
- -import org.checkerframework.checker.nullness.qual.NonNull;
- -import org.tensorflow.lite.DataType;
- -import org.tensorflow.lite.support.common.internal.SupportPreconditions;
- -import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
-
- /**
- * TensorLabel is an util wrapper for TensorBuffers with meaningful labels on an axis.
- @@ -56,169 +58,170 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
- * a label file (plain text file whose each line is a label) in assets simply.
- */
- public class TensorLabel {
- - private final Map<Integer, List<String>> axisLabels;
- - private final TensorBuffer tensorBuffer;
- - private final int[] shape;
- -
- - /**
- - * Creates a TensorLabel object which is able to label on the axes of multi-dimensional tensors.
- - *
- - * @param axisLabels A map, whose key is axis id (starting from 0) and value is corresponding
- - * labels. Note: The size of labels should be same with the size of the tensor on that axis.
- - * @param tensorBuffer The TensorBuffer to be labeled.
- - * @throws NullPointerException if {@code axisLabels} or {@code tensorBuffer} is null, or any
- - * value in {@code axisLabels} is null.
- - * @throws IllegalArgumentException if any key in {@code axisLabels} is out of range (compared to
- - * the shape of {@code tensorBuffer}, or any value (labels) has different size with the {@code
- - * tensorBuffer} on the given dimension.
- - */
- - public TensorLabel(
- - @NonNull Map<Integer, List<String>> axisLabels, @NonNull TensorBuffer tensorBuffer) {
- - SupportPreconditions.checkNotNull(axisLabels, "Axis labels cannot be null.");
- - SupportPreconditions.checkNotNull(tensorBuffer, "Tensor Buffer cannot be null.");
- - this.axisLabels = axisLabels;
- - this.tensorBuffer = tensorBuffer;
- - this.shape = tensorBuffer.getShape();
- - for (Map.Entry<Integer, List<String>> entry : axisLabels.entrySet()) {
- - int axis = entry.getKey();
- - SupportPreconditions.checkArgument(
- - axis >= 0 && axis < shape.length, "Invalid axis id: " + axis);
- - SupportPreconditions.checkNotNull(entry.getValue(), "Label list is null on axis " + axis);
- - SupportPreconditions.checkArgument(
- - shape[axis] == entry.getValue().size(),
- - "Label number " + entry.getValue().size() + " mismatch the shape on axis " + axis);
- + private final Map<Integer, List<String>> axisLabels;
- + private final TensorBuffer tensorBuffer;
- + private final int[] shape;
- +
- + /**
- + * Creates a TensorLabel object which is able to label on the axes of multi-dimensional tensors.
- + *
- + * @param axisLabels A map, whose key is axis id (starting from 0) and value is corresponding
- + * labels. Note: The size of labels should be same with the size of the tensor on that axis.
- + * @param tensorBuffer The TensorBuffer to be labeled.
- + * @throws NullPointerException if {@code axisLabels} or {@code tensorBuffer} is null, or any
- + * value in {@code axisLabels} is null.
- + * @throws IllegalArgumentException if any key in {@code axisLabels} is out of range (compared
- + * to
- + * the shape of {@code tensorBuffer}, or any value (labels) has different size with the
- + * {@code tensorBuffer} on the given dimension.
- + */
- + public TensorLabel(
- + @NonNull Map<Integer, List<String>> axisLabels, @NonNull TensorBuffer tensorBuffer) {
- + SupportPreconditions.checkNotNull(axisLabels, "Axis labels cannot be null.");
- + SupportPreconditions.checkNotNull(tensorBuffer, "Tensor Buffer cannot be null.");
- + this.axisLabels = axisLabels;
- + this.tensorBuffer = tensorBuffer;
- + this.shape = tensorBuffer.getShape();
- + for (Map.Entry<Integer, List<String>> entry : axisLabels.entrySet()) {
- + int axis = entry.getKey();
- + SupportPreconditions.checkArgument(
- + axis >= 0 && axis < shape.length, "Invalid axis id: " + axis);
- + SupportPreconditions.checkNotNull(
- + entry.getValue(), "Label list is null on axis " + axis);
- + SupportPreconditions.checkArgument(shape[axis] == entry.getValue().size(),
- + "Label number " + entry.getValue().size() + " mismatch the shape on axis "
- + + axis);
- + }
- }
- - }
- -
- - /**
- - * Creates a TensorLabel object which is able to label on one axis of multi-dimensional tensors.
- - *
- - * <p>Note: The labels are applied on the first axis whose size is larger than 1. For example, if
- - * the shape of the tensor is [1, 10, 3], the labels will be applied on axis 1 (id starting from
- - * 0), and size of {@code axisLabels} should be 10 as well.
- - *
- - * @param axisLabels A list of labels, whose size should be same with the size of the tensor on
- - * the to-be-labeled axis.
- - * @param tensorBuffer The TensorBuffer to be labeled.
- - */
- - public TensorLabel(@NonNull List<String> axisLabels, @NonNull TensorBuffer tensorBuffer) {
- - this(makeMap(getFirstAxisWithSizeGreaterThanOne(tensorBuffer), axisLabels), tensorBuffer);
- - }
- -
- - /**
- - * Gets the map with a pair of the label and the corresponding TensorBuffer. Only allow the
- - * mapping on the first axis with size greater than 1 currently.
- - */
- - @NonNull
- - public Map<String, TensorBuffer> getMapWithTensorBuffer() {
- - int labeledAxis = getFirstAxisWithSizeGreaterThanOne(tensorBuffer);
- -
- - Map<String, TensorBuffer> labelToTensorMap = new LinkedHashMap<>();
- - SupportPreconditions.checkArgument(
- - axisLabels.containsKey(labeledAxis),
- - "get a <String, TensorBuffer> map requires the labels are set on the first non-1 axis.");
- - List<String> labels = axisLabels.get(labeledAxis);
- -
- - DataType dataType = tensorBuffer.getDataType();
- - int typeSize = tensorBuffer.getTypeSize();
- - int flatSize = tensorBuffer.getFlatSize();
- -
- - // Gets the underlying bytes that could be used to generate the sub-array later.
- - ByteBuffer byteBuffer = tensorBuffer.getBuffer();
- - byteBuffer.rewind();
- -
- - // Note: computation below is only correct when labeledAxis is the first axis with size greater
- - // than 1.
- - int subArrayLength = flatSize / shape[labeledAxis] * typeSize;
- - int i = 0;
- - SupportPreconditions.checkNotNull(labels, "Label list should never be null");
- - for (String label : labels) {
- - // Gets the corresponding TensorBuffer.
- - byteBuffer.position(i * subArrayLength);
- - ByteBuffer subBuffer = byteBuffer.slice();
- - // ByteBuffer.slice doesn't keep order. Modify it to align with the original one.
- - subBuffer.order(byteBuffer.order()).limit(subArrayLength);
- - TensorBuffer labelBuffer = TensorBuffer.createDynamic(dataType);
- - labelBuffer.loadBuffer(subBuffer, Arrays.copyOfRange(shape, labeledAxis + 1, shape.length));
- - labelToTensorMap.put(label, labelBuffer);
- - i += 1;
- +
- + /**
- + * Creates a TensorLabel object which is able to label on one axis of multi-dimensional tensors.
- + *
- + * <p>Note: The labels are applied on the first axis whose size is larger than 1. For example,
- + * if the shape of the tensor is [1, 10, 3], the labels will be applied on axis 1 (id starting
- + * from 0), and size of {@code axisLabels} should be 10 as well.
- + *
- + * @param axisLabels A list of labels, whose size should be same with the size of the tensor on
- + * the to-be-labeled axis.
- + * @param tensorBuffer The TensorBuffer to be labeled.
- + */
- + public TensorLabel(@NonNull List<String> axisLabels, @NonNull TensorBuffer tensorBuffer) {
- + this(makeMap(getFirstAxisWithSizeGreaterThanOne(tensorBuffer), axisLabels), tensorBuffer);
- }
- - return labelToTensorMap;
- - }
- -
- - /**
- - * Gets a map that maps label to float. Only allow the mapping on the first axis with size greater
- - * than 1, and the axis should be effectively the last axis (which means every sub tensor
- - * specified by this axis should have a flat size of 1).
- - *
- - * <p>{@link TensorLabel#getCategoryList()} is an alternative API to get the result.
- - *
- - * @throws IllegalStateException if size of a sub tensor on each label is not 1.
- - */
- - @NonNull
- - public Map<String, Float> getMapWithFloatValue() {
- - int labeledAxis = getFirstAxisWithSizeGreaterThanOne(tensorBuffer);
- - SupportPreconditions.checkState(
- - labeledAxis == shape.length - 1,
- - "get a <String, Scalar> map is only valid when the only labeled axis is the last one.");
- - List<String> labels = axisLabels.get(labeledAxis);
- - float[] data = tensorBuffer.getFloatArray();
- - SupportPreconditions.checkState(labels.size() == data.length);
- - Map<String, Float> result = new LinkedHashMap<>();
- - int i = 0;
- - for (String label : labels) {
- - result.put(label, data[i]);
- - i += 1;
- +
- + /**
- + * Gets the map with a pair of the label and the corresponding TensorBuffer. Only allow the
- + * mapping on the first axis with size greater than 1 currently.
- + */
- + @NonNull
- + public Map<String, TensorBuffer> getMapWithTensorBuffer() {
- + int labeledAxis = getFirstAxisWithSizeGreaterThanOne(tensorBuffer);
- +
- + Map<String, TensorBuffer> labelToTensorMap = new LinkedHashMap<>();
- + SupportPreconditions.checkArgument(axisLabels.containsKey(labeledAxis),
- + "get a <String, TensorBuffer> map requires the labels are set on the first non-1 axis.");
- + List<String> labels = axisLabels.get(labeledAxis);
- +
- + DataType dataType = tensorBuffer.getDataType();
- + int typeSize = tensorBuffer.getTypeSize();
- + int flatSize = tensorBuffer.getFlatSize();
- +
- + // Gets the underlying bytes that could be used to generate the sub-array later.
- + ByteBuffer byteBuffer = tensorBuffer.getBuffer();
- + byteBuffer.rewind();
- +
- + // Note: computation below is only correct when labeledAxis is the first axis with size
- + // greater than 1.
- + int subArrayLength = flatSize / shape[labeledAxis] * typeSize;
- + int i = 0;
- + SupportPreconditions.checkNotNull(labels, "Label list should never be null");
- + for (String label : labels) {
- + // Gets the corresponding TensorBuffer.
- + byteBuffer.position(i * subArrayLength);
- + ByteBuffer subBuffer = byteBuffer.slice();
- + // ByteBuffer.slice doesn't keep order. Modify it to align with the original one.
- + subBuffer.order(byteBuffer.order()).limit(subArrayLength);
- + TensorBuffer labelBuffer = TensorBuffer.createDynamic(dataType);
- + labelBuffer.loadBuffer(
- + subBuffer, Arrays.copyOfRange(shape, labeledAxis + 1, shape.length));
- + labelToTensorMap.put(label, labelBuffer);
- + i += 1;
- + }
- + return labelToTensorMap;
- }
- - return result;
- - }
- -
- - /**
- - * Gets a list of {@link Category} from the {@link TensorLabel} object.
- - *
- - * <p>The axis of label should be effectively the last axis (which means every sub tensor
- - * specified by this axis should have a flat size of 1), so that each labelled sub tensor could be
- - * converted into a float value score. Example: A {@link TensorLabel} with shape {@code {2, 5, 3}}
- - * and axis 2 is valid. If axis is 1 or 0, it cannot be converted into a {@link Category}.
- - *
- - * <p>{@link TensorLabel#getMapWithFloatValue()} is an alternative but returns a {@link Map} as
- - * the result.
- - *
- - * @throws IllegalStateException if size of a sub tensor on each label is not 1.
- - */
- - @NonNull
- - public List<Category> getCategoryList() {
- - int labeledAxis = getFirstAxisWithSizeGreaterThanOne(tensorBuffer);
- - SupportPreconditions.checkState(
- - labeledAxis == shape.length - 1,
- - "get a Category list is only valid when the only labeled axis is the last one.");
- - List<String> labels = axisLabels.get(labeledAxis);
- - float[] data = tensorBuffer.getFloatArray();
- - SupportPreconditions.checkState(labels.size() == data.length);
- - List<Category> result = new ArrayList<>();
- - int i = 0;
- - for (String label : labels) {
- - result.add(new Category(label, data[i]));
- - i += 1;
- +
- + /**
- + * Gets a map that maps label to float. Only allow the mapping on the first axis with size
- + * greater than 1, and the axis should be effectively the last axis (which means every sub
- + * tensor specified by this axis should have a flat size of 1).
- + *
- + * <p>{@link TensorLabel#getCategoryList()} is an alternative API to get the result.
- + *
- + * @throws IllegalStateException if size of a sub tensor on each label is not 1.
- + */
- + @NonNull
- + public Map<String, Float> getMapWithFloatValue() {
- + int labeledAxis = getFirstAxisWithSizeGreaterThanOne(tensorBuffer);
- + SupportPreconditions.checkState(labeledAxis == shape.length - 1,
- + "get a <String, Scalar> map is only valid when the only labeled axis is the last one.");
- + List<String> labels = axisLabels.get(labeledAxis);
- + float[] data = tensorBuffer.getFloatArray();
- + SupportPreconditions.checkState(labels.size() == data.length);
- + Map<String, Float> result = new LinkedHashMap<>();
- + int i = 0;
- + for (String label : labels) {
- + result.put(label, data[i]);
- + i += 1;
- + }
- + return result;
- }
- - return result;
- - }
- -
- - private static int getFirstAxisWithSizeGreaterThanOne(@NonNull TensorBuffer tensorBuffer) {
- - int[] shape = tensorBuffer.getShape();
- - for (int i = 0; i < shape.length; i++) {
- - if (shape[i] > 1) {
- - return i;
- - }
- +
- + /**
- + * Gets a list of {@link Category} from the {@link TensorLabel} object.
- + *
- + * <p>The axis of label should be effectively the last axis (which means every sub tensor
- + * specified by this axis should have a flat size of 1), so that each labelled sub tensor could
- + * be converted into a float value score. Example: A {@link TensorLabel} with shape {@code {2,
- + * 5, 3}} and axis 2 is valid. If axis is 1 or 0, it cannot be converted into a {@link
- + * Category}.
- + *
- + * <p>{@link TensorLabel#getMapWithFloatValue()} is an alternative but returns a {@link Map} as
- + * the result.
- + *
- + * @throws IllegalStateException if size of a sub tensor on each label is not 1.
- + */
- + @NonNull
- + public List<Category> getCategoryList() {
- + int labeledAxis = getFirstAxisWithSizeGreaterThanOne(tensorBuffer);
- + SupportPreconditions.checkState(labeledAxis == shape.length - 1,
- + "get a Category list is only valid when the only labeled axis is the last one.");
- + List<String> labels = axisLabels.get(labeledAxis);
- + float[] data = tensorBuffer.getFloatArray();
- + SupportPreconditions.checkState(labels.size() == data.length);
- + List<Category> result = new ArrayList<>();
- + int i = 0;
- + for (String label : labels) {
- + result.add(new Category(label, data[i]));
- + i += 1;
- + }
- + return result;
- + }
- +
- + private static int getFirstAxisWithSizeGreaterThanOne(@NonNull TensorBuffer tensorBuffer) {
- + int[] shape = tensorBuffer.getShape();
- + for (int i = 0; i < shape.length; i++) {
- + if (shape[i] > 1) {
- + return i;
- + }
- + }
- + throw new IllegalArgumentException(
- + "Cannot find an axis to label. A valid axis to label should have size larger than 1.");
- + }
- +
- + // Helper function to wrap the List<String> to a one-entry map.
- + private static Map<Integer, List<String>> makeMap(int axis, List<String> labels) {
- + Map<Integer, List<String>> map = new LinkedHashMap<>();
- + map.put(axis, labels);
- + return map;
- }
- - throw new IllegalArgumentException(
- - "Cannot find an axis to label. A valid axis to label should have size larger than 1.");
- - }
- -
- - // Helper function to wrap the List<String> to a one-entry map.
- - private static Map<Integer, List<String>> makeMap(int axis, List<String> labels) {
- - Map<Integer, List<String>> map = new LinkedHashMap<>();
- - map.put(axis, labels);
- - return map;
- - }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/ops/LabelAxisOp.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/ops/LabelAxisOp.java
- index ed47f65a726a6..e44edc64f4969 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/ops/LabelAxisOp.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/ops/LabelAxisOp.java
- @@ -16,16 +16,18 @@ limitations under the License.
- package org.tensorflow.lite.support.label.ops;
-
- import android.content.Context;
- -import java.io.IOException;
- -import java.util.HashMap;
- -import java.util.List;
- -import java.util.Map;
- +
- import org.checkerframework.checker.nullness.qual.NonNull;
- import org.tensorflow.lite.support.common.FileUtil;
- import org.tensorflow.lite.support.common.internal.SupportPreconditions;
- import org.tensorflow.lite.support.label.TensorLabel;
- import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
-
- +import java.io.IOException;
- +import java.util.HashMap;
- +import java.util.List;
- +import java.util.Map;
- +
- /**
- * Labels TensorBuffer with axisLabels for outputs.
- *
- @@ -33,42 +35,42 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
- * a pair of the label name and the corresponding TensorBuffer value.
- */
- public class LabelAxisOp {
- - // Axis and its corresponding label names.
- - private final Map<Integer, List<String>> axisLabels;
- -
- - protected LabelAxisOp(Builder builder) {
- - axisLabels = builder.axisLabels;
- - }
- -
- - public TensorLabel apply(@NonNull TensorBuffer buffer) {
- - SupportPreconditions.checkNotNull(buffer, "Tensor buffer cannot be null.");
- - return new TensorLabel(axisLabels, buffer);
- - }
- -
- - /** The inner builder class to build a LabelTensor Operator. */
- - public static class Builder {
- + // Axis and its corresponding label names.
- private final Map<Integer, List<String>> axisLabels;
-
- - protected Builder() {
- - axisLabels = new HashMap<>();
- + protected LabelAxisOp(Builder builder) {
- + axisLabels = builder.axisLabels;
- }
-
- - public Builder addAxisLabel(@NonNull Context context, int axis, @NonNull String filePath)
- - throws IOException {
- - SupportPreconditions.checkNotNull(context, "Context cannot be null.");
- - SupportPreconditions.checkNotNull(filePath, "File path cannot be null.");
- - List<String> labels = FileUtil.loadLabels(context, filePath);
- - axisLabels.put(axis, labels);
- - return this;
- + public TensorLabel apply(@NonNull TensorBuffer buffer) {
- + SupportPreconditions.checkNotNull(buffer, "Tensor buffer cannot be null.");
- + return new TensorLabel(axisLabels, buffer);
- }
-
- - public Builder addAxisLabel(int axis, @NonNull List<String> labels) {
- - axisLabels.put(axis, labels);
- - return this;
- - }
- + /** The inner builder class to build a LabelTensor Operator. */
- + public static class Builder {
- + private final Map<Integer, List<String>> axisLabels;
- +
- + protected Builder() {
- + axisLabels = new HashMap<>();
- + }
- +
- + public Builder addAxisLabel(@NonNull Context context, int axis, @NonNull String filePath)
- + throws IOException {
- + SupportPreconditions.checkNotNull(context, "Context cannot be null.");
- + SupportPreconditions.checkNotNull(filePath, "File path cannot be null.");
- + List<String> labels = FileUtil.loadLabels(context, filePath);
- + axisLabels.put(axis, labels);
- + return this;
- + }
- +
- + public Builder addAxisLabel(int axis, @NonNull List<String> labels) {
- + axisLabels.put(axis, labels);
- + return this;
- + }
-
- - public LabelAxisOp build() {
- - return new LabelAxisOp(this);
- + public LabelAxisOp build() {
- + return new LabelAxisOp(this);
- + }
- }
- - }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/model/GpuDelegateProxy.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/model/GpuDelegateProxy.java
- index 9cfcf923dedee..ada9b33fb0eea 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/model/GpuDelegateProxy.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/model/GpuDelegateProxy.java
- @@ -16,54 +16,55 @@ limitations under the License.
- package org.tensorflow.lite.support.model;
-
- import android.util.Log;
- -import java.io.Closeable;
- -import java.io.IOException;
- +
- import org.checkerframework.checker.nullness.qual.Nullable;
- import org.tensorflow.lite.Delegate;
-
- +import java.io.Closeable;
- +import java.io.IOException;
- +
- /**
- * Helper class to create and call necessary methods of {@code GpuDelegate} which is not a strict
- * dependency.
- */
- class GpuDelegateProxy implements Delegate, Closeable {
- + private static final String TAG = "GpuDelegateProxy";
-
- - private static final String TAG = "GpuDelegateProxy";
- -
- - private final Delegate proxiedDelegate;
- - private final Closeable proxiedCloseable;
- + private final Delegate proxiedDelegate;
- + private final Closeable proxiedCloseable;
-
- - @Nullable
- - public static GpuDelegateProxy maybeNewInstance() {
- - try {
- - Class<?> clazz = Class.forName("org.tensorflow.lite.gpu.GpuDelegate");
- - Object instance = clazz.getDeclaredConstructor().newInstance();
- - return new GpuDelegateProxy(instance);
- - } catch (ReflectiveOperationException e) {
- - Log.e(TAG, "Failed to create the GpuDelegate dynamically.", e);
- - return null;
- + @Nullable
- + public static GpuDelegateProxy maybeNewInstance() {
- + try {
- + Class<?> clazz = Class.forName("org.tensorflow.lite.gpu.GpuDelegate");
- + Object instance = clazz.getDeclaredConstructor().newInstance();
- + return new GpuDelegateProxy(instance);
- + } catch (ReflectiveOperationException e) {
- + Log.e(TAG, "Failed to create the GpuDelegate dynamically.", e);
- + return null;
- + }
- }
- - }
-
- - /** Calls {@code close()} method of the delegate. */
- - @Override
- - public void close() {
- - try {
- - proxiedCloseable.close();
- - } catch (IOException e) {
- - // Should not trigger, because GpuDelegate#close never throws. The catch is required because
- - // of Closeable#close.
- - Log.e(TAG, "Failed to close the GpuDelegate.", e);
- + /** Calls {@code close()} method of the delegate. */
- + @Override
- + public void close() {
- + try {
- + proxiedCloseable.close();
- + } catch (IOException e) {
- + // Should not trigger, because GpuDelegate#close never throws. The catch is required
- + // because of Closeable#close.
- + Log.e(TAG, "Failed to close the GpuDelegate.", e);
- + }
- }
- - }
-
- - /** Calls {@code getNativeHandle()} method of the delegate. */
- - @Override
- - public long getNativeHandle() {
- - return proxiedDelegate.getNativeHandle();
- - }
- + /** Calls {@code getNativeHandle()} method of the delegate. */
- + @Override
- + public long getNativeHandle() {
- + return proxiedDelegate.getNativeHandle();
- + }
-
- - private GpuDelegateProxy(Object instance) {
- - this.proxiedCloseable = (Closeable) instance;
- - this.proxiedDelegate = (Delegate) instance;
- - }
- + private GpuDelegateProxy(Object instance) {
- + this.proxiedCloseable = (Closeable) instance;
- + this.proxiedDelegate = (Delegate) instance;
- + }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/model/Model.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/model/Model.java
- index 09b63f1b12beb..282f2b9aa599c 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/model/Model.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/model/Model.java
- @@ -16,9 +16,7 @@ limitations under the License.
- package org.tensorflow.lite.support.model;
-
- import android.content.Context;
- -import java.io.IOException;
- -import java.nio.MappedByteBuffer;
- -import java.util.Map;
- +
- import org.checkerframework.checker.nullness.qual.NonNull;
- import org.checkerframework.checker.nullness.qual.Nullable;
- import org.tensorflow.lite.InterpreterApi;
- @@ -27,6 +25,10 @@ import org.tensorflow.lite.Tensor;
- import org.tensorflow.lite.support.common.FileUtil;
- import org.tensorflow.lite.support.common.internal.SupportPreconditions;
-
- +import java.io.IOException;
- +import java.nio.MappedByteBuffer;
- +import java.util.Map;
- +
- /**
- * The wrapper class for a TFLite model and a TFLite interpreter.
- *
- @@ -34,263 +36,254 @@ import org.tensorflow.lite.support.common.internal.SupportPreconditions;
- * interpreter instance to run it.
- */
- public class Model {
- + /** The runtime device type used for executing classification. */
- + public enum Device { CPU, NNAPI, GPU }
-
- - /** The runtime device type used for executing classification. */
- - public enum Device {
- - CPU,
- - NNAPI,
- - GPU
- - }
- -
- - /**
- - * Options for running the model. Configurable parameters includes:
- - *
- - * <ul>
- - * <li>{@code device} {@link Builder#setDevice(Device)} specifies the hardware to run the model.
- - * The default value is {@link Device#CPU}.
- - * <li>{@code numThreads} {@link Builder#setNumThreads(int)} specifies the number of threads
- - * used by TFLite inference. It's only effective when device is set to {@link Device#CPU}
- - * and default value is 1.
- - * </ul>
- - */
- - public static class Options {
- - private final Device device;
- - private final int numThreads;
- - private final TfLiteRuntime tfLiteRuntime;
- -
- - /** Builder of {@link Options}. See its doc for details. */
- - public static class Builder {
- - private Device device = Device.CPU;
- - private int numThreads = 1;
- - private TfLiteRuntime tfLiteRuntime;
- -
- - public Builder setDevice(Device device) {
- - this.device = device;
- - return this;
- - }
- -
- - public Builder setNumThreads(int numThreads) {
- - this.numThreads = numThreads;
- - return this;
- - }
- -
- - public Builder setTfLiteRuntime(TfLiteRuntime tfLiteRuntime) {
- - this.tfLiteRuntime = tfLiteRuntime;
- - return this;
- - }
- -
- - public Options build() {
- - return new Options(this);
- - }
- + /**
- + * Options for running the model. Configurable parameters includes:
- + *
- + * <ul>
- + * <li>{@code device} {@link Builder#setDevice(Device)} specifies the hardware to run the
- + * model. The default value is {@link Device#CPU}. <li>{@code numThreads} {@link
- + * Builder#setNumThreads(int)} specifies the number of threads used by TFLite inference. It's
- + * only effective when device is set to {@link Device#CPU} and default value is 1.
- + * </ul>
- + */
- + public static class Options {
- + private final Device device;
- + private final int numThreads;
- + private final TfLiteRuntime tfLiteRuntime;
- +
- + /** Builder of {@link Options}. See its doc for details. */
- + public static class Builder {
- + private Device device = Device.CPU;
- + private int numThreads = 1;
- + private TfLiteRuntime tfLiteRuntime;
- +
- + public Builder setDevice(Device device) {
- + this.device = device;
- + return this;
- + }
- +
- + public Builder setNumThreads(int numThreads) {
- + this.numThreads = numThreads;
- + return this;
- + }
- +
- + public Builder setTfLiteRuntime(TfLiteRuntime tfLiteRuntime) {
- + this.tfLiteRuntime = tfLiteRuntime;
- + return this;
- + }
- +
- + public Options build() {
- + return new Options(this);
- + }
- + }
- +
- + private Options(Builder builder) {
- + device = builder.device;
- + numThreads = builder.numThreads;
- + tfLiteRuntime = builder.tfLiteRuntime;
- + }
- }
-
- - private Options(Builder builder) {
- - device = builder.device;
- - numThreads = builder.numThreads;
- - tfLiteRuntime = builder.tfLiteRuntime;
- - }
- - }
- + /** An instance of the driver class to run model inference with Tensorflow Lite. */
- + private final InterpreterApi interpreter;
-
- - /** An instance of the driver class to run model inference with Tensorflow Lite. */
- - private final InterpreterApi interpreter;
- + /** Path to tflite model file in asset folder. */
- + private final String modelPath;
-
- - /** Path to tflite model file in asset folder. */
- - private final String modelPath;
- + /** The memory-mapped model data. */
- + private final MappedByteBuffer byteModel;
-
- - /** The memory-mapped model data. */
- - private final MappedByteBuffer byteModel;
- + private final GpuDelegateProxy gpuDelegateProxy;
-
- - private final GpuDelegateProxy gpuDelegateProxy;
- + /**
- + * Builder for {@link Model}.
- + *
- + * @deprecated Please use {@link Model#createModel(Context, String, Options)}.
- + */
- + @Deprecated
- + public static class Builder {
- + private Device device = Device.CPU;
- + private int numThreads = 1;
- + private final String modelPath;
- + private final MappedByteBuffer byteModel;
- +
- + /**
- + * Creates a builder which loads tflite model from asset folder using memory-mapped files.
- + *
- + * @param context Application context to access assets.
- + * @param modelPath Asset path of the model (.tflite file).
- + * @throws IOException if an I/O error occurs when loading the tflite model.
- + */
- + public Builder(@NonNull Context context, @NonNull String modelPath) throws IOException {
- + this.modelPath = modelPath;
- + byteModel = FileUtil.loadMappedFile(context, modelPath);
- + }
- +
- + /** Sets running device. By default, TFLite will run on CPU. */
- + @NonNull
- + public Builder setDevice(Device device) {
- + this.device = device;
- + return this;
- + }
- +
- + /** Sets number of threads. By default it's 1. */
- + @NonNull
- + public Builder setNumThreads(int numThreads) {
- + this.numThreads = numThreads;
- + return this;
- + }
- +
- + // Note: The implementation is copied from `Model#createModel`. As the builder is going to
- + // be deprecated, this function is also to be removed.
- + @NonNull
- + public Model build() {
- + Options options =
- + new Options.Builder().setNumThreads(numThreads).setDevice(device).build();
- + return createModel(byteModel, modelPath, options);
- + }
- + }
-
- - /**
- - * Builder for {@link Model}.
- - *
- - * @deprecated Please use {@link Model#createModel(Context, String, Options)}.
- - */
- - @Deprecated
- - public static class Builder {
- - private Device device = Device.CPU;
- - private int numThreads = 1;
- - private final String modelPath;
- - private final MappedByteBuffer byteModel;
- + /**
- + * Loads a model from assets and initialize TFLite interpreter.
- + *
- + * <p>The default options are: (1) CPU device; (2) one thread.
- + *
- + * @param context The App Context.
- + * @param modelPath The path of the model file.
- + * @throws IOException if any exception occurs when open the model file.
- + */
- + public static Model createModel(@NonNull Context context, @NonNull String modelPath)
- + throws IOException {
- + return createModel(context, modelPath, new Options.Builder().build());
- + }
-
- /**
- - * Creates a builder which loads tflite model from asset folder using memory-mapped files.
- + * Loads a model from assets and initialize TFLite interpreter with given options.
- *
- - * @param context Application context to access assets.
- - * @param modelPath Asset path of the model (.tflite file).
- - * @throws IOException if an I/O error occurs when loading the tflite model.
- + * @see Options for details.
- + * @param context The App Context.
- + * @param modelPath The path of the model file.
- + * @param options The options for running the model.
- + * @throws IOException if any exception occurs when open the model file.
- */
- - public Builder(@NonNull Context context, @NonNull String modelPath) throws IOException {
- - this.modelPath = modelPath;
- - byteModel = FileUtil.loadMappedFile(context, modelPath);
- + public static Model createModel(@NonNull Context context, @NonNull String modelPath,
- + @NonNull Options options) throws IOException {
- + SupportPreconditions.checkNotEmpty(
- + modelPath, "Model path in the asset folder cannot be empty.");
- + MappedByteBuffer byteModel = FileUtil.loadMappedFile(context, modelPath);
- + return createModel(byteModel, modelPath, options);
- }
-
- - /** Sets running device. By default, TFLite will run on CPU. */
- - @NonNull
- - public Builder setDevice(Device device) {
- - this.device = device;
- - return this;
- + /**
- + * Creates a model with loaded {@link MappedByteBuffer}.
- + *
- + * @see Options for details.
- + * @param byteModel The loaded TFLite model.
- + * @param modelPath The original path of the model. It can be fetched later by {@link
- + * Model#getPath()}.
- + * @param options The options for running the model.
- + * @throws IllegalArgumentException if {@code options.device} is {@link Device#GPU} but
- + * "tensorflow-lite-gpu" is not linked to the project.
- + */
- + public static Model createModel(@NonNull MappedByteBuffer byteModel, @NonNull String modelPath,
- + @NonNull Options options) {
- + InterpreterApi.Options interpreterOptions = new InterpreterApi.Options();
- + GpuDelegateProxy gpuDelegateProxy = null;
- + switch (options.device) {
- + case NNAPI:
- + interpreterOptions.setUseNNAPI(true);
- + break;
- + case GPU:
- + gpuDelegateProxy = GpuDelegateProxy.maybeNewInstance();
- + SupportPreconditions.checkArgument(gpuDelegateProxy != null,
- + "Cannot inference with GPU. Did you add \"tensorflow-lite-gpu\" as dependency?");
- + interpreterOptions.addDelegate(gpuDelegateProxy);
- + break;
- + case CPU:
- + break;
- + }
- + interpreterOptions.setNumThreads(options.numThreads);
- + if (options.tfLiteRuntime != null) {
- + interpreterOptions.setRuntime(options.tfLiteRuntime);
- + }
- + InterpreterApi interpreter = InterpreterApi.create(byteModel, interpreterOptions);
- + return new Model(modelPath, byteModel, interpreter, gpuDelegateProxy);
- }
-
- - /** Sets number of threads. By default it's 1. */
- + /** Returns the memory-mapped model data. */
- @NonNull
- - public Builder setNumThreads(int numThreads) {
- - this.numThreads = numThreads;
- - return this;
- + public MappedByteBuffer getData() {
- + return byteModel;
- }
-
- - // Note: The implementation is copied from `Model#createModel`. As the builder is going to be
- - // deprecated, this function is also to be removed.
- + /** Returns the path of the model file stored in Assets. */
- @NonNull
- - public Model build() {
- - Options options = new Options.Builder().setNumThreads(numThreads).setDevice(device).build();
- - return createModel(byteModel, modelPath, options);
- + public String getPath() {
- + return modelPath;
- }
- - }
- -
- - /**
- - * Loads a model from assets and initialize TFLite interpreter.
- - *
- - * <p>The default options are: (1) CPU device; (2) one thread.
- - *
- - * @param context The App Context.
- - * @param modelPath The path of the model file.
- - * @throws IOException if any exception occurs when open the model file.
- - */
- - public static Model createModel(@NonNull Context context, @NonNull String modelPath)
- - throws IOException {
- - return createModel(context, modelPath, new Options.Builder().build());
- - }
- -
- - /**
- - * Loads a model from assets and initialize TFLite interpreter with given options.
- - *
- - * @see Options for details.
- - * @param context The App Context.
- - * @param modelPath The path of the model file.
- - * @param options The options for running the model.
- - * @throws IOException if any exception occurs when open the model file.
- - */
- - public static Model createModel(
- - @NonNull Context context, @NonNull String modelPath, @NonNull Options options)
- - throws IOException {
- - SupportPreconditions.checkNotEmpty(
- - modelPath, "Model path in the asset folder cannot be empty.");
- - MappedByteBuffer byteModel = FileUtil.loadMappedFile(context, modelPath);
- - return createModel(byteModel, modelPath, options);
- - }
- -
- - /**
- - * Creates a model with loaded {@link MappedByteBuffer}.
- - *
- - * @see Options for details.
- - * @param byteModel The loaded TFLite model.
- - * @param modelPath The original path of the model. It can be fetched later by {@link
- - * Model#getPath()}.
- - * @param options The options for running the model.
- - * @throws IllegalArgumentException if {@code options.device} is {@link Device#GPU} but
- - * "tensorflow-lite-gpu" is not linked to the project.
- - */
- - public static Model createModel(
- - @NonNull MappedByteBuffer byteModel, @NonNull String modelPath, @NonNull Options options) {
- - InterpreterApi.Options interpreterOptions = new InterpreterApi.Options();
- - GpuDelegateProxy gpuDelegateProxy = null;
- - switch (options.device) {
- - case NNAPI:
- - interpreterOptions.setUseNNAPI(true);
- - break;
- - case GPU:
- - gpuDelegateProxy = GpuDelegateProxy.maybeNewInstance();
- - SupportPreconditions.checkArgument(
- - gpuDelegateProxy != null,
- - "Cannot inference with GPU. Did you add \"tensorflow-lite-gpu\" as dependency?");
- - interpreterOptions.addDelegate(gpuDelegateProxy);
- - break;
- - case CPU:
- - break;
- +
- + /**
- + * Gets the Tensor associated with the provided input index.
- + *
- + * @throws IllegalStateException if the interpreter is closed.
- + */
- + public Tensor getInputTensor(int inputIndex) {
- + return interpreter.getInputTensor(inputIndex);
- }
- - interpreterOptions.setNumThreads(options.numThreads);
- - if (options.tfLiteRuntime != null) {
- - interpreterOptions.setRuntime(options.tfLiteRuntime);
- +
- + /**
- + * Gets the Tensor associated with the provided output index.
- + *
- + * @throws IllegalStateException if the interpreter is closed.
- + */
- + public Tensor getOutputTensor(int outputIndex) {
- + return interpreter.getOutputTensor(outputIndex);
- }
- - InterpreterApi interpreter = InterpreterApi.create(byteModel, interpreterOptions);
- - return new Model(modelPath, byteModel, interpreter, gpuDelegateProxy);
- - }
- -
- - /** Returns the memory-mapped model data. */
- - @NonNull
- - public MappedByteBuffer getData() {
- - return byteModel;
- - }
- -
- - /** Returns the path of the model file stored in Assets. */
- - @NonNull
- - public String getPath() {
- - return modelPath;
- - }
- -
- - /**
- - * Gets the Tensor associated with the provided input index.
- - *
- - * @throws IllegalStateException if the interpreter is closed.
- - */
- - public Tensor getInputTensor(int inputIndex) {
- - return interpreter.getInputTensor(inputIndex);
- - }
- -
- - /**
- - * Gets the Tensor associated with the provided output index.
- - *
- - * @throws IllegalStateException if the interpreter is closed.
- - */
- - public Tensor getOutputTensor(int outputIndex) {
- - return interpreter.getOutputTensor(outputIndex);
- - }
- -
- - /**
- - * Returns the output shape. Useful if output shape is only determined when graph is created.
- - *
- - * @throws IllegalStateException if the interpreter is closed.
- - */
- - public int[] getOutputTensorShape(int outputIndex) {
- - return interpreter.getOutputTensor(outputIndex).shape();
- - }
- -
- - /**
- - * Runs model inference on multiple inputs, and returns multiple outputs.
- - *
- - * @param inputs an array of input data. The inputs should be in the same order as inputs of the
- - * model. Each input can be an array or multidimensional array, or a {@link
- - * java.nio.ByteBuffer} of primitive types including int, float, long, and byte. {@link
- - * java.nio.ByteBuffer} is the preferred way to pass large input data, whereas string types
- - * require using the (multi-dimensional) array input path. When {@link java.nio.ByteBuffer} is
- - * used, its content should remain unchanged until model inference is done.
- - * @param outputs a map mapping output indices to multidimensional arrays of output data or {@link
- - * java.nio.ByteBuffer}s of primitive types including int, float, long, and byte. It only
- - * needs to keep entries for the outputs to be used.
- - */
- - public void run(@NonNull Object[] inputs, @NonNull Map<Integer, Object> outputs) {
- - interpreter.runForMultipleInputsOutputs(inputs, outputs);
- - }
- -
- - public void close() {
- - if (interpreter != null) {
- - interpreter.close();
- +
- + /**
- + * Returns the output shape. Useful if output shape is only determined when graph is created.
- + *
- + * @throws IllegalStateException if the interpreter is closed.
- + */
- + public int[] getOutputTensorShape(int outputIndex) {
- + return interpreter.getOutputTensor(outputIndex).shape();
- + }
- +
- + /**
- + * Runs model inference on multiple inputs, and returns multiple outputs.
- + *
- + * @param inputs an array of input data. The inputs should be in the same order as inputs of the
- + * model. Each input can be an array or multidimensional array, or a {@link
- + * java.nio.ByteBuffer} of primitive types including int, float, long, and byte. {@link
- + * java.nio.ByteBuffer} is the preferred way to pass large input data, whereas string types
- + * require using the (multi-dimensional) array input path. When {@link java.nio.ByteBuffer}
- + * is used, its content should remain unchanged until model inference is done.
- + * @param outputs a map mapping output indices to multidimensional arrays of output data or
- + * {@link
- + * java.nio.ByteBuffer}s of primitive types including int, float, long, and byte. It only
- + * needs to keep entries for the outputs to be used.
- + */
- + public void run(@NonNull Object[] inputs, @NonNull Map<Integer, Object> outputs) {
- + interpreter.runForMultipleInputsOutputs(inputs, outputs);
- }
- - if (gpuDelegateProxy != null) {
- - gpuDelegateProxy.close();
- +
- + public void close() {
- + if (interpreter != null) {
- + interpreter.close();
- + }
- + if (gpuDelegateProxy != null) {
- + gpuDelegateProxy.close();
- + }
- + }
- +
- + private Model(@NonNull String modelPath, @NonNull MappedByteBuffer byteModel,
- + @NonNull InterpreterApi interpreter, @Nullable GpuDelegateProxy gpuDelegateProxy) {
- + this.modelPath = modelPath;
- + this.byteModel = byteModel;
- + this.interpreter = interpreter;
- + this.gpuDelegateProxy = gpuDelegateProxy;
- }
- - }
- -
- - private Model(
- - @NonNull String modelPath,
- - @NonNull MappedByteBuffer byteModel,
- - @NonNull InterpreterApi interpreter,
- - @Nullable GpuDelegateProxy gpuDelegateProxy) {
- - this.modelPath = modelPath;
- - this.byteModel = byteModel;
- - this.interpreter = interpreter;
- - this.gpuDelegateProxy = gpuDelegateProxy;
- - }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBuffer.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBuffer.java
- index 9e0204bdc2e71..ec6c800ef557a 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBuffer.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBuffer.java
- @@ -19,473 +19,476 @@ import static org.tensorflow.lite.support.common.internal.SupportPreconditions.c
- import static org.tensorflow.lite.support.common.internal.SupportPreconditions.checkNotNull;
- import static org.tensorflow.lite.support.common.internal.SupportPreconditions.checkState;
-
- +import org.checkerframework.checker.nullness.qual.NonNull;
- +import org.tensorflow.lite.DataType;
- +
- import java.nio.ByteBuffer;
- import java.nio.ByteOrder;
- import java.util.Arrays;
- -import org.checkerframework.checker.nullness.qual.NonNull;
- -import org.tensorflow.lite.DataType;
-
- /** Represents the data buffer for either a model's input or its output. */
- public abstract class TensorBuffer {
- - /** Where the data is stored. */
- - protected ByteBuffer buffer;
- -
- - /** Shape of the tensor stored in this buffer. */
- - protected int[] shape;
- -
- - /** Number of elements in the buffer. It will be changed to a proper value in the constructor. */
- - protected int flatSize = -1;
- -
- - /**
- - * Indicator of whether this buffer is dynamic or fixed-size. Fixed-size buffers will have
- - * pre-allocated memory and fixed size. While the size of dynamic buffers can be changed.
- - */
- - protected final boolean isDynamic;
- -
- - /**
- - * Creates a {@link TensorBuffer} with specified {@code shape} and {@link DataType}. Here are some
- - * examples:
- - *
- - * <pre>
- - * // Creating a float TensorBuffer with shape {2, 3}:
- - * int[] shape = new int[] {2, 3};
- - * TensorBuffer tensorBuffer = TensorBuffer.createFixedSize(shape, DataType.FLOAT32);
- - * </pre>
- - *
- - * <pre>
- - * // Creating an uint8 TensorBuffer of a scalar:
- - * int[] shape = new int[] {};
- - * TensorBuffer tensorBuffer = TensorBuffer.createFixedSize(shape, DataType.UINT8);
- - * </pre>
- - *
- - * <pre>
- - * // Creating an empty uint8 TensorBuffer:
- - * int[] shape = new int[] {0};
- - * TensorBuffer tensorBuffer = TensorBuffer.createFixedSize(shape, DataType.UINT8);
- - * </pre>
- - *
- - * <p>The size of a fixed-size TensorBuffer cannot be changed once it is created.
- - *
- - * @param shape The shape of the {@link TensorBuffer} to be created.
- - * @param dataType The dataType of the {@link TensorBuffer} to be created.
- - * @throws NullPointerException if {@code shape} is null.
- - * @throws IllegalArgumentException if {@code shape} has non-positive elements.
- - */
- - @NonNull
- - public static TensorBuffer createFixedSize(@NonNull int[] shape, DataType dataType) {
- - switch (dataType) {
- - case FLOAT32:
- - return new TensorBufferFloat(shape);
- - case UINT8:
- - return new TensorBufferUint8(shape);
- - default:
- - throw new AssertionError("TensorBuffer does not support data type: " + dataType);
- + /** Where the data is stored. */
- + protected ByteBuffer buffer;
- +
- + /** Shape of the tensor stored in this buffer. */
- + protected int[] shape;
- +
- + /**
- + * Number of elements in the buffer. It will be changed to a proper value in the constructor.
- + */
- + protected int flatSize = -1;
- +
- + /**
- + * Indicator of whether this buffer is dynamic or fixed-size. Fixed-size buffers will have
- + * pre-allocated memory and fixed size. While the size of dynamic buffers can be changed.
- + */
- + protected final boolean isDynamic;
- +
- + /**
- + * Creates a {@link TensorBuffer} with specified {@code shape} and {@link DataType}. Here are
- + * some examples:
- + *
- + * <pre>
- + * // Creating a float TensorBuffer with shape {2, 3}:
- + * int[] shape = new int[] {2, 3};
- + * TensorBuffer tensorBuffer = TensorBuffer.createFixedSize(shape, DataType.FLOAT32);
- + * </pre>
- + *
- + * <pre>
- + * // Creating an uint8 TensorBuffer of a scalar:
- + * int[] shape = new int[] {};
- + * TensorBuffer tensorBuffer = TensorBuffer.createFixedSize(shape, DataType.UINT8);
- + * </pre>
- + *
- + * <pre>
- + * // Creating an empty uint8 TensorBuffer:
- + * int[] shape = new int[] {0};
- + * TensorBuffer tensorBuffer = TensorBuffer.createFixedSize(shape, DataType.UINT8);
- + * </pre>
- + *
- + * <p>The size of a fixed-size TensorBuffer cannot be changed once it is created.
- + *
- + * @param shape The shape of the {@link TensorBuffer} to be created.
- + * @param dataType The dataType of the {@link TensorBuffer} to be created.
- + * @throws NullPointerException if {@code shape} is null.
- + * @throws IllegalArgumentException if {@code shape} has non-positive elements.
- + */
- + @NonNull
- + public static TensorBuffer createFixedSize(@NonNull int[] shape, DataType dataType) {
- + switch (dataType) {
- + case FLOAT32:
- + return new TensorBufferFloat(shape);
- + case UINT8:
- + return new TensorBufferUint8(shape);
- + default:
- + throw new AssertionError("TensorBuffer does not support data type: " + dataType);
- + }
- + }
- +
- + /**
- + * Creates an empty dynamic {@link TensorBuffer} with specified {@link DataType}. The shape of
- + * the created {@link TensorBuffer} is {0}.
- + *
- + * <p>Dynamic TensorBuffers will reallocate memory when loading arrays or data buffers of
- + * different buffer sizes. Here are some examples:
- + *
- + * <pre>
- + * // Creating a float dynamic TensorBuffer:
- + * TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.FLOAT32);
- + * // Loading a float array:
- + * float[] arr1 = new float[] {1, 2, 3};
- + * tensorBuffer.loadArray(arr, new int[] {arr1.length});
- + * // loading another float array:
- + * float[] arr2 = new float[] {1, 2, 3, 4, 5};
- + * tensorBuffer.loadArray(arr, new int[] {arr2.length});
- + * // loading a third float array with the same size as arr2, assuming shape doesn't change:
- + * float[] arr3 = new float[] {5, 4, 3, 2, 1};
- + * tensorBuffer.loadArray(arr);
- + * // loading a forth float array with different size as arr3 and omitting the shape will result
- + * // in error:
- + * float[] arr4 = new float[] {3, 2, 1};
- + * tensorBuffer.loadArray(arr); // Error: The size of byte buffer and the shape do not match.
- + * </pre>
- + *
- + * @param dataType The dataType of the {@link TensorBuffer} to be created.
- + */
- + @NonNull
- + public static TensorBuffer createDynamic(DataType dataType) {
- + switch (dataType) {
- + case FLOAT32:
- + return new TensorBufferFloat();
- + case UINT8:
- + return new TensorBufferUint8();
- + default:
- + throw new AssertionError("TensorBuffer does not support data type: " + dataType);
- + }
- }
- - }
- -
- - /**
- - * Creates an empty dynamic {@link TensorBuffer} with specified {@link DataType}. The shape of the
- - * created {@link TensorBuffer} is {0}.
- - *
- - * <p>Dynamic TensorBuffers will reallocate memory when loading arrays or data buffers of
- - * different buffer sizes. Here are some examples:
- - *
- - * <pre>
- - * // Creating a float dynamic TensorBuffer:
- - * TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.FLOAT32);
- - * // Loading a float array:
- - * float[] arr1 = new float[] {1, 2, 3};
- - * tensorBuffer.loadArray(arr, new int[] {arr1.length});
- - * // loading another float array:
- - * float[] arr2 = new float[] {1, 2, 3, 4, 5};
- - * tensorBuffer.loadArray(arr, new int[] {arr2.length});
- - * // loading a third float array with the same size as arr2, assuming shape doesn't change:
- - * float[] arr3 = new float[] {5, 4, 3, 2, 1};
- - * tensorBuffer.loadArray(arr);
- - * // loading a forth float array with different size as arr3 and omitting the shape will result
- - * // in error:
- - * float[] arr4 = new float[] {3, 2, 1};
- - * tensorBuffer.loadArray(arr); // Error: The size of byte buffer and the shape do not match.
- - * </pre>
- - *
- - * @param dataType The dataType of the {@link TensorBuffer} to be created.
- - */
- - @NonNull
- - public static TensorBuffer createDynamic(DataType dataType) {
- - switch (dataType) {
- - case FLOAT32:
- - return new TensorBufferFloat();
- - case UINT8:
- - return new TensorBufferUint8();
- - default:
- - throw new AssertionError("TensorBuffer does not support data type: " + dataType);
- +
- + /**
- + * Creates a {@link TensorBuffer} deep-copying data from another, with specified {@link
- + * DataType}.
- + *
- + * @param buffer the source {@link TensorBuffer} to copy from.
- + * @param dataType the expected {@link DataType} of newly created {@link TensorBuffer}.
- + * @throws NullPointerException if {@code buffer} is null.
- + */
- + @NonNull
- + public static TensorBuffer createFrom(@NonNull TensorBuffer buffer, DataType dataType) {
- + checkNotNull(buffer, "Cannot create a buffer from null");
- + TensorBuffer result;
- + if (buffer.isDynamic()) {
- + result = createDynamic(dataType);
- + } else {
- + result = createFixedSize(buffer.shape, dataType);
- + }
- + // The only scenario we need float array is FLOAT32->FLOAT32, or we can always use INT as
- + // intermediate container.
- + // The assumption is not true when we support other data types.
- + if (buffer.getDataType() == DataType.FLOAT32 && dataType == DataType.FLOAT32) {
- + float[] data = buffer.getFloatArray();
- + result.loadArray(data, buffer.shape);
- + } else {
- + int[] data = buffer.getIntArray();
- + result.loadArray(data, buffer.shape);
- + }
- + return result;
- }
- - }
- -
- - /**
- - * Creates a {@link TensorBuffer} deep-copying data from another, with specified {@link DataType}.
- - *
- - * @param buffer the source {@link TensorBuffer} to copy from.
- - * @param dataType the expected {@link DataType} of newly created {@link TensorBuffer}.
- - * @throws NullPointerException if {@code buffer} is null.
- - */
- - @NonNull
- - public static TensorBuffer createFrom(@NonNull TensorBuffer buffer, DataType dataType) {
- - checkNotNull(buffer, "Cannot create a buffer from null");
- - TensorBuffer result;
- - if (buffer.isDynamic()) {
- - result = createDynamic(dataType);
- - } else {
- - result = createFixedSize(buffer.shape, dataType);
- +
- + /** Returns the data buffer. */
- + @NonNull
- + public ByteBuffer getBuffer() {
- + return buffer;
- }
- - // The only scenario we need float array is FLOAT32->FLOAT32, or we can always use INT as
- - // intermediate container.
- - // The assumption is not true when we support other data types.
- - if (buffer.getDataType() == DataType.FLOAT32 && dataType == DataType.FLOAT32) {
- - float[] data = buffer.getFloatArray();
- - result.loadArray(data, buffer.shape);
- - } else {
- - int[] data = buffer.getIntArray();
- - result.loadArray(data, buffer.shape);
- +
- + /**
- + * Gets the flatSize of the buffer.
- + *
- + * @throws IllegalStateException if the underlying data is corrupted
- + */
- + public int getFlatSize() {
- + assertShapeIsCorrect();
- + return flatSize;
- }
- - return result;
- - }
- -
- - /** Returns the data buffer. */
- - @NonNull
- - public ByteBuffer getBuffer() {
- - return buffer;
- - }
- -
- - /**
- - * Gets the flatSize of the buffer.
- - *
- - * @throws IllegalStateException if the underlying data is corrupted
- - */
- - public int getFlatSize() {
- - assertShapeIsCorrect();
- - return flatSize;
- - }
- -
- - /**
- - * Gets the current shape. (returning a copy here to avoid unexpected modification.)
- - *
- - * @throws IllegalStateException if the underlying data is corrupted
- - */
- - @NonNull
- - public int[] getShape() {
- - assertShapeIsCorrect();
- - return Arrays.copyOf(shape, shape.length);
- - }
- -
- - /** Returns the data type of this buffer. */
- - public abstract DataType getDataType();
- -
- - /**
- - * Returns a float array of the values stored in this buffer. If the buffer is of different types
- - * than float, the values will be converted into float. For example, values in {@link
- - * TensorBufferUint8} will be converted from uint8 to float.
- - */
- - @NonNull
- - public abstract float[] getFloatArray();
- -
- - /**
- - * Returns a float value at a given index. If the buffer is of different types than float, the
- - * value will be converted into float. For example, when reading a value from {@link
- - * TensorBufferUint8}, the value will be first read out as uint8, and then will be converted from
- - * uint8 to float.
- - *
- - * <pre>
- - * For example, a TensorBuffer with shape {2, 3} that represents the following array,
- - * [[0.0f, 1.0f, 2.0f], [3.0f, 4.0f, 5.0f]].
- - *
- - * The fourth element (whose value is 3.0f) in the TensorBuffer can be retrieved by:
- - * float v = tensorBuffer.getFloatValue(3);
- - * </pre>
- - *
- - * @param absIndex The absolute index of the value to be read.
- - */
- - public abstract float getFloatValue(int absIndex);
- -
- - /**
- - * Returns an int array of the values stored in this buffer. If the buffer is of different type
- - * than int, the values will be converted into int, and loss of precision may apply. For example,
- - * getting an int array from a {@link TensorBufferFloat} with values {400.32f, 23.04f}, the output
- - * is {400, 23}.
- - */
- - @NonNull
- - public abstract int[] getIntArray();
- -
- - /**
- - * Returns an int value at a given index. If the buffer is of different types than int, the value
- - * will be converted into int. For example, when reading a value from {@link TensorBufferFloat},
- - * the value will be first read out as float, and then will be converted from float to int. Loss
- - * of precision may apply.
- - *
- - * <pre>
- - * For example, a TensorBuffer with shape {2, 3} that represents the following array,
- - * [[0.0f, 1.0f, 2.0f], [3.0f, 4.0f, 5.0f]].
- - *
- - * The fourth element (whose value is 3.0f) in the TensorBuffer can be retrieved by:
- - * int v = tensorBuffer.getIntValue(3);
- - * Note that v is converted from 3.0f to 3 as a result of type conversion.
- - * </pre>
- - *
- - * @param absIndex The absolute index of the value to be read.
- - */
- - public abstract int getIntValue(int absIndex);
- -
- - /**
- - * Returns the number of bytes of a single element in the array. For example, a float buffer will
- - * return 4, and a byte buffer will return 1.
- - */
- - public abstract int getTypeSize();
- -
- - /** Returns if the {@link TensorBuffer} is dynamic sized (could resize arbitrarily). */
- - public boolean isDynamic() {
- - return isDynamic;
- - }
- -
- - /**
- - * Loads an int array into this buffer with specific shape. If the buffer is of different types
- - * than int, the values will be converted into the buffer's type before being loaded into the
- - * buffer, and loss of precision may apply. For example, loading an int array with values {400,
- - * -23} into a {@link TensorBufferUint8} , the values will be clamped to [0, 255] and then be
- - * casted to uint8 by {255, 0}.
- - *
- - * @param src The source array to be loaded.
- - * @param shape Shape of the tensor that {@code src} represents.
- - * @throws NullPointerException if {@code src} is null.
- - * @throws NullPointerException if {@code shape} is null.
- - * @throws IllegalArgumentException if the size of the array to be loaded does not match the
- - * specified shape.
- - */
- - public abstract void loadArray(@NonNull int[] src, @NonNull int[] shape);
- -
- - /**
- - * Loads an int array into this buffer. If the buffer is of different types than int, the values
- - * will be converted into the buffer's type before being loaded into the buffer, and loss of
- - * precision may apply. For example, loading an int array with values {400, -23} into a {@link
- - * TensorBufferUint8} , the values will be clamped to [0, 255] and then be casted to uint8 by
- - * {255, 0}.
- - *
- - * <p>Using this method assumes that the shape of {@code src} is the same as the shape of this
- - * {@link TensorBuffer}. Thus the size of {@code buffer} ({@code src.length}) should always match
- - * the flat size of this {@link TensorBuffer}, for both fixed-size and dynamic {@link
- - * TensorBuffer}. Use {@link #loadArray(int[], int[])} if {@code src} has a different shape.
- - *
- - * @param src The source array to be loaded.
- - */
- - public void loadArray(@NonNull int[] src) {
- - loadArray(src, shape);
- - }
- -
- - /**
- - * Loads a float array into this buffer with specific shape. If the buffer is of different types
- - * than float, the values will be converted into the buffer's type before being loaded into the
- - * buffer, and loss of precision may apply. For example, loading a float array into a {@link
- - * TensorBufferUint8} with values {400.32f, -23.04f}, the values will be clamped to [0, 255] and
- - * then be casted to uint8 by {255, 0}.
- - *
- - * @param src The source array to be loaded.
- - * @param shape Shape of the tensor that {@code src} represents.
- - * @throws NullPointerException if {@code src} is null.
- - * @throws NullPointerException if {@code shape} is null.
- - * @throws IllegalArgumentException if the size of the array to be loaded does not match the
- - * specified shape.
- - */
- - public abstract void loadArray(@NonNull float[] src, @NonNull int[] shape);
- -
- - /**
- - * Loads a float array into this buffer. If the buffer is of different types than float, the
- - * values will be converted into the buffer's type before being loaded into the buffer, and loss
- - * of precision may apply. For example, loading a float array into a {@link TensorBufferUint8}
- - * with values {400.32f, -23.04f}, the values will be clamped to [0, 255] and then be casted to
- - * uint8 by {255, 0}.
- - *
- - * <p>Using this method assumes that the shape of {@code src} is the same as the shape of this
- - * {@link TensorBuffer}. Thus the size of {@code buffer} ({@code src.length}) should always match
- - * the flat size of this {@link TensorBuffer}, for both fixed-size and dynamic {@link
- - * TensorBuffer}. Use {@link #loadArray(float[], int[])} if {@code src} has a different shape.
- - *
- - * @param src The source array to be loaded.
- - */
- - public void loadArray(@NonNull float[] src) {
- - loadArray(src, shape);
- - }
- -
- - /**
- - * Loads a byte buffer into this {@link TensorBuffer} with specific shape.
- - *
- - * <p>Important: The loaded buffer is a reference. DO NOT MODIFY. We don't create a copy here for
- - * performance concern, but if modification is necessary, please make a copy.
- - *
- - * <p>For the best performance, always load a direct {@link ByteBuffer} or a {@link ByteBuffer}
- - * backed by an array.
- - *
- - * @param buffer The byte buffer to load.
- - * @throws NullPointerException if {@code buffer} is null.
- - * @throws IllegalArgumentException if the size of {@code buffer} and {@code typeSize} do not
- - * match or the size of {@code buffer} and {@code flatSize} do not match.
- - */
- - public void loadBuffer(@NonNull ByteBuffer buffer, @NonNull int[] shape) {
- - checkNotNull(buffer, "Byte buffer cannot be null.");
- - checkArgument(isShapeValid(shape), "Values in TensorBuffer shape should be non-negative.");
- -
- - int flatSize = computeFlatSize(shape);
- - checkArgument(
- - (buffer.limit() == getTypeSize() * flatSize),
- - "The size of byte buffer and the shape do not match. Expected: "
- - + getTypeSize() * flatSize
- - + " Actual: "
- - + buffer.limit());
- -
- - if (!isDynamic) {
- - // Make sure the new shape fits the buffer size when TensorBuffer has fixed size.
- - checkArgument(Arrays.equals(shape, this.shape));
- +
- + /**
- + * Gets the current shape. (returning a copy here to avoid unexpected modification.)
- + *
- + * @throws IllegalStateException if the underlying data is corrupted
- + */
- + @NonNull
- + public int[] getShape() {
- + assertShapeIsCorrect();
- + return Arrays.copyOf(shape, shape.length);
- }
-
- - // Update to the new shape, since shape dim values might change.
- - this.shape = shape.clone();
- - this.flatSize = flatSize;
- -
- - buffer.rewind();
- - this.buffer = buffer;
- - }
- -
- - /**
- - * Loads a byte buffer into this {@link TensorBuffer}. Buffer size must match the flat size of
- - * this {@link TensorBuffer}.
- - *
- - * <p>Using this method assumes that the shape of {@code buffer} is the same as the shape of this
- - * {@link TensorBuffer}. Thus the size of {@code buffer} ({@code buffer.limit()}) should always
- - * match the flat size of this {@link TensorBuffer}, for both fixed-size and dynamic {@link
- - * TensorBuffer}. Use {@link #loadBuffer(ByteBuffer, int[])} if {@code buffer} has a different
- - * shape.
- - *
- - * <p>Important: The loaded buffer is a reference. DO NOT MODIFY. We don't create a copy here for
- - * performance concern, but if modification is necessary, please make a copy.
- - *
- - * <p>For the best performance, always load a direct {@link ByteBuffer} or a {@link ByteBuffer}
- - * backed by an array.
- - *
- - * <p>If the {@code buffer} is read-only, we adopt a copy-on-write strategy for performance.
- - *
- - * @param buffer The byte buffer to load.
- - */
- - public void loadBuffer(@NonNull ByteBuffer buffer) {
- - loadBuffer(buffer, shape);
- - }
- -
- - /**
- - * Constructs a fixed size {@link TensorBuffer} with specified {@code shape}.
- - *
- - * @throws NullPointerException if {@code shape} is null.
- - * @throws IllegalArgumentException if {@code shape} has non-positive elements.
- - */
- - protected TensorBuffer(@NonNull int[] shape) {
- - isDynamic = false;
- - allocateMemory(shape);
- - }
- -
- - /** Constructs a dynamic {@link TensorBuffer} which can be resized. */
- - protected TensorBuffer() {
- - isDynamic = true;
- - // Initialize the dynamic TensorBuffer with an empty ByteBuffer.
- - allocateMemory(new int[] {0});
- - }
- -
- - /** Calculates number of elements in the buffer. */
- - protected static int computeFlatSize(@NonNull int[] shape) {
- - checkNotNull(shape, "Shape cannot be null.");
- - int prod = 1;
- - for (int s : shape) {
- - prod = prod * s;
- + /** Returns the data type of this buffer. */
- + public abstract DataType getDataType();
- +
- + /**
- + * Returns a float array of the values stored in this buffer. If the buffer is of different
- + * types than float, the values will be converted into float. For example, values in {@link
- + * TensorBufferUint8} will be converted from uint8 to float.
- + */
- + @NonNull
- + public abstract float[] getFloatArray();
- +
- + /**
- + * Returns a float value at a given index. If the buffer is of different types than float, the
- + * value will be converted into float. For example, when reading a value from {@link
- + * TensorBufferUint8}, the value will be first read out as uint8, and then will be converted
- + * from uint8 to float.
- + *
- + * <pre>
- + * For example, a TensorBuffer with shape {2, 3} that represents the following array,
- + * [[0.0f, 1.0f, 2.0f], [3.0f, 4.0f, 5.0f]].
- + *
- + * The fourth element (whose value is 3.0f) in the TensorBuffer can be retrieved by:
- + * float v = tensorBuffer.getFloatValue(3);
- + * </pre>
- + *
- + * @param absIndex The absolute index of the value to be read.
- + */
- + public abstract float getFloatValue(int absIndex);
- +
- + /**
- + * Returns an int array of the values stored in this buffer. If the buffer is of different type
- + * than int, the values will be converted into int, and loss of precision may apply. For
- + * example, getting an int array from a {@link TensorBufferFloat} with values {400.32f, 23.04f},
- + * the output is {400, 23}.
- + */
- + @NonNull
- + public abstract int[] getIntArray();
- +
- + /**
- + * Returns an int value at a given index. If the buffer is of different types than int, the
- + * value will be converted into int. For example, when reading a value from {@link
- + * TensorBufferFloat}, the value will be first read out as float, and then will be converted
- + * from float to int. Loss of precision may apply.
- + *
- + * <pre>
- + * For example, a TensorBuffer with shape {2, 3} that represents the following array,
- + * [[0.0f, 1.0f, 2.0f], [3.0f, 4.0f, 5.0f]].
- + *
- + * The fourth element (whose value is 3.0f) in the TensorBuffer can be retrieved by:
- + * int v = tensorBuffer.getIntValue(3);
- + * Note that v is converted from 3.0f to 3 as a result of type conversion.
- + * </pre>
- + *
- + * @param absIndex The absolute index of the value to be read.
- + */
- + public abstract int getIntValue(int absIndex);
- +
- + /**
- + * Returns the number of bytes of a single element in the array. For example, a float buffer
- + * will return 4, and a byte buffer will return 1.
- + */
- + public abstract int getTypeSize();
- +
- + /** Returns if the {@link TensorBuffer} is dynamic sized (could resize arbitrarily). */
- + public boolean isDynamic() {
- + return isDynamic;
- }
- - return prod;
- - }
- -
- - /**
- - * For dynamic buffer, resize the memory if needed. For fixed-size buffer, check if the {@code
- - * shape} of src fits the buffer size.
- - */
- - protected void resize(@NonNull int[] shape) {
- - if (isDynamic) {
- - allocateMemory(shape);
- - } else {
- - // Make sure the new shape fits the buffer size when TensorBuffer has fixed size.
- - checkArgument(Arrays.equals(shape, this.shape));
- - this.shape = shape.clone();
- +
- + /**
- + * Loads an int array into this buffer with specific shape. If the buffer is of different types
- + * than int, the values will be converted into the buffer's type before being loaded into the
- + * buffer, and loss of precision may apply. For example, loading an int array with values {400,
- + * -23} into a {@link TensorBufferUint8} , the values will be clamped to [0, 255] and then be
- + * casted to uint8 by {255, 0}.
- + *
- + * @param src The source array to be loaded.
- + * @param shape Shape of the tensor that {@code src} represents.
- + * @throws NullPointerException if {@code src} is null.
- + * @throws NullPointerException if {@code shape} is null.
- + * @throws IllegalArgumentException if the size of the array to be loaded does not match the
- + * specified shape.
- + */
- + public abstract void loadArray(@NonNull int[] src, @NonNull int[] shape);
- +
- + /**
- + * Loads an int array into this buffer. If the buffer is of different types than int, the values
- + * will be converted into the buffer's type before being loaded into the buffer, and loss of
- + * precision may apply. For example, loading an int array with values {400, -23} into a {@link
- + * TensorBufferUint8} , the values will be clamped to [0, 255] and then be casted to uint8 by
- + * {255, 0}.
- + *
- + * <p>Using this method assumes that the shape of {@code src} is the same as the shape of this
- + * {@link TensorBuffer}. Thus the size of {@code buffer} ({@code src.length}) should always
- + * match the flat size of this {@link TensorBuffer}, for both fixed-size and dynamic {@link
- + * TensorBuffer}. Use {@link #loadArray(int[], int[])} if {@code src} has a different shape.
- + *
- + * @param src The source array to be loaded.
- + */
- + public void loadArray(@NonNull int[] src) {
- + loadArray(src, shape);
- + }
- +
- + /**
- + * Loads a float array into this buffer with specific shape. If the buffer is of different types
- + * than float, the values will be converted into the buffer's type before being loaded into the
- + * buffer, and loss of precision may apply. For example, loading a float array into a {@link
- + * TensorBufferUint8} with values {400.32f, -23.04f}, the values will be clamped to [0, 255] and
- + * then be casted to uint8 by {255, 0}.
- + *
- + * @param src The source array to be loaded.
- + * @param shape Shape of the tensor that {@code src} represents.
- + * @throws NullPointerException if {@code src} is null.
- + * @throws NullPointerException if {@code shape} is null.
- + * @throws IllegalArgumentException if the size of the array to be loaded does not match the
- + * specified shape.
- + */
- + public abstract void loadArray(@NonNull float[] src, @NonNull int[] shape);
- +
- + /**
- + * Loads a float array into this buffer. If the buffer is of different types than float, the
- + * values will be converted into the buffer's type before being loaded into the buffer, and loss
- + * of precision may apply. For example, loading a float array into a {@link TensorBufferUint8}
- + * with values {400.32f, -23.04f}, the values will be clamped to [0, 255] and then be casted to
- + * uint8 by {255, 0}.
- + *
- + * <p>Using this method assumes that the shape of {@code src} is the same as the shape of this
- + * {@link TensorBuffer}. Thus the size of {@code buffer} ({@code src.length}) should always
- + * match the flat size of this {@link TensorBuffer}, for both fixed-size and dynamic {@link
- + * TensorBuffer}. Use {@link #loadArray(float[], int[])} if {@code src} has a different shape.
- + *
- + * @param src The source array to be loaded.
- + */
- + public void loadArray(@NonNull float[] src) {
- + loadArray(src, shape);
- }
- - }
-
- - /** Copies the underlying {@link ByteBuffer} if it's readonly. */
- - protected synchronized void copyByteBufferIfReadOnly() {
- - if (!buffer.isReadOnly()) {
- - return;
- + /**
- + * Loads a byte buffer into this {@link TensorBuffer} with specific shape.
- + *
- + * <p>Important: The loaded buffer is a reference. DO NOT MODIFY. We don't create a copy here
- + * for performance concern, but if modification is necessary, please make a copy.
- + *
- + * <p>For the best performance, always load a direct {@link ByteBuffer} or a {@link ByteBuffer}
- + * backed by an array.
- + *
- + * @param buffer The byte buffer to load.
- + * @throws NullPointerException if {@code buffer} is null.
- + * @throws IllegalArgumentException if the size of {@code buffer} and {@code typeSize} do not
- + * match or the size of {@code buffer} and {@code flatSize} do not match.
- + */
- + public void loadBuffer(@NonNull ByteBuffer buffer, @NonNull int[] shape) {
- + checkNotNull(buffer, "Byte buffer cannot be null.");
- + checkArgument(isShapeValid(shape), "Values in TensorBuffer shape should be non-negative.");
- +
- + int flatSize = computeFlatSize(shape);
- + checkArgument((buffer.limit() == getTypeSize() * flatSize),
- + "The size of byte buffer and the shape do not match. Expected: "
- + + getTypeSize() * flatSize + " Actual: " + buffer.limit());
- +
- + if (!isDynamic) {
- + // Make sure the new shape fits the buffer size when TensorBuffer has fixed size.
- + checkArgument(Arrays.equals(shape, this.shape));
- + }
- +
- + // Update to the new shape, since shape dim values might change.
- + this.shape = shape.clone();
- + this.flatSize = flatSize;
- +
- + buffer.rewind();
- + this.buffer = buffer;
- }
- - ByteBuffer newByteBuffer = ByteBuffer.allocateDirect(buffer.capacity());
- - newByteBuffer.order(buffer.order());
- - newByteBuffer.put(buffer);
- - newByteBuffer.rewind();
- - buffer = newByteBuffer;
- - }
- -
- - /**
- - * Allocates buffer with corresponding size of the {@code shape}. If shape is an empty array, this
- - * {@link TensorBuffer} will be created as a scalar and its flatSize will be 1.
- - *
- - * @throws NullPointerException if {@code shape} is null.
- - * @throws IllegalArgumentException if {@code shape} has negative elements.
- - */
- - private void allocateMemory(@NonNull int[] shape) {
- - checkNotNull(shape, "TensorBuffer shape cannot be null.");
- - checkArgument(isShapeValid(shape), "Values in TensorBuffer shape should be non-negative.");
- -
- - // Check if the new shape is the same as current shape.
- - int newFlatSize = computeFlatSize(shape);
- - this.shape = shape.clone();
- - if (flatSize == newFlatSize) {
- - return;
- +
- + /**
- + * Loads a byte buffer into this {@link TensorBuffer}. Buffer size must match the flat size of
- + * this {@link TensorBuffer}.
- + *
- + * <p>Using this method assumes that the shape of {@code buffer} is the same as the shape of
- + * this
- + * {@link TensorBuffer}. Thus the size of {@code buffer} ({@code buffer.limit()}) should always
- + * match the flat size of this {@link TensorBuffer}, for both fixed-size and dynamic {@link
- + * TensorBuffer}. Use {@link #loadBuffer(ByteBuffer, int[])} if {@code buffer} has a different
- + * shape.
- + *
- + * <p>Important: The loaded buffer is a reference. DO NOT MODIFY. We don't create a copy here
- + * for performance concern, but if modification is necessary, please make a copy.
- + *
- + * <p>For the best performance, always load a direct {@link ByteBuffer} or a {@link ByteBuffer}
- + * backed by an array.
- + *
- + * <p>If the {@code buffer} is read-only, we adopt a copy-on-write strategy for performance.
- + *
- + * @param buffer The byte buffer to load.
- + */
- + public void loadBuffer(@NonNull ByteBuffer buffer) {
- + loadBuffer(buffer, shape);
- + }
- +
- + /**
- + * Constructs a fixed size {@link TensorBuffer} with specified {@code shape}.
- + *
- + * @throws NullPointerException if {@code shape} is null.
- + * @throws IllegalArgumentException if {@code shape} has non-positive elements.
- + */
- + protected TensorBuffer(@NonNull int[] shape) {
- + isDynamic = false;
- + allocateMemory(shape);
- + }
- +
- + /** Constructs a dynamic {@link TensorBuffer} which can be resized. */
- + protected TensorBuffer() {
- + isDynamic = true;
- + // Initialize the dynamic TensorBuffer with an empty ByteBuffer.
- + allocateMemory(new int[] {0});
- + }
- +
- + /** Calculates number of elements in the buffer. */
- + protected static int computeFlatSize(@NonNull int[] shape) {
- + checkNotNull(shape, "Shape cannot be null.");
- + int prod = 1;
- + for (int s : shape) {
- + prod = prod * s;
- + }
- + return prod;
- + }
- +
- + /**
- + * For dynamic buffer, resize the memory if needed. For fixed-size buffer, check if the {@code
- + * shape} of src fits the buffer size.
- + */
- + protected void resize(@NonNull int[] shape) {
- + if (isDynamic) {
- + allocateMemory(shape);
- + } else {
- + // Make sure the new shape fits the buffer size when TensorBuffer has fixed size.
- + checkArgument(Arrays.equals(shape, this.shape));
- + this.shape = shape.clone();
- + }
- + }
- +
- + /** Copies the underlying {@link ByteBuffer} if it's readonly. */
- + protected synchronized void copyByteBufferIfReadOnly() {
- + if (!buffer.isReadOnly()) {
- + return;
- + }
- + ByteBuffer newByteBuffer = ByteBuffer.allocateDirect(buffer.capacity());
- + newByteBuffer.order(buffer.order());
- + newByteBuffer.put(buffer);
- + newByteBuffer.rewind();
- + buffer = newByteBuffer;
- + }
- +
- + /**
- + * Allocates buffer with corresponding size of the {@code shape}. If shape is an empty array,
- + * this
- + * {@link TensorBuffer} will be created as a scalar and its flatSize will be 1.
- + *
- + * @throws NullPointerException if {@code shape} is null.
- + * @throws IllegalArgumentException if {@code shape} has negative elements.
- + */
- + private void allocateMemory(@NonNull int[] shape) {
- + checkNotNull(shape, "TensorBuffer shape cannot be null.");
- + checkArgument(isShapeValid(shape), "Values in TensorBuffer shape should be non-negative.");
- +
- + // Check if the new shape is the same as current shape.
- + int newFlatSize = computeFlatSize(shape);
- + this.shape = shape.clone();
- + if (flatSize == newFlatSize) {
- + return;
- + }
- +
- + // Update to the new shape.
- + flatSize = newFlatSize;
- + buffer = ByteBuffer.allocateDirect(flatSize * getTypeSize());
- + buffer.order(ByteOrder.nativeOrder());
- }
-
- - // Update to the new shape.
- - flatSize = newFlatSize;
- - buffer = ByteBuffer.allocateDirect(flatSize * getTypeSize());
- - buffer.order(ByteOrder.nativeOrder());
- - }
- -
- - /**
- - * Verifies if the shape of the {@link TensorBuffer} matched the size of the underlying {@link
- - * ByteBuffer}.
- - */
- - private void assertShapeIsCorrect() {
- - int flatSize = computeFlatSize(shape);
- - checkState(
- - (buffer.limit() == getTypeSize() * flatSize),
- - String.format(
- - "The size of underlying ByteBuffer (%d) and the shape (%s) do not match. The"
- - + " ByteBuffer may have been changed.",
- - buffer.limit(), Arrays.toString(shape)));
- - }
- -
- - /**
- - * Checks if {@code shape} meets one of following two requirements: 1. Elements in {@code shape}
- - * are all non-negative numbers. 2. {@code shape} is an empty array, which corresponds to scalar.
- - */
- - private static boolean isShapeValid(@NonNull int[] shape) {
- - if (shape.length == 0) {
- - // This shape refers to a scalar.
- - return true;
- + /**
- + * Verifies if the shape of the {@link TensorBuffer} matched the size of the underlying {@link
- + * ByteBuffer}.
- + */
- + private void assertShapeIsCorrect() {
- + int flatSize = computeFlatSize(shape);
- + checkState((buffer.limit() == getTypeSize() * flatSize),
- + String.format(
- + "The size of underlying ByteBuffer (%d) and the shape (%s) do not match. The"
- + + " ByteBuffer may have been changed.",
- + buffer.limit(), Arrays.toString(shape)));
- }
-
- - // This shape refers to a multidimensional array.
- - for (int s : shape) {
- - // All elements in shape should be non-negative.
- - if (s < 0) {
- - return false;
- - }
- + /**
- + * Checks if {@code shape} meets one of following two requirements: 1. Elements in {@code shape}
- + * are all non-negative numbers. 2. {@code shape} is an empty array, which corresponds to
- + * scalar.
- + */
- + private static boolean isShapeValid(@NonNull int[] shape) {
- + if (shape.length == 0) {
- + // This shape refers to a scalar.
- + return true;
- + }
- +
- + // This shape refers to a multidimensional array.
- + for (int s : shape) {
- + // All elements in shape should be non-negative.
- + if (s < 0) {
- + return false;
- + }
- + }
- + return true;
- }
- - return true;
- - }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBufferFloat.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBufferFloat.java
- index 8d2bc5ad0c84d..632db6c886b17 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBufferFloat.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBufferFloat.java
- @@ -15,103 +15,102 @@ limitations under the License.
-
- package org.tensorflow.lite.support.tensorbuffer;
-
- -import java.nio.FloatBuffer;
- import org.checkerframework.checker.nullness.qual.NonNull;
- import org.tensorflow.lite.DataType;
- import org.tensorflow.lite.support.common.internal.SupportPreconditions;
-
- +import java.nio.FloatBuffer;
- +
- /** Represents data buffer with float values. */
- public final class TensorBufferFloat extends TensorBuffer {
- - private static final DataType DATA_TYPE = DataType.FLOAT32;
- -
- - /**
- - * Creates a {@link TensorBufferFloat} with specified {@code shape}.
- - *
- - * @throws NullPointerException if {@code shape} is null.
- - * @throws IllegalArgumentException if {@code shape} has non-positive elements.
- - */
- - TensorBufferFloat(@NonNull int[] shape) {
- - super(shape);
- - }
- -
- - TensorBufferFloat() {
- - super();
- - }
- -
- - @Override
- - public DataType getDataType() {
- - return DATA_TYPE;
- - }
- -
- - @Override
- - @NonNull
- - public float[] getFloatArray() {
- - buffer.rewind();
- - float[] arr = new float[flatSize];
- -
- - FloatBuffer floatBuffer = buffer.asFloatBuffer();
- - floatBuffer.get(arr);
- - return arr;
- - }
- -
- - @Override
- - public float getFloatValue(int absIndex) {
- - return buffer.getFloat(absIndex << 2);
- - }
- -
- - @Override
- - @NonNull
- - public int[] getIntArray() {
- - buffer.rewind();
- - float[] floatArr = new float[flatSize];
- - buffer.asFloatBuffer().get(floatArr);
- -
- - int[] intArr = new int[flatSize];
- - for (int i = 0; i < flatSize; i++) {
- - intArr[i] = (int) floatArr[i];
- + private static final DataType DATA_TYPE = DataType.FLOAT32;
- +
- + /**
- + * Creates a {@link TensorBufferFloat} with specified {@code shape}.
- + *
- + * @throws NullPointerException if {@code shape} is null.
- + * @throws IllegalArgumentException if {@code shape} has non-positive elements.
- + */
- + TensorBufferFloat(@NonNull int[] shape) {
- + super(shape);
- }
- - return intArr;
- - }
- -
- - @Override
- - public int getIntValue(int absIndex) {
- - return (int) buffer.getFloat(absIndex << 2);
- - }
- -
- - @Override
- - public int getTypeSize() {
- - return DATA_TYPE.byteSize();
- - }
- -
- - @Override
- - public void loadArray(@NonNull float[] src, @NonNull int[] shape) {
- - SupportPreconditions.checkNotNull(src, "The array to be loaded cannot be null.");
- - SupportPreconditions.checkArgument(
- - src.length == computeFlatSize(shape),
- - "The size of the array to be loaded does not match the specified shape.");
- - copyByteBufferIfReadOnly();
- - resize(shape);
- - buffer.rewind();
- -
- - FloatBuffer floatBuffer = buffer.asFloatBuffer();
- - floatBuffer.put(src);
- - }
- -
- - @Override
- - public void loadArray(@NonNull int[] src, @NonNull int[] shape) {
- - SupportPreconditions.checkNotNull(src, "The array to be loaded cannot be null.");
- - SupportPreconditions.checkArgument(
- - src.length == computeFlatSize(shape),
- - "The size of the array to be loaded does not match the specified shape.");
- - copyByteBufferIfReadOnly();
- - resize(shape);
- - buffer.rewind();
- -
- - float[] floatArray = new float[src.length];
- - int cnt = 0;
- - for (int a : src) {
- - floatArray[cnt++] = (float) a;
- +
- + TensorBufferFloat() {
- + super();
- + }
- +
- + @Override
- + public DataType getDataType() {
- + return DATA_TYPE;
- + }
- +
- + @Override
- + @NonNull
- + public float[] getFloatArray() {
- + buffer.rewind();
- + float[] arr = new float[flatSize];
- +
- + FloatBuffer floatBuffer = buffer.asFloatBuffer();
- + floatBuffer.get(arr);
- + return arr;
- + }
- +
- + @Override
- + public float getFloatValue(int absIndex) {
- + return buffer.getFloat(absIndex << 2);
- + }
- +
- + @Override
- + @NonNull
- + public int[] getIntArray() {
- + buffer.rewind();
- + float[] floatArr = new float[flatSize];
- + buffer.asFloatBuffer().get(floatArr);
- +
- + int[] intArr = new int[flatSize];
- + for (int i = 0; i < flatSize; i++) {
- + intArr[i] = (int) floatArr[i];
- + }
- + return intArr;
- + }
- +
- + @Override
- + public int getIntValue(int absIndex) {
- + return (int) buffer.getFloat(absIndex << 2);
- + }
- +
- + @Override
- + public int getTypeSize() {
- + return DATA_TYPE.byteSize();
- + }
- +
- + @Override
- + public void loadArray(@NonNull float[] src, @NonNull int[] shape) {
- + SupportPreconditions.checkNotNull(src, "The array to be loaded cannot be null.");
- + SupportPreconditions.checkArgument(src.length == computeFlatSize(shape),
- + "The size of the array to be loaded does not match the specified shape.");
- + copyByteBufferIfReadOnly();
- + resize(shape);
- + buffer.rewind();
- +
- + FloatBuffer floatBuffer = buffer.asFloatBuffer();
- + floatBuffer.put(src);
- + }
- +
- + @Override
- + public void loadArray(@NonNull int[] src, @NonNull int[] shape) {
- + SupportPreconditions.checkNotNull(src, "The array to be loaded cannot be null.");
- + SupportPreconditions.checkArgument(src.length == computeFlatSize(shape),
- + "The size of the array to be loaded does not match the specified shape.");
- + copyByteBufferIfReadOnly();
- + resize(shape);
- + buffer.rewind();
- +
- + float[] floatArray = new float[src.length];
- + int cnt = 0;
- + for (int a : src) {
- + floatArray[cnt++] = (float) a;
- + }
- + buffer.asFloatBuffer().put(floatArray);
- }
- - buffer.asFloatBuffer().put(floatArray);
- - }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBufferUint8.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBufferUint8.java
- index b2fa466e5be92..2924ef0af6c11 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBufferUint8.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBufferUint8.java
- @@ -21,103 +21,101 @@ import org.tensorflow.lite.support.common.internal.SupportPreconditions;
-
- /** Represents data buffer with 8-bit unsigned integer values. */
- public final class TensorBufferUint8 extends TensorBuffer {
- - private static final DataType DATA_TYPE = DataType.UINT8;
- -
- - /**
- - * Creates a {@link TensorBufferUint8} with specified {@code shape}.
- - *
- - * @throws NullPointerException if {@code shape} is null.
- - * @throws IllegalArgumentException if {@code shape} has non-positive elements.
- - */
- - TensorBufferUint8(@NonNull int[] shape) {
- - super(shape);
- - }
- -
- - TensorBufferUint8() {
- - super();
- - }
- -
- - @Override
- - public DataType getDataType() {
- - return DATA_TYPE;
- - }
- -
- - @Override
- - @NonNull
- - public float[] getFloatArray() {
- - buffer.rewind();
- - byte[] byteArr = new byte[flatSize];
- - buffer.get(byteArr);
- -
- - float[] floatArr = new float[flatSize];
- - for (int i = 0; i < flatSize; i++) {
- - floatArr[i] = (float) (byteArr[i] & 0xff);
- + private static final DataType DATA_TYPE = DataType.UINT8;
- +
- + /**
- + * Creates a {@link TensorBufferUint8} with specified {@code shape}.
- + *
- + * @throws NullPointerException if {@code shape} is null.
- + * @throws IllegalArgumentException if {@code shape} has non-positive elements.
- + */
- + TensorBufferUint8(@NonNull int[] shape) {
- + super(shape);
- }
- - return floatArr;
- - }
- -
- - @Override
- - public float getFloatValue(int index) {
- - return (float) (buffer.get(index) & 0xff);
- - }
- -
- - @Override
- - @NonNull
- - public int[] getIntArray() {
- - buffer.rewind();
- - byte[] byteArr = new byte[flatSize];
- - buffer.get(byteArr);
- -
- - int[] intArr = new int[flatSize];
- - for (int i = 0; i < flatSize; i++) {
- - intArr[i] = byteArr[i] & 0xff;
- +
- + TensorBufferUint8() {
- + super();
- + }
- +
- + @Override
- + public DataType getDataType() {
- + return DATA_TYPE;
- + }
- +
- + @Override
- + @NonNull
- + public float[] getFloatArray() {
- + buffer.rewind();
- + byte[] byteArr = new byte[flatSize];
- + buffer.get(byteArr);
- +
- + float[] floatArr = new float[flatSize];
- + for (int i = 0; i < flatSize; i++) {
- + floatArr[i] = (float) (byteArr[i] & 0xff);
- + }
- + return floatArr;
- + }
- +
- + @Override
- + public float getFloatValue(int index) {
- + return (float) (buffer.get(index) & 0xff);
- }
- - return intArr;
- - }
- -
- - @Override
- - public int getIntValue(int index) {
- - return buffer.get(index) & 0xff;
- - }
- -
- - @Override
- - public int getTypeSize() {
- - return DATA_TYPE.byteSize();
- - }
- -
- - @Override
- - public void loadArray(@NonNull float[] src, @NonNull int[] shape) {
- - SupportPreconditions.checkNotNull(src, "The array to be loaded cannot be null.");
- - SupportPreconditions.checkArgument(
- - src.length == computeFlatSize(shape),
- - "The size of the array to be loaded does not match the specified shape.");
- - copyByteBufferIfReadOnly();
- - resize(shape);
- - buffer.rewind();
- -
- - byte[] byteArr = new byte[src.length];
- - int cnt = 0;
- - for (float a : src) {
- - byteArr[cnt++] = (byte) Math.max(Math.min(a, 255.0), 0.0);
- +
- + @Override
- + @NonNull
- + public int[] getIntArray() {
- + buffer.rewind();
- + byte[] byteArr = new byte[flatSize];
- + buffer.get(byteArr);
- +
- + int[] intArr = new int[flatSize];
- + for (int i = 0; i < flatSize; i++) {
- + intArr[i] = byteArr[i] & 0xff;
- + }
- + return intArr;
- }
- - buffer.put(byteArr);
- - }
- -
- - @Override
- - public void loadArray(@NonNull int[] src, @NonNull int[] shape) {
- - SupportPreconditions.checkNotNull(src, "The array to be loaded cannot be null.");
- - SupportPreconditions.checkArgument(
- - src.length == computeFlatSize(shape),
- - "The size of the array to be loaded does not match the specified shape.");
- - copyByteBufferIfReadOnly();
- - resize(shape);
- - buffer.rewind();
- -
- - byte[] byteArr = new byte[src.length];
- - int cnt = 0;
- - for (float a : src) {
- - byteArr[cnt++] = (byte) Math.max(Math.min(a, 255), 0);
- +
- + @Override
- + public int getIntValue(int index) {
- + return buffer.get(index) & 0xff;
- + }
- +
- + @Override
- + public int getTypeSize() {
- + return DATA_TYPE.byteSize();
- + }
- +
- + @Override
- + public void loadArray(@NonNull float[] src, @NonNull int[] shape) {
- + SupportPreconditions.checkNotNull(src, "The array to be loaded cannot be null.");
- + SupportPreconditions.checkArgument(src.length == computeFlatSize(shape),
- + "The size of the array to be loaded does not match the specified shape.");
- + copyByteBufferIfReadOnly();
- + resize(shape);
- + buffer.rewind();
- +
- + byte[] byteArr = new byte[src.length];
- + int cnt = 0;
- + for (float a : src) {
- + byteArr[cnt++] = (byte) Math.max(Math.min(a, 255.0), 0.0);
- + }
- + buffer.put(byteArr);
- + }
- +
- + @Override
- + public void loadArray(@NonNull int[] src, @NonNull int[] shape) {
- + SupportPreconditions.checkNotNull(src, "The array to be loaded cannot be null.");
- + SupportPreconditions.checkArgument(src.length == computeFlatSize(shape),
- + "The size of the array to be loaded does not match the specified shape.");
- + copyByteBufferIfReadOnly();
- + resize(shape);
- + buffer.rewind();
- +
- + byte[] byteArr = new byte[src.length];
- + int cnt = 0;
- + for (float a : src) {
- + byteArr[cnt++] = (byte) Math.max(Math.min(a, 255), 0);
- + }
- + buffer.put(byteArr);
- }
- - buffer.put(byteArr);
- - }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/audio/classifier/AudioClassifier.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/audio/classifier/AudioClassifier.java
- index 043528aa88138..85c5d12e2fc53 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/audio/classifier/AudioClassifier.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/audio/classifier/AudioClassifier.java
- @@ -22,13 +22,7 @@ import android.media.AudioFormat;
- import android.media.AudioRecord;
- import android.media.MediaRecorder;
- import android.os.ParcelFileDescriptor;
- -import java.io.File;
- -import java.io.IOException;
- -import java.nio.ByteBuffer;
- -import java.nio.MappedByteBuffer;
- -import java.util.ArrayList;
- -import java.util.Collections;
- -import java.util.List;
- +
- import org.tensorflow.lite.DataType;
- import org.tensorflow.lite.support.audio.TensorAudio;
- import org.tensorflow.lite.support.audio.TensorAudio.TensorAudioFormat;
- @@ -40,6 +34,14 @@ import org.tensorflow.lite.task.core.TaskJniUtils.EmptyHandleProvider;
- import org.tensorflow.lite.task.core.TaskJniUtils.FdAndOptionsHandleProvider;
- import org.tensorflow.lite.task.core.annotations.UsedByReflection;
-
- +import java.io.File;
- +import java.io.IOException;
- +import java.nio.ByteBuffer;
- +import java.nio.MappedByteBuffer;
- +import java.util.ArrayList;
- +import java.util.Collections;
- +import java.util.List;
- +
- /**
- * Performs classification on audio waveforms.
- *
- @@ -72,468 +74,437 @@ import org.tensorflow.lite.task.core.annotations.UsedByReflection;
- * CLI demo tool</a> for easily trying out this API.
- */
- public final class AudioClassifier extends BaseTaskApi {
- + private static final String AUDIO_CLASSIFIER_NATIVE_LIB = "task_audio_jni";
- + private static final int OPTIONAL_FD_LENGTH = -1;
- + private static final int OPTIONAL_FD_OFFSET = -1;
- +
- + /**
- + * Creates an {@link AudioClassifier} instance from the default {@link AudioClassifierOptions}.
- + *
- + * @param modelPath path of the classification model with metadata in the assets
- + * @throws IOException if an I/O error occurs when loading the tflite model
- + * @throws IllegalArgumentException if an argument is invalid
- + * @throws IllegalStateException if there is an internal error
- + * @throws RuntimeException if there is an otherwise unspecified error
- + */
- + public static AudioClassifier createFromFile(Context context, String modelPath)
- + throws IOException {
- + return createFromFileAndOptions(
- + context, modelPath, AudioClassifierOptions.builder().build());
- + }
-
- - private static final String AUDIO_CLASSIFIER_NATIVE_LIB = "task_audio_jni";
- - private static final int OPTIONAL_FD_LENGTH = -1;
- - private static final int OPTIONAL_FD_OFFSET = -1;
- -
- - /**
- - * Creates an {@link AudioClassifier} instance from the default {@link AudioClassifierOptions}.
- - *
- - * @param modelPath path of the classification model with metadata in the assets
- - * @throws IOException if an I/O error occurs when loading the tflite model
- - * @throws IllegalArgumentException if an argument is invalid
- - * @throws IllegalStateException if there is an internal error
- - * @throws RuntimeException if there is an otherwise unspecified error
- - */
- - public static AudioClassifier createFromFile(Context context, String modelPath)
- - throws IOException {
- - return createFromFileAndOptions(context, modelPath, AudioClassifierOptions.builder().build());
- - }
- -
- - /**
- - * Creates an {@link AudioClassifier} instance from the default {@link AudioClassifierOptions}.
- - *
- - * @param modelFile the classification model {@link File} instance
- - * @throws IOException if an I/O error occurs when loading the tflite model
- - * @throws IllegalArgumentException if an argument is invalid
- - * @throws IllegalStateException if there is an internal error
- - * @throws RuntimeException if there is an otherwise unspecified error
- - */
- - public static AudioClassifier createFromFile(File modelFile) throws IOException {
- - return createFromFileAndOptions(modelFile, AudioClassifierOptions.builder().build());
- - }
- -
- - /**
- - * Creates an {@link AudioClassifier} instance with a model buffer and the default {@link
- - * AudioClassifierOptions}.
- - *
- - * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the
- - * classification model
- - * @throws IllegalStateException if there is an internal error
- - * @throws RuntimeException if there is an otherwise unspecified error
- - * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
- - * {@link MappedByteBuffer}
- - */
- - public static AudioClassifier createFromBuffer(final ByteBuffer modelBuffer) {
- - return createFromBufferAndOptions(modelBuffer, AudioClassifierOptions.builder().build());
- - }
- -
- - /**
- - * Creates an {@link AudioClassifier} instance from {@link AudioClassifierOptions}.
- - *
- - * @param modelPath path of the classification model with metadata in the assets
- - * @throws IOException if an I/O error occurs when loading the tflite model
- - * @throws IllegalArgumentException if an argument is invalid
- - * @throws IllegalStateException if there is an internal error
- - * @throws RuntimeException if there is an otherwise unspecified error
- - */
- - public static AudioClassifier createFromFileAndOptions(
- - Context context, String modelPath, AudioClassifierOptions options) throws IOException {
- - return new AudioClassifier(
- - TaskJniUtils.createHandleFromFdAndOptions(
- - context,
- - new FdAndOptionsHandleProvider<AudioClassifierOptions>() {
- - @Override
- - public long createHandle(
- - int fileDescriptor,
- - long fileDescriptorLength,
- - long fileDescriptorOffset,
- - AudioClassifierOptions options) {
- - return initJniWithModelFdAndOptions(
- - fileDescriptor,
- - fileDescriptorLength,
- - fileDescriptorOffset,
- - options,
- - TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions()));
- - }
- - },
- - AUDIO_CLASSIFIER_NATIVE_LIB,
- - modelPath,
- - options));
- - }
- -
- - /**
- - * Creates an {@link AudioClassifier} instance.
- - *
- - * @param modelFile the classification model {@link File} instance
- - * @throws IOException if an I/O error occurs when loading the tflite model
- - * @throws IllegalArgumentException if an argument is invalid
- - * @throws IllegalStateException if there is an internal error
- - * @throws RuntimeException if there is an otherwise unspecified error
- - */
- - public static AudioClassifier createFromFileAndOptions(
- - File modelFile, final AudioClassifierOptions options) throws IOException {
- - try (ParcelFileDescriptor descriptor =
- - ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
- - return new AudioClassifier(
- - TaskJniUtils.createHandleFromLibrary(
- - new TaskJniUtils.EmptyHandleProvider() {
- - @Override
- - public long createHandle() {
- - return initJniWithModelFdAndOptions(
- - descriptor.getFd(),
- - /*fileDescriptorLength=*/ OPTIONAL_FD_LENGTH,
- - /*fileDescriptorOffset=*/ OPTIONAL_FD_OFFSET,
- - options,
- - TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions()));
- - }
- - },
- - AUDIO_CLASSIFIER_NATIVE_LIB));
- + /**
- + * Creates an {@link AudioClassifier} instance from the default {@link AudioClassifierOptions}.
- + *
- + * @param modelFile the classification model {@link File} instance
- + * @throws IOException if an I/O error occurs when loading the tflite model
- + * @throws IllegalArgumentException if an argument is invalid
- + * @throws IllegalStateException if there is an internal error
- + * @throws RuntimeException if there is an otherwise unspecified error
- + */
- + public static AudioClassifier createFromFile(File modelFile) throws IOException {
- + return createFromFileAndOptions(modelFile, AudioClassifierOptions.builder().build());
- }
- - }
- -
- - /**
- - * Creates an {@link AudioClassifier} instance with a model buffer and {@link
- - * AudioClassifierOptions}.
- - *
- - * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the
- - * classification model
- - * @throws IllegalStateException if there is an internal error
- - * @throws RuntimeException if there is an otherwise unspecified error
- - * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
- - * {@link MappedByteBuffer}
- - */
- - public static AudioClassifier createFromBufferAndOptions(
- - final ByteBuffer modelBuffer, final AudioClassifierOptions options) {
- - if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) {
- - throw new IllegalArgumentException(
- - "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer.");
- +
- + /**
- + * Creates an {@link AudioClassifier} instance with a model buffer and the default {@link
- + * AudioClassifierOptions}.
- + *
- + * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the
- + * classification model
- + * @throws IllegalStateException if there is an internal error
- + * @throws RuntimeException if there is an otherwise unspecified error
- + * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
- + * {@link MappedByteBuffer}
- + */
- + public static AudioClassifier createFromBuffer(final ByteBuffer modelBuffer) {
- + return createFromBufferAndOptions(modelBuffer, AudioClassifierOptions.builder().build());
- }
- - return new AudioClassifier(
- - TaskJniUtils.createHandleFromLibrary(
- - new EmptyHandleProvider() {
- - @Override
- - public long createHandle() {
- - return initJniWithByteBuffer(
- - modelBuffer,
- - options,
- - TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions()));
- - }
- - },
- - AUDIO_CLASSIFIER_NATIVE_LIB));
- - }
- -
- - /**
- - * Constructor to initialize the JNI with a pointer from C++.
- - *
- - * @param nativeHandle a pointer referencing memory allocated in C++
- - */
- - private AudioClassifier(long nativeHandle) {
- - super(nativeHandle);
- - }
- -
- - /** Options for setting up an {@link AudioClassifier}. */
- - @UsedByReflection("audio_classifier_jni.cc")
- - public static class AudioClassifierOptions {
- - // Not using AutoValue for this class because scoreThreshold cannot have default value
- - // (otherwise, the default value would override the one in the model metadata) and `Optional` is
- - // not an option here, because
- - // 1. java.util.Optional require Java 8 while we need to support Java 7.
- - // 2. The Guava library (com.google.common.base.Optional) is avoided in this project. See the
- - // comments for labelAllowList.
- - private final BaseOptions baseOptions;
- - private final String displayNamesLocale;
- - private final int maxResults;
- - private final float scoreThreshold;
- - private final boolean isScoreThresholdSet;
- - // As an open source project, we've been trying avoiding depending on common java libraries,
- - // such as Guava, because it may introduce conflicts with clients who also happen to use those
- - // libraries. Therefore, instead of using ImmutableList here, we convert the List into
- - // unmodifiableList in setLabelAllowList() and setLabelDenyList() to make it less
- - // vulnerable.
- - private final List<String> labelAllowList;
- - private final List<String> labelDenyList;
- -
- - public static Builder builder() {
- - return new Builder();
- +
- + /**
- + * Creates an {@link AudioClassifier} instance from {@link AudioClassifierOptions}.
- + *
- + * @param modelPath path of the classification model with metadata in the assets
- + * @throws IOException if an I/O error occurs when loading the tflite model
- + * @throws IllegalArgumentException if an argument is invalid
- + * @throws IllegalStateException if there is an internal error
- + * @throws RuntimeException if there is an otherwise unspecified error
- + */
- + public static AudioClassifier createFromFileAndOptions(
- + Context context, String modelPath, AudioClassifierOptions options) throws IOException {
- + return new AudioClassifier(TaskJniUtils.createHandleFromFdAndOptions(
- + context, new FdAndOptionsHandleProvider<AudioClassifierOptions>() {
- + @Override
- + public long createHandle(int fileDescriptor, long fileDescriptorLength,
- + long fileDescriptorOffset, AudioClassifierOptions options) {
- + return initJniWithModelFdAndOptions(fileDescriptor, fileDescriptorLength,
- + fileDescriptorOffset, options,
- + TaskJniUtils.createProtoBaseOptionsHandle(
- + options.getBaseOptions()));
- + }
- + }, AUDIO_CLASSIFIER_NATIVE_LIB, modelPath, options));
- }
-
- - /** A builder that helps to configure an instance of AudioClassifierOptions. */
- - public static class Builder {
- - private BaseOptions baseOptions = BaseOptions.builder().build();
- - private String displayNamesLocale = "en";
- - private int maxResults = -1;
- - private float scoreThreshold;
- - private boolean isScoreThresholdSet;
- - private List<String> labelAllowList = new ArrayList<>();
- - private List<String> labelDenyList = new ArrayList<>();
- -
- - private Builder() {}
- -
- - /** Sets the general options to configure Task APIs, such as accelerators. */
- - public Builder setBaseOptions(BaseOptions baseOptions) {
- - this.baseOptions = baseOptions;
- - return this;
- - }
- -
- - /**
- - * Sets the locale to use for display names specified through the TFLite Model Metadata, if
- - * any.
- - *
- - * <p>Defaults to English({@code "en"}). See the <a
- - * href="https://github.com/tensorflow/tflite-support/blob/3ce83f0cfe2c68fecf83e019f2acc354aaba471f/tensorflow_lite_support/metadata/metadata_schema.fbs#L147">TFLite
- - * Metadata schema file.</a> for the accepted pattern of locale.
- - */
- - public Builder setDisplayNamesLocale(String displayNamesLocale) {
- - this.displayNamesLocale = displayNamesLocale;
- - return this;
- - }
- -
- - /**
- - * Sets the maximum number of top scored results to return.
- - *
- - * @param maxResults if < 0, all results will be returned. If 0, an invalid argument error is
- - * returned. Defaults to -1.
- - * @throws IllegalArgumentException if maxResults is 0
- - */
- - public Builder setMaxResults(int maxResults) {
- - if (maxResults == 0) {
- - throw new IllegalArgumentException("maxResults cannot be 0.");
- + /**
- + * Creates an {@link AudioClassifier} instance.
- + *
- + * @param modelFile the classification model {@link File} instance
- + * @throws IOException if an I/O error occurs when loading the tflite model
- + * @throws IllegalArgumentException if an argument is invalid
- + * @throws IllegalStateException if there is an internal error
- + * @throws RuntimeException if there is an otherwise unspecified error
- + */
- + public static AudioClassifier createFromFileAndOptions(
- + File modelFile, final AudioClassifierOptions options) throws IOException {
- + try (ParcelFileDescriptor descriptor =
- + ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
- + return new AudioClassifier(
- + TaskJniUtils.createHandleFromLibrary(new TaskJniUtils.EmptyHandleProvider() {
- + @Override
- + public long createHandle() {
- + return initJniWithModelFdAndOptions(descriptor.getFd(),
- + /*fileDescriptorLength=*/OPTIONAL_FD_LENGTH,
- + /*fileDescriptorOffset=*/OPTIONAL_FD_OFFSET, options,
- + TaskJniUtils.createProtoBaseOptionsHandle(
- + options.getBaseOptions()));
- + }
- + }, AUDIO_CLASSIFIER_NATIVE_LIB));
- }
- - this.maxResults = maxResults;
- - return this;
- - }
- -
- - /**
- - * Sets the score threshold.
- - *
- - * <p>It overrides the one provided in the model metadata (if any). Results below this value
- - * are rejected.
- - */
- - public Builder setScoreThreshold(float scoreThreshold) {
- - this.scoreThreshold = scoreThreshold;
- - isScoreThresholdSet = true;
- - return this;
- - }
- -
- - /**
- - * Sets the optional allowlist of labels.
- - *
- - * <p>If non-empty, classifications whose label is not in this set will be filtered out.
- - * Duplicate or unknown labels are ignored. Mutually exclusive with labelDenyList.
- - */
- - public Builder setLabelAllowList(List<String> labelAllowList) {
- - this.labelAllowList = Collections.unmodifiableList(new ArrayList<>(labelAllowList));
- - return this;
- - }
- -
- - /**
- - * Sets the optional denylist of labels.
- - *
- - * <p>If non-empty, classifications whose label is in this set will be filtered out. Duplicate
- - * or unknown labels are ignored. Mutually exclusive with labelAllowList.
- - */
- - public Builder setLabelDenyList(List<String> labelDenyList) {
- - this.labelDenyList = Collections.unmodifiableList(new ArrayList<>(labelDenyList));
- - return this;
- - }
- -
- - public AudioClassifierOptions build() {
- - return new AudioClassifierOptions(this);
- - }
- }
-
- - @UsedByReflection("audio_classifier_jni.cc")
- - public String getDisplayNamesLocale() {
- - return displayNamesLocale;
- + /**
- + * Creates an {@link AudioClassifier} instance with a model buffer and {@link
- + * AudioClassifierOptions}.
- + *
- + * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the
- + * classification model
- + * @throws IllegalStateException if there is an internal error
- + * @throws RuntimeException if there is an otherwise unspecified error
- + * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
- + * {@link MappedByteBuffer}
- + */
- + public static AudioClassifier createFromBufferAndOptions(
- + final ByteBuffer modelBuffer, final AudioClassifierOptions options) {
- + if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) {
- + throw new IllegalArgumentException(
- + "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer.");
- + }
- + return new AudioClassifier(TaskJniUtils.createHandleFromLibrary(new EmptyHandleProvider() {
- + @Override
- + public long createHandle() {
- + return initJniWithByteBuffer(modelBuffer, options,
- + TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions()));
- + }
- + }, AUDIO_CLASSIFIER_NATIVE_LIB));
- }
-
- - @UsedByReflection("audio_classifier_jni.cc")
- - public int getMaxResults() {
- - return maxResults;
- + /**
- + * Constructor to initialize the JNI with a pointer from C++.
- + *
- + * @param nativeHandle a pointer referencing memory allocated in C++
- + */
- + private AudioClassifier(long nativeHandle) {
- + super(nativeHandle);
- }
-
- + /** Options for setting up an {@link AudioClassifier}. */
- @UsedByReflection("audio_classifier_jni.cc")
- - public float getScoreThreshold() {
- - return scoreThreshold;
- + public static class AudioClassifierOptions {
- + // Not using AutoValue for this class because scoreThreshold cannot have default value
- + // (otherwise, the default value would override the one in the model metadata) and
- + // `Optional` is not an option here, because
- + // 1. java.util.Optional require Java 8 while we need to support Java 7.
- + // 2. The Guava library (com.google.common.base.Optional) is avoided in this project. See
- + // the comments for labelAllowList.
- + private final BaseOptions baseOptions;
- + private final String displayNamesLocale;
- + private final int maxResults;
- + private final float scoreThreshold;
- + private final boolean isScoreThresholdSet;
- + // As an open source project, we've been trying avoiding depending on common java libraries,
- + // such as Guava, because it may introduce conflicts with clients who also happen to use
- + // those libraries. Therefore, instead of using ImmutableList here, we convert the List into
- + // unmodifiableList in setLabelAllowList() and setLabelDenyList() to make it less
- + // vulnerable.
- + private final List<String> labelAllowList;
- + private final List<String> labelDenyList;
- +
- + public static Builder builder() {
- + return new Builder();
- + }
- +
- + /** A builder that helps to configure an instance of AudioClassifierOptions. */
- + public static class Builder {
- + private BaseOptions baseOptions = BaseOptions.builder().build();
- + private String displayNamesLocale = "en";
- + private int maxResults = -1;
- + private float scoreThreshold;
- + private boolean isScoreThresholdSet;
- + private List<String> labelAllowList = new ArrayList<>();
- + private List<String> labelDenyList = new ArrayList<>();
- +
- + private Builder() {}
- +
- + /** Sets the general options to configure Task APIs, such as accelerators. */
- + public Builder setBaseOptions(BaseOptions baseOptions) {
- + this.baseOptions = baseOptions;
- + return this;
- + }
- +
- + /**
- + * Sets the locale to use for display names specified through the TFLite Model Metadata,
- + * if any.
- + *
- + * <p>Defaults to English({@code "en"}). See the <a
- + * href="https://github.com/tensorflow/tflite-support/blob/3ce83f0cfe2c68fecf83e019f2acc354aaba471f/tensorflow_lite_support/metadata/metadata_schema.fbs#L147">TFLite
- + * Metadata schema file.</a> for the accepted pattern of locale.
- + */
- + public Builder setDisplayNamesLocale(String displayNamesLocale) {
- + this.displayNamesLocale = displayNamesLocale;
- + return this;
- + }
- +
- + /**
- + * Sets the maximum number of top scored results to return.
- + *
- + * @param maxResults if < 0, all results will be returned. If 0, an invalid argument
- + * error is
- + * returned. Defaults to -1.
- + * @throws IllegalArgumentException if maxResults is 0
- + */
- + public Builder setMaxResults(int maxResults) {
- + if (maxResults == 0) {
- + throw new IllegalArgumentException("maxResults cannot be 0.");
- + }
- + this.maxResults = maxResults;
- + return this;
- + }
- +
- + /**
- + * Sets the score threshold.
- + *
- + * <p>It overrides the one provided in the model metadata (if any). Results below this
- + * value are rejected.
- + */
- + public Builder setScoreThreshold(float scoreThreshold) {
- + this.scoreThreshold = scoreThreshold;
- + isScoreThresholdSet = true;
- + return this;
- + }
- +
- + /**
- + * Sets the optional allowlist of labels.
- + *
- + * <p>If non-empty, classifications whose label is not in this set will be filtered out.
- + * Duplicate or unknown labels are ignored. Mutually exclusive with labelDenyList.
- + */
- + public Builder setLabelAllowList(List<String> labelAllowList) {
- + this.labelAllowList = Collections.unmodifiableList(new ArrayList<>(labelAllowList));
- + return this;
- + }
- +
- + /**
- + * Sets the optional denylist of labels.
- + *
- + * <p>If non-empty, classifications whose label is in this set will be filtered out.
- + * Duplicate or unknown labels are ignored. Mutually exclusive with labelAllowList.
- + */
- + public Builder setLabelDenyList(List<String> labelDenyList) {
- + this.labelDenyList = Collections.unmodifiableList(new ArrayList<>(labelDenyList));
- + return this;
- + }
- +
- + public AudioClassifierOptions build() {
- + return new AudioClassifierOptions(this);
- + }
- + }
- +
- + @UsedByReflection("audio_classifier_jni.cc")
- + public String getDisplayNamesLocale() {
- + return displayNamesLocale;
- + }
- +
- + @UsedByReflection("audio_classifier_jni.cc")
- + public int getMaxResults() {
- + return maxResults;
- + }
- +
- + @UsedByReflection("audio_classifier_jni.cc")
- + public float getScoreThreshold() {
- + return scoreThreshold;
- + }
- +
- + @UsedByReflection("audio_classifier_jni.cc")
- + public boolean getIsScoreThresholdSet() {
- + return isScoreThresholdSet;
- + }
- +
- + @UsedByReflection("audio_classifier_jni.cc")
- + public List<String> getLabelAllowList() {
- + return new ArrayList<>(labelAllowList);
- + }
- +
- + @UsedByReflection("audio_classifier_jni.cc")
- + public List<String> getLabelDenyList() {
- + return new ArrayList<>(labelDenyList);
- + }
- +
- + public BaseOptions getBaseOptions() {
- + return baseOptions;
- + }
- +
- + private AudioClassifierOptions(Builder builder) {
- + displayNamesLocale = builder.displayNamesLocale;
- + maxResults = builder.maxResults;
- + scoreThreshold = builder.scoreThreshold;
- + isScoreThresholdSet = builder.isScoreThresholdSet;
- + labelAllowList = builder.labelAllowList;
- + labelDenyList = builder.labelDenyList;
- + baseOptions = builder.baseOptions;
- + }
- }
-
- - @UsedByReflection("audio_classifier_jni.cc")
- - public boolean getIsScoreThresholdSet() {
- - return isScoreThresholdSet;
- + /**
- + * Performs actual classification on the provided audio tensor.
- + *
- + * @param tensor a {@link TensorAudio} containing the input audio clip in float with values
- + * between [-1, 1). The {@code tensor} argument should have the same flat size as the TFLite
- + * model's input tensor. It's recommended to create {@code tensor} using {@code
- + * createInputTensorAudio} method.
- + * @throws IllegalArgumentException if an argument is invalid
- + * @throws IllegalStateException if error occurs when classifying the audio clip from the native
- + * code
- + */
- + public List<Classifications> classify(TensorAudio tensor) {
- + TensorBuffer buffer = tensor.getTensorBuffer();
- + TensorAudioFormat format = tensor.getFormat();
- + checkState(buffer.getBuffer().hasArray(),
- + "Input tensor buffer should be a non-direct buffer with a backed array (i.e. not readonly"
- + + " buffer).");
- + return classifyNative(getNativeHandle(), buffer.getBuffer().array(), format.getChannels(),
- + format.getSampleRate());
- }
-
- - @UsedByReflection("audio_classifier_jni.cc")
- - public List<String> getLabelAllowList() {
- - return new ArrayList<>(labelAllowList);
- + /**
- + * Creates a {@link TensorAudio} instance to store input audio samples.
- + *
- + * @return a {@link TensorAudio} with the same size as model input tensor
- + * @throws IllegalArgumentException if the model is not compatible
- + */
- + public TensorAudio createInputTensorAudio() {
- + TensorAudioFormat format = getRequiredTensorAudioFormat();
- +
- + long bufferSize = getRequiredInputBufferSize();
- + long samples = bufferSize / format.getChannels();
- + return TensorAudio.create(format, (int) samples);
- }
-
- - @UsedByReflection("audio_classifier_jni.cc")
- - public List<String> getLabelDenyList() {
- - return new ArrayList<>(labelDenyList);
- + /** Returns the required input buffer size in number of float elements. */
- + public long getRequiredInputBufferSize() {
- + return getRequiredInputBufferSizeNative(getNativeHandle());
- }
-
- - public BaseOptions getBaseOptions() {
- - return baseOptions;
- + /**
- + * Creates an {@link android.media.AudioRecord} instance to record audio stream. The returned
- + * AudioRecord instance is initialized and client needs to call {@link
- + * android.media.AudioRecord#startRecording} method to start recording.
- + *
- + * @return an {@link android.media.AudioRecord} instance in {@link
- + * android.media.AudioRecord#STATE_INITIALIZED}
- + * @throws IllegalArgumentException if the model required channel count is unsupported
- + * @throws IllegalStateException if AudioRecord instance failed to initialize
- + */
- + public AudioRecord createAudioRecord() {
- + TensorAudioFormat format = getRequiredTensorAudioFormat();
- + int channelConfig = 0;
- +
- + switch (format.getChannels()) {
- + case 1:
- + channelConfig = AudioFormat.CHANNEL_IN_MONO;
- + break;
- + case 2:
- + channelConfig = AudioFormat.CHANNEL_IN_STEREO;
- + break;
- + default:
- + throw new IllegalArgumentException(String.format(
- + "Number of channels required by the model is %d. getAudioRecord method only"
- + + " supports 1 or 2 audio channels.",
- + format.getChannels()));
- + }
- +
- + int bufferSizeInBytes = AudioRecord.getMinBufferSize(
- + format.getSampleRate(), channelConfig, AudioFormat.ENCODING_PCM_FLOAT);
- + if (bufferSizeInBytes == AudioRecord.ERROR
- + || bufferSizeInBytes == AudioRecord.ERROR_BAD_VALUE) {
- + throw new IllegalStateException(String.format(
- + "AudioRecord.getMinBufferSize failed. Returned: %d", bufferSizeInBytes));
- + }
- + // The buffer of AudioRecord should be strictly longer than what model requires so that
- + // clients could run `TensorAudio::load(record)` together with `AudioClassifier::classify`.
- + int bufferSizeMultiplier = 2;
- + int modelRequiredBufferSize = (int) getRequiredInputBufferSize()
- + * DataType.FLOAT32.byteSize() * bufferSizeMultiplier;
- + if (bufferSizeInBytes < modelRequiredBufferSize) {
- + bufferSizeInBytes = modelRequiredBufferSize;
- + }
- + AudioRecord audioRecord = new AudioRecord(
- + // including MIC, UNPROCESSED, and CAMCORDER.
- + MediaRecorder.AudioSource.VOICE_RECOGNITION, format.getSampleRate(), channelConfig,
- + AudioFormat.ENCODING_PCM_FLOAT, bufferSizeInBytes);
- + checkState(audioRecord.getState() == AudioRecord.STATE_INITIALIZED,
- + "AudioRecord failed to initialize");
- + return audioRecord;
- }
-
- - private AudioClassifierOptions(Builder builder) {
- - displayNamesLocale = builder.displayNamesLocale;
- - maxResults = builder.maxResults;
- - scoreThreshold = builder.scoreThreshold;
- - isScoreThresholdSet = builder.isScoreThresholdSet;
- - labelAllowList = builder.labelAllowList;
- - labelDenyList = builder.labelDenyList;
- - baseOptions = builder.baseOptions;
- + /** Returns the {@link TensorAudioFormat} required by the model. */
- + public TensorAudioFormat getRequiredTensorAudioFormat() {
- + return TensorAudioFormat.builder()
- + .setChannels(getRequiredChannels())
- + .setSampleRate(getRequiredSampleRate())
- + .build();
- }
- - }
- -
- - /**
- - * Performs actual classification on the provided audio tensor.
- - *
- - * @param tensor a {@link TensorAudio} containing the input audio clip in float with values
- - * between [-1, 1). The {@code tensor} argument should have the same flat size as the TFLite
- - * model's input tensor. It's recommended to create {@code tensor} using {@code
- - * createInputTensorAudio} method.
- - * @throws IllegalArgumentException if an argument is invalid
- - * @throws IllegalStateException if error occurs when classifying the audio clip from the native
- - * code
- - */
- - public List<Classifications> classify(TensorAudio tensor) {
- - TensorBuffer buffer = tensor.getTensorBuffer();
- - TensorAudioFormat format = tensor.getFormat();
- - checkState(
- - buffer.getBuffer().hasArray(),
- - "Input tensor buffer should be a non-direct buffer with a backed array (i.e. not readonly"
- - + " buffer).");
- - return classifyNative(
- - getNativeHandle(),
- - buffer.getBuffer().array(),
- - format.getChannels(),
- - format.getSampleRate());
- - }
- -
- - /**
- - * Creates a {@link TensorAudio} instance to store input audio samples.
- - *
- - * @return a {@link TensorAudio} with the same size as model input tensor
- - * @throws IllegalArgumentException if the model is not compatible
- - */
- - public TensorAudio createInputTensorAudio() {
- - TensorAudioFormat format = getRequiredTensorAudioFormat();
- -
- - long bufferSize = getRequiredInputBufferSize();
- - long samples = bufferSize / format.getChannels();
- - return TensorAudio.create(format, (int) samples);
- - }
- -
- - /** Returns the required input buffer size in number of float elements. */
- - public long getRequiredInputBufferSize() {
- - return getRequiredInputBufferSizeNative(getNativeHandle());
- - }
- -
- - /**
- - * Creates an {@link android.media.AudioRecord} instance to record audio stream. The returned
- - * AudioRecord instance is initialized and client needs to call {@link
- - * android.media.AudioRecord#startRecording} method to start recording.
- - *
- - * @return an {@link android.media.AudioRecord} instance in {@link
- - * android.media.AudioRecord#STATE_INITIALIZED}
- - * @throws IllegalArgumentException if the model required channel count is unsupported
- - * @throws IllegalStateException if AudioRecord instance failed to initialize
- - */
- - public AudioRecord createAudioRecord() {
- - TensorAudioFormat format = getRequiredTensorAudioFormat();
- - int channelConfig = 0;
- -
- - switch (format.getChannels()) {
- - case 1:
- - channelConfig = AudioFormat.CHANNEL_IN_MONO;
- - break;
- - case 2:
- - channelConfig = AudioFormat.CHANNEL_IN_STEREO;
- - break;
- - default:
- - throw new IllegalArgumentException(
- - String.format(
- - "Number of channels required by the model is %d. getAudioRecord method only"
- - + " supports 1 or 2 audio channels.",
- - format.getChannels()));
- +
- + private int getRequiredChannels() {
- + return getRequiredChannelsNative(getNativeHandle());
- }
-
- - int bufferSizeInBytes =
- - AudioRecord.getMinBufferSize(
- - format.getSampleRate(), channelConfig, AudioFormat.ENCODING_PCM_FLOAT);
- - if (bufferSizeInBytes == AudioRecord.ERROR
- - || bufferSizeInBytes == AudioRecord.ERROR_BAD_VALUE) {
- - throw new IllegalStateException(
- - String.format("AudioRecord.getMinBufferSize failed. Returned: %d", bufferSizeInBytes));
- + private int getRequiredSampleRate() {
- + return getRequiredSampleRateNative(getNativeHandle());
- }
- - // The buffer of AudioRecord should be strictly longer than what model requires so that clients
- - // could run `TensorAudio::load(record)` together with `AudioClassifier::classify`.
- - int bufferSizeMultiplier = 2;
- - int modelRequiredBufferSize =
- - (int) getRequiredInputBufferSize() * DataType.FLOAT32.byteSize() * bufferSizeMultiplier;
- - if (bufferSizeInBytes < modelRequiredBufferSize) {
- - bufferSizeInBytes = modelRequiredBufferSize;
- +
- + // TODO(b/183343074): JNI method invocation is very expensive, taking about .2ms
- + // each time. Consider combining the native getter methods into 1 and cache it in Java layer.
- + private static native long getRequiredInputBufferSizeNative(long nativeHandle);
- +
- + private static native int getRequiredChannelsNative(long nativeHandle);
- +
- + private static native int getRequiredSampleRateNative(long nativeHandle);
- +
- + private static native List<Classifications> classifyNative(
- + long nativeHandle, byte[] audioBuffer, int channels, int sampleRate);
- +
- + private static native long initJniWithModelFdAndOptions(int fileDescriptor,
- + long fileDescriptorLength, long fileDescriptorOffset, AudioClassifierOptions options,
- + long baseOptionsHandle);
- +
- + private static native long initJniWithByteBuffer(
- + ByteBuffer modelBuffer, AudioClassifierOptions options, long baseOptionsHandle);
- +
- + /**
- + * Releases memory pointed by {@code nativeHandle}, namely a C++ `AudioClassifier` instance.
- + *
- + * @param nativeHandle pointer to memory allocated
- + */
- + @Override
- + protected void deinit(long nativeHandle) {
- + deinitJni(nativeHandle);
- }
- - AudioRecord audioRecord =
- - new AudioRecord(
- - // including MIC, UNPROCESSED, and CAMCORDER.
- - MediaRecorder.AudioSource.VOICE_RECOGNITION,
- - format.getSampleRate(),
- - channelConfig,
- - AudioFormat.ENCODING_PCM_FLOAT,
- - bufferSizeInBytes);
- - checkState(
- - audioRecord.getState() == AudioRecord.STATE_INITIALIZED,
- - "AudioRecord failed to initialize");
- - return audioRecord;
- - }
- -
- - /** Returns the {@link TensorAudioFormat} required by the model. */
- - public TensorAudioFormat getRequiredTensorAudioFormat() {
- - return TensorAudioFormat.builder()
- - .setChannels(getRequiredChannels())
- - .setSampleRate(getRequiredSampleRate())
- - .build();
- - }
- -
- - private int getRequiredChannels() {
- - return getRequiredChannelsNative(getNativeHandle());
- - }
- -
- - private int getRequiredSampleRate() {
- - return getRequiredSampleRateNative(getNativeHandle());
- - }
- -
- - // TODO(b/183343074): JNI method invocation is very expensive, taking about .2ms
- - // each time. Consider combining the native getter methods into 1 and cache it in Java layer.
- - private static native long getRequiredInputBufferSizeNative(long nativeHandle);
- -
- - private static native int getRequiredChannelsNative(long nativeHandle);
- -
- - private static native int getRequiredSampleRateNative(long nativeHandle);
- -
- - private static native List<Classifications> classifyNative(
- - long nativeHandle, byte[] audioBuffer, int channels, int sampleRate);
- -
- - private static native long initJniWithModelFdAndOptions(
- - int fileDescriptor,
- - long fileDescriptorLength,
- - long fileDescriptorOffset,
- - AudioClassifierOptions options,
- - long baseOptionsHandle);
- -
- - private static native long initJniWithByteBuffer(
- - ByteBuffer modelBuffer, AudioClassifierOptions options, long baseOptionsHandle);
- -
- - /**
- - * Releases memory pointed by {@code nativeHandle}, namely a C++ `AudioClassifier` instance.
- - *
- - * @param nativeHandle pointer to memory allocated
- - */
- - @Override
- - protected void deinit(long nativeHandle) {
- - deinitJni(nativeHandle);
- - }
- -
- - /**
- - * Native method to release memory pointed by {@code nativeHandle}, namely a C++ `AudioClassifier`
- - * instance.
- - *
- - * @param nativeHandle pointer to memory allocated
- - */
- - private static native void deinitJni(long nativeHandle);
- +
- + /**
- + * Native method to release memory pointed by {@code nativeHandle}, namely a C++
- + * `AudioClassifier` instance.
- + *
- + * @param nativeHandle pointer to memory allocated
- + */
- + private static native void deinitJni(long nativeHandle);
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/audio/classifier/Classifications.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/audio/classifier/Classifications.java
- index 9c0cdf9e249ae..8e8270269dad8 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/audio/classifier/Classifications.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/audio/classifier/Classifications.java
- @@ -16,11 +16,13 @@ limitations under the License.
- package org.tensorflow.lite.task.audio.classifier;
-
- import com.google.auto.value.AutoValue;
- +
- +import org.tensorflow.lite.support.label.Category;
- +import org.tensorflow.lite.task.core.annotations.UsedByReflection;
- +
- import java.util.ArrayList;
- import java.util.Collections;
- import java.util.List;
- -import org.tensorflow.lite.support.label.Category;
- -import org.tensorflow.lite.task.core.annotations.UsedByReflection;
-
- /**
- * The classification results of one head in a multihead (a.k.a. multi-output) {@link
- @@ -31,18 +33,18 @@ import org.tensorflow.lite.task.core.annotations.UsedByReflection;
- @AutoValue
- @UsedByReflection("audio_classifier_jni.cc")
- public abstract class Classifications {
- + @UsedByReflection("audio_classifier_jni.cc")
- + static Classifications create(List<Category> categories, int headIndex, String headName) {
- + return new AutoValue_Classifications(
- + Collections.unmodifiableList(new ArrayList<Category>(categories)), headIndex,
- + headName);
- + }
-
- - @UsedByReflection("audio_classifier_jni.cc")
- - static Classifications create(List<Category> categories, int headIndex, String headName) {
- - return new AutoValue_Classifications(
- - Collections.unmodifiableList(new ArrayList<Category>(categories)), headIndex, headName);
- - }
- -
- - // Same reason for not using ImmutableList as stated in
- - // {@link ImageClassifier#ImageClassifierOptions#labelAllowList}.
- - public abstract List<Category> getCategories();
- + // Same reason for not using ImmutableList as stated in
- + // {@link ImageClassifier#ImageClassifierOptions#labelAllowList}.
- + public abstract List<Category> getCategories();
-
- - public abstract int getHeadIndex();
- + public abstract int getHeadIndex();
-
- - public abstract String getHeadName();
- + public abstract String getHeadName();
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/BaseOptions.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/BaseOptions.java
- index 242414bd21bdb..b2d722332c954 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/BaseOptions.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/BaseOptions.java
- @@ -20,65 +20,66 @@ import com.google.auto.value.AutoValue;
- /** Options to configure Task APIs in general. */
- @AutoValue
- public abstract class BaseOptions {
- - private static final int DEFAULT_NUM_THREADS = -1;
- + private static final int DEFAULT_NUM_THREADS = -1;
-
- - /** Builder for {@link BaseOptions}. */
- - @AutoValue.Builder
- - public abstract static class Builder {
- + /** Builder for {@link BaseOptions}. */
- + @AutoValue.Builder
- + public abstract static class Builder {
- + /**
- + * Sets the advanced accelerator options.
- + *
- + * <p>Note: this method will override those highlevel API to choose an delegate, such as
- + * {@link #useGpu} and {@link #useNnapi}.
- + */
- + public abstract Builder setComputeSettings(ComputeSettings computeSettings);
-
- - /**
- - * Sets the advanced accelerator options.
- - *
- - * <p>Note: this method will override those highlevel API to choose an delegate, such as {@link
- - * #useGpu} and {@link #useNnapi}.
- - */
- - public abstract Builder setComputeSettings(ComputeSettings computeSettings);
- + /**
- + * Sets the number of threads to be used for TFLite ops that support multi-threading when
- + * running inference with CPU. Defaults to -1.
- + *
- + * <p>{@code numThreads} should be greater than 0 or equal to -1. Setting numThreads to -1
- + * has the effect to let TFLite runtime set the value.
- + */
- + public abstract Builder setNumThreads(int numThreads);
-
- - /**
- - * Sets the number of threads to be used for TFLite ops that support multi-threading when
- - * running inference with CPU. Defaults to -1.
- - *
- - * <p>{@code numThreads} should be greater than 0 or equal to -1. Setting numThreads to -1 has
- - * the effect to let TFLite runtime set the value.
- - */
- - public abstract Builder setNumThreads(int numThreads);
- + /**
- + * Uses GPU for inference. The advanced GPU configuration settings will be set to default
- + * values.
- + *
- + * <p>Note: this method will override the settings from {@link #setComputeSettings}.
- + *
- + * <p>To manipulate the advanced GPU configuration settings, use {@link
- + * #setComputeSettings}.
- + */
- + public Builder useGpu() {
- + return setComputeSettings(
- + ComputeSettings.builder().setDelegate(ComputeSettings.Delegate.GPU).build());
- + }
-
- - /**
- - * Uses GPU for inference. The advanced GPU configuration settings will be set to default
- - * values.
- - *
- - * <p>Note: this method will override the settings from {@link #setComputeSettings}.
- - *
- - * <p>To manipulate the advanced GPU configuration settings, use {@link #setComputeSettings}.
- - */
- - public Builder useGpu() {
- - return setComputeSettings(
- - ComputeSettings.builder().setDelegate(ComputeSettings.Delegate.GPU).build());
- - }
- + /**
- + * Uses NNAPI for inference. The advanced NNAPI configuration settings will be set to
- + * default values.
- + *
- + * <p>Note: this method will override the settings from {@link #setComputeSettings}.
- + *
- + * <p>To manipulate the advanced NNAPI configuration settings, use {@link
- + * #setComputeSettings}.
- + */
- + public Builder useNnapi() {
- + return setComputeSettings(
- + ComputeSettings.builder().setDelegate(ComputeSettings.Delegate.NNAPI).build());
- + }
-
- - /**
- - * Uses NNAPI for inference. The advanced NNAPI configuration settings will be set to default
- - * values.
- - *
- - * <p>Note: this method will override the settings from {@link #setComputeSettings}.
- - *
- - * <p>To manipulate the advanced NNAPI configuration settings, use {@link #setComputeSettings}.
- - */
- - public Builder useNnapi() {
- - return setComputeSettings(
- - ComputeSettings.builder().setDelegate(ComputeSettings.Delegate.NNAPI).build());
- + public abstract BaseOptions build();
- }
-
- - public abstract BaseOptions build();
- - }
- -
- - public static Builder builder() {
- - return new AutoValue_BaseOptions.Builder()
- - .setComputeSettings(ComputeSettings.builder().build())
- - .setNumThreads(DEFAULT_NUM_THREADS);
- - }
- + public static Builder builder() {
- + return new AutoValue_BaseOptions.Builder()
- + .setComputeSettings(ComputeSettings.builder().build())
- + .setNumThreads(DEFAULT_NUM_THREADS);
- + }
-
- - abstract ComputeSettings getComputeSettings();
- + abstract ComputeSettings getComputeSettings();
-
- - abstract int getNumThreads();
- + abstract int getNumThreads();
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/BaseTaskApi.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/BaseTaskApi.java
- index b3fe9def83c69..a8ae65cd1cf3b 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/BaseTaskApi.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/BaseTaskApi.java
- @@ -16,76 +16,78 @@ limitations under the License.
- package org.tensorflow.lite.task.core;
-
- import android.util.Log;
- +
- import java.io.Closeable;
-
- /**
- * Base class for Task API, provides shared logic to load/unload native libs to its C++ counterpart.
- */
- public abstract class BaseTaskApi implements Closeable {
- - private static final String TAG = BaseTaskApi.class.getSimpleName();
- -
- - /**
- - * Represents a pointer to the corresponding C++ task_api object. The nativeHandle pointer is
- - * initialized from subclasses and must be released by calling {@link #deinit} after it is no
- - * longer needed.
- - */
- - private final long nativeHandle;
- -
- - /** Indicates whether the {@link #nativeHandle} pointer has been released yet. */
- - private boolean closed;
- -
- - /**
- - * Constructor to initialize the JNI with a pointer from C++.
- - *
- - * @param nativeHandle a pointer referencing memory allocated in C++.
- - */
- - protected BaseTaskApi(long nativeHandle) {
- - if (nativeHandle == TaskJniUtils.INVALID_POINTER) {
- - throw new IllegalArgumentException("Failed to load C++ pointer from JNI");
- + private static final String TAG = BaseTaskApi.class.getSimpleName();
- +
- + /**
- + * Represents a pointer to the corresponding C++ task_api object. The nativeHandle pointer is
- + * initialized from subclasses and must be released by calling {@link #deinit} after it is no
- + * longer needed.
- + */
- + private final long nativeHandle;
- +
- + /** Indicates whether the {@link #nativeHandle} pointer has been released yet. */
- + private boolean closed;
- +
- + /**
- + * Constructor to initialize the JNI with a pointer from C++.
- + *
- + * @param nativeHandle a pointer referencing memory allocated in C++.
- + */
- + protected BaseTaskApi(long nativeHandle) {
- + if (nativeHandle == TaskJniUtils.INVALID_POINTER) {
- + throw new IllegalArgumentException("Failed to load C++ pointer from JNI");
- + }
- + this.nativeHandle = nativeHandle;
- + }
- +
- + public boolean isClosed() {
- + return closed;
- }
- - this.nativeHandle = nativeHandle;
- - }
- -
- - public boolean isClosed() {
- - return closed;
- - }
- -
- - /** Release the memory allocated from C++ and deregister the library from the static holder. */
- - @Override
- - public synchronized void close() {
- - if (closed) {
- - return;
- +
- + /** Release the memory allocated from C++ and deregister the library from the static holder. */
- + @Override
- + public synchronized void close() {
- + if (closed) {
- + return;
- + }
- + deinit(nativeHandle);
- + closed = true;
- }
- - deinit(nativeHandle);
- - closed = true;
- - }
-
- - public long getNativeHandle() {
- - return nativeHandle;
- - }
- + public long getNativeHandle() {
- + return nativeHandle;
- + }
-
- - protected void checkNotClosed() {
- - if (isClosed()) {
- - throw new IllegalStateException("Internal error: The task lib has already been closed.");
- + protected void checkNotClosed() {
- + if (isClosed()) {
- + throw new IllegalStateException(
- + "Internal error: The task lib has already been closed.");
- + }
- }
- - }
- -
- - @Override
- - protected void finalize() throws Throwable {
- - try {
- - if (!closed) {
- - Log.w(TAG, "Closing an already closed native lib");
- - close();
- - }
- - } finally {
- - super.finalize();
- +
- + @Override
- + protected void finalize() throws Throwable {
- + try {
- + if (!closed) {
- + Log.w(TAG, "Closing an already closed native lib");
- + close();
- + }
- + } finally {
- + super.finalize();
- + }
- }
- - }
- -
- - /**
- - * Releases memory pointed by the pointer in the native layer.
- - *
- - * @param nativeHandle pointer to memory allocated
- - */
- - protected abstract void deinit(long nativeHandle);
- +
- + /**
- + * Releases memory pointed by the pointer in the native layer.
- + *
- + * @param nativeHandle pointer to memory allocated
- + */
- + protected abstract void deinit(long nativeHandle);
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/ComputeSettings.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/ComputeSettings.java
- index 80a9e82ff3802..0c2d04283594d 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/ComputeSettings.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/ComputeSettings.java
- @@ -20,38 +20,36 @@ import com.google.auto.value.AutoValue;
- /** Options to configure how to accelerate the model inference using dedicated delegates. */
- @AutoValue
- public abstract class ComputeSettings {
- + /** TFLite accelerator delegate options. */
- + public enum Delegate {
- + NONE(0),
- + NNAPI(1),
- + GPU(2);
-
- - /** TFLite accelerator delegate options. */
- - public enum Delegate {
- - NONE(0),
- - NNAPI(1),
- - GPU(2);
- + private final int value;
-
- - private final int value;
- + Delegate(int value) {
- + this.value = value;
- + }
-
- - Delegate(int value) {
- - this.value = value;
- + public int getValue() {
- + return value;
- + }
- }
-
- - public int getValue() {
- - return value;
- - }
- - }
- -
- - /** Builder for {@link ComputeSettings}. */
- - @AutoValue.Builder
- - public abstract static class Builder {
- -
- - public abstract Builder setDelegate(Delegate delegate);
- + /** Builder for {@link ComputeSettings}. */
- + @AutoValue.Builder
- + public abstract static class Builder {
- + public abstract Builder setDelegate(Delegate delegate);
-
- - public abstract ComputeSettings build();
- - }
- + public abstract ComputeSettings build();
- + }
-
- - public static Builder builder() {
- - return new AutoValue_ComputeSettings.Builder().setDelegate(DEFAULT_DELEGATE);
- - }
- + public static Builder builder() {
- + return new AutoValue_ComputeSettings.Builder().setDelegate(DEFAULT_DELEGATE);
- + }
-
- - public abstract Delegate getDelegate();
- + public abstract Delegate getDelegate();
-
- - private static final Delegate DEFAULT_DELEGATE = Delegate.NONE;
- + private static final Delegate DEFAULT_DELEGATE = Delegate.NONE;
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/TaskJniUtils.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/TaskJniUtils.java
- index 76109f453b01f..9d5b775456c43 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/TaskJniUtils.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/TaskJniUtils.java
- @@ -18,6 +18,7 @@ package org.tensorflow.lite.task.core;
- import android.content.Context;
- import android.content.res.AssetFileDescriptor;
- import android.util.Log;
- +
- import java.io.FileInputStream;
- import java.io.IOException;
- import java.nio.ByteBuffer;
- @@ -26,156 +27,146 @@ import java.nio.channels.FileChannel;
-
- /** JNI utils for Task API. */
- public class TaskJniUtils {
- - public static final long INVALID_POINTER = 0;
- - private static final String TAG = TaskJniUtils.class.getSimpleName();
- - /** Syntax sugar to get nativeHandle from empty param list. */
- - public interface EmptyHandleProvider {
- - long createHandle();
- - }
- -
- - /** Syntax sugar to get nativeHandle from an array of {@link ByteBuffer}s. */
- - public interface MultipleBuffersHandleProvider {
- - long createHandle(ByteBuffer... buffers);
- - }
- -
- - /** Syntax sugar to get nativeHandle from file descriptor and options. */
- - public interface FdAndOptionsHandleProvider<T> {
- - long createHandle(
- - int fileDescriptor, long fileDescriptorLength, long fileDescriptorOffset, T options);
- - }
- -
- - /**
- - * Initializes the JNI and returns C++ handle with file descriptor and options for task API.
- - *
- - * @param context the Android app context
- - * @param provider provider to get C++ handle, usually returned from native call
- - * @param libName name of C++ lib to be loaded
- - * @param filePath path of the file to be loaded
- - * @param options options to set up the task API, used by the provider
- - * @return C++ handle as long
- - * @throws IOException If model file fails to load.
- - */
- - public static <T> long createHandleFromFdAndOptions(
- - Context context,
- - final FdAndOptionsHandleProvider<T> provider,
- - String libName,
- - String filePath,
- - final T options)
- - throws IOException {
- - try (AssetFileDescriptor assetFileDescriptor = context.getAssets().openFd(filePath)) {
- - return createHandleFromLibrary(
- - new EmptyHandleProvider() {
- + public static final long INVALID_POINTER = 0;
- + private static final String TAG = TaskJniUtils.class.getSimpleName();
- + /** Syntax sugar to get nativeHandle from empty param list. */
- + public interface EmptyHandleProvider {
- + long createHandle();
- + }
- +
- + /** Syntax sugar to get nativeHandle from an array of {@link ByteBuffer}s. */
- + public interface MultipleBuffersHandleProvider {
- + long createHandle(ByteBuffer... buffers);
- + }
- +
- + /** Syntax sugar to get nativeHandle from file descriptor and options. */
- + public interface FdAndOptionsHandleProvider<T> {
- + long createHandle(int fileDescriptor, long fileDescriptorLength, long fileDescriptorOffset,
- + T options);
- + }
- +
- + /**
- + * Initializes the JNI and returns C++ handle with file descriptor and options for task API.
- + *
- + * @param context the Android app context
- + * @param provider provider to get C++ handle, usually returned from native call
- + * @param libName name of C++ lib to be loaded
- + * @param filePath path of the file to be loaded
- + * @param options options to set up the task API, used by the provider
- + * @return C++ handle as long
- + * @throws IOException If model file fails to load.
- + */
- + public static <T> long createHandleFromFdAndOptions(Context context,
- + final FdAndOptionsHandleProvider<T> provider, String libName, String filePath,
- + final T options) throws IOException {
- + try (AssetFileDescriptor assetFileDescriptor = context.getAssets().openFd(filePath)) {
- + return createHandleFromLibrary(new EmptyHandleProvider() {
- + @Override
- + public long createHandle() {
- + return provider.createHandle(
- + /*fileDescriptor=*/assetFileDescriptor.getParcelFileDescriptor()
- + .getFd(),
- + /*fileDescriptorLength=*/assetFileDescriptor.getLength(),
- + /*fileDescriptorOffset=*/assetFileDescriptor.getStartOffset(), options);
- + }
- + }, libName);
- + }
- + }
- +
- + /**
- + * Initializes the JNI and returns C++ handle by first loading the C++ library and then invokes
- + * {@link EmptyHandleProvider#createHandle()}.
- + *
- + * @param provider provider to get C++ handle, usually returned from native call
- + * @return C++ handle as long
- + */
- + public static long createHandleFromLibrary(EmptyHandleProvider provider, String libName) {
- + tryLoadLibrary(libName);
- + try {
- + return provider.createHandle();
- + } catch (RuntimeException e) {
- + String errorMessage = "Error getting native address of native library: " + libName;
- + Log.e(TAG, errorMessage, e);
- + throw new IllegalStateException(errorMessage, e);
- + }
- + }
- +
- + /**
- + * Initializes the JNI and returns C++ handle by first loading the C++ library and then invokes
- + * {@link MultipleBuffersHandleProvider#createHandle(ByteBuffer...)}.
- + *
- + * @param context app context
- + * @param provider provider to get C++ pointer, usually returned from native call
- + * @param libName name of C++ lib to load
- + * @param filePaths file paths to load
- + * @return C++ pointer as long
- + * @throws IOException If model file fails to load.
- + */
- + public static long createHandleWithMultipleAssetFilesFromLibrary(Context context,
- + final MultipleBuffersHandleProvider provider, String libName, String... filePaths)
- + throws IOException {
- + final MappedByteBuffer[] buffers = new MappedByteBuffer[filePaths.length];
- + for (int i = 0; i < filePaths.length; i++) {
- + buffers[i] = loadMappedFile(context, filePaths[i]);
- + }
- + return createHandleFromLibrary(new EmptyHandleProvider() {
- @Override
- public long createHandle() {
- - return provider.createHandle(
- - /*fileDescriptor=*/ assetFileDescriptor.getParcelFileDescriptor().getFd(),
- - /*fileDescriptorLength=*/ assetFileDescriptor.getLength(),
- - /*fileDescriptorOffset=*/ assetFileDescriptor.getStartOffset(),
- - options);
- + return provider.createHandle(buffers);
- }
- - },
- - libName);
- - }
- - }
- -
- - /**
- - * Initializes the JNI and returns C++ handle by first loading the C++ library and then invokes
- - * {@link EmptyHandleProvider#createHandle()}.
- - *
- - * @param provider provider to get C++ handle, usually returned from native call
- - * @return C++ handle as long
- - */
- - public static long createHandleFromLibrary(EmptyHandleProvider provider, String libName) {
- - tryLoadLibrary(libName);
- - try {
- - return provider.createHandle();
- - } catch (RuntimeException e) {
- - String errorMessage = "Error getting native address of native library: " + libName;
- - Log.e(TAG, errorMessage, e);
- - throw new IllegalStateException(errorMessage, e);
- - }
- - }
- -
- - /**
- - * Initializes the JNI and returns C++ handle by first loading the C++ library and then invokes
- - * {@link MultipleBuffersHandleProvider#createHandle(ByteBuffer...)}.
- - *
- - * @param context app context
- - * @param provider provider to get C++ pointer, usually returned from native call
- - * @param libName name of C++ lib to load
- - * @param filePaths file paths to load
- - * @return C++ pointer as long
- - * @throws IOException If model file fails to load.
- - */
- - public static long createHandleWithMultipleAssetFilesFromLibrary(
- - Context context,
- - final MultipleBuffersHandleProvider provider,
- - String libName,
- - String... filePaths)
- - throws IOException {
- - final MappedByteBuffer[] buffers = new MappedByteBuffer[filePaths.length];
- - for (int i = 0; i < filePaths.length; i++) {
- - buffers[i] = loadMappedFile(context, filePaths[i]);
- + }, libName);
- }
- - return createHandleFromLibrary(
- - new EmptyHandleProvider() {
- - @Override
- - public long createHandle() {
- - return provider.createHandle(buffers);
- - }
- - },
- - libName);
- - }
- -
- - /**
- - * Loads a file from the asset folder through memory mapping.
- - *
- - * @param context Application context to access assets.
- - * @param filePath Asset path of the file.
- - * @return the loaded memory mapped file.
- - * @throws IOException If model file fails to load.
- - */
- - public static MappedByteBuffer loadMappedFile(Context context, String filePath)
- - throws IOException {
- - try (AssetFileDescriptor fileDescriptor = context.getAssets().openFd(filePath);
- - FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor())) {
- - FileChannel fileChannel = inputStream.getChannel();
- - long startOffset = fileDescriptor.getStartOffset();
- - long declaredLength = fileDescriptor.getDeclaredLength();
- - return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
- +
- + /**
- + * Loads a file from the asset folder through memory mapping.
- + *
- + * @param context Application context to access assets.
- + * @param filePath Asset path of the file.
- + * @return the loaded memory mapped file.
- + * @throws IOException If model file fails to load.
- + */
- + public static MappedByteBuffer loadMappedFile(Context context, String filePath)
- + throws IOException {
- + try (AssetFileDescriptor fileDescriptor = context.getAssets().openFd(filePath);
- + FileInputStream inputStream =
- + new FileInputStream(fileDescriptor.getFileDescriptor())) {
- + FileChannel fileChannel = inputStream.getChannel();
- + long startOffset = fileDescriptor.getStartOffset();
- + long declaredLength = fileDescriptor.getDeclaredLength();
- + return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
- + }
- }
- - }
- -
- - /**
- - * Try loading a native library, if it's already loaded return directly.
- - *
- - * @param libName name of the lib
- - */
- - public static void tryLoadLibrary(String libName) {
- - try {
- - System.loadLibrary(libName);
- - } catch (UnsatisfiedLinkError e) {
- - String errorMessage = "Error loading native library: " + libName;
- - Log.e(TAG, errorMessage, e);
- - throw new UnsatisfiedLinkError(errorMessage);
- +
- + /**
- + * Try loading a native library, if it's already loaded return directly.
- + *
- + * @param libName name of the lib
- + */
- + public static void tryLoadLibrary(String libName) {
- + try {
- + System.loadLibrary(libName);
- + } catch (UnsatisfiedLinkError e) {
- + String errorMessage = "Error loading native library: " + libName;
- + Log.e(TAG, errorMessage, e);
- + throw new UnsatisfiedLinkError(errorMessage);
- + }
- }
- - }
-
- - public static long createProtoBaseOptionsHandle(BaseOptions baseOptions) {
- - return createProtoBaseOptionsHandleWithLegacyNumThreads(baseOptions, /*legacyNumThreads =*/ -1);
- - }
- + public static long createProtoBaseOptionsHandle(BaseOptions baseOptions) {
- + return createProtoBaseOptionsHandleWithLegacyNumThreads(
- + baseOptions, /*legacyNumThreads =*/-1);
- + }
-
- - public static long createProtoBaseOptionsHandleWithLegacyNumThreads(
- - BaseOptions baseOptions, int legacyNumThreads) {
- - // NumThreads should be configured through BaseOptions. However, if NumThreads is configured
- - // through the legacy API of the Task Java API (then it will not equal to -1, the default
- - // value), use it to overide the one in baseOptions.
- - return createProtoBaseOptions(
- - baseOptions.getComputeSettings().getDelegate().getValue(),
- - legacyNumThreads == -1 ? baseOptions.getNumThreads() : legacyNumThreads);
- - }
- + public static long createProtoBaseOptionsHandleWithLegacyNumThreads(
- + BaseOptions baseOptions, int legacyNumThreads) {
- + // NumThreads should be configured through BaseOptions. However, if NumThreads is configured
- + // through the legacy API of the Task Java API (then it will not equal to -1, the default
- + // value), use it to overide the one in baseOptions.
- + return createProtoBaseOptions(baseOptions.getComputeSettings().getDelegate().getValue(),
- + legacyNumThreads == -1 ? baseOptions.getNumThreads() : legacyNumThreads);
- + }
-
- - private TaskJniUtils() {}
- + private TaskJniUtils() {}
-
- - private static native long createProtoBaseOptions(int delegate, int numThreads);
- + private static native long createProtoBaseOptions(int delegate, int numThreads);
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/annotations/UsedByReflection.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/annotations/UsedByReflection.java
- index bfa1ea750cf1f..fb1dfec82d7b4 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/annotations/UsedByReflection.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/annotations/UsedByReflection.java
- @@ -27,5 +27,5 @@ import java.lang.annotation.Target;
- */
- @Target({ElementType.METHOD, ElementType.FIELD, ElementType.TYPE, ElementType.CONSTRUCTOR})
- public @interface UsedByReflection {
- - String value();
- + String value();
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/vision/ImageProcessingOptions.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/vision/ImageProcessingOptions.java
- index 287ba444c386b..b1784d02f2362 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/vision/ImageProcessingOptions.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/vision/ImageProcessingOptions.java
- @@ -16,6 +16,7 @@ limitations under the License.
- package org.tensorflow.lite.task.core.vision;
-
- import android.graphics.Rect;
- +
- import com.google.auto.value.AutoValue;
-
- /**
- @@ -45,74 +46,74 @@ import com.google.auto.value.AutoValue;
- */
- @AutoValue
- public abstract class ImageProcessingOptions {
- -
- - /**
- - * Orientation type that follows EXIF specification.
- - *
- - * <p>The name of each enum value defines the position of the 0th row and the 0th column of the
- - * image content. See the <a href="http://jpegclub.org/exif_orientation.html">EXIF orientation
- - * documentation</a> for details.
- - */
- - public enum Orientation {
- - TOP_LEFT(0),
- - TOP_RIGHT(1),
- - BOTTOM_RIGHT(2),
- - BOTTOM_LEFT(3),
- - LEFT_TOP(4),
- - RIGHT_TOP(5),
- - RIGHT_BOTTOM(6),
- - LEFT_BOTTOM(7);
- -
- - private final int value;
- -
- - Orientation(int value) {
- - this.value = value;
- - }
- -
- - public int getValue() {
- - return value;
- - }
- - };
- -
- - private static final Rect defaultRoi = new Rect();
- - private static final Orientation DEFAULT_ORIENTATION = Orientation.TOP_LEFT;
- -
- - public abstract Rect getRoi();
- -
- - public abstract Orientation getOrientation();
- -
- - public static Builder builder() {
- - return new AutoValue_ImageProcessingOptions.Builder()
- - .setRoi(defaultRoi)
- - .setOrientation(DEFAULT_ORIENTATION);
- - }
- -
- - /** Builder for {@link ImageProcessingOptions}. */
- - @AutoValue.Builder
- - public abstract static class Builder {
- -
- /**
- - * Sets the region of interest (ROI) of the image. Defaults to the entire image.
- + * Orientation type that follows EXIF specification.
- *
- - * <p>Cropping according to this region of interest is prepended to the pre-processing
- - * operations.
- + * <p>The name of each enum value defines the position of the 0th row and the 0th column of the
- + * image content. See the <a href="http://jpegclub.org/exif_orientation.html">EXIF orientation
- + * documentation</a> for details.
- */
- - public abstract Builder setRoi(Rect roi);
- + public enum Orientation {
- + TOP_LEFT(0),
- + TOP_RIGHT(1),
- + BOTTOM_RIGHT(2),
- + BOTTOM_LEFT(3),
- + LEFT_TOP(4),
- + RIGHT_TOP(5),
- + RIGHT_BOTTOM(6),
- + LEFT_BOTTOM(7);
- +
- + private final int value;
- +
- + Orientation(int value) {
- + this.value = value;
- + }
- +
- + public int getValue() {
- + return value;
- + }
- + }
- + ;
-
- - /**
- - * Sets the orientation of the image. Defaults to {@link Orientation#TOP_LEFT}.
- - *
- - * <p>Rotation will be applied accordingly so that inference is performed on an "upright" image.
- - */
- - public abstract Builder setOrientation(Orientation orientation);
- + private static final Rect defaultRoi = new Rect();
- + private static final Orientation DEFAULT_ORIENTATION = Orientation.TOP_LEFT;
-
- - abstract Rect getRoi();
- + public abstract Rect getRoi();
-
- - abstract ImageProcessingOptions autoBuild();
- + public abstract Orientation getOrientation();
- +
- + public static Builder builder() {
- + return new AutoValue_ImageProcessingOptions.Builder()
- + .setRoi(defaultRoi)
- + .setOrientation(DEFAULT_ORIENTATION);
- + }
-
- - public ImageProcessingOptions build() {
- - setRoi(new Rect(getRoi())); // Make a defensive copy, since Rect is mutable.
- - return autoBuild();
- + /** Builder for {@link ImageProcessingOptions}. */
- + @AutoValue.Builder
- + public abstract static class Builder {
- + /**
- + * Sets the region of interest (ROI) of the image. Defaults to the entire image.
- + *
- + * <p>Cropping according to this region of interest is prepended to the pre-processing
- + * operations.
- + */
- + public abstract Builder setRoi(Rect roi);
- +
- + /**
- + * Sets the orientation of the image. Defaults to {@link Orientation#TOP_LEFT}.
- + *
- + * <p>Rotation will be applied accordingly so that inference is performed on an "upright"
- + * image.
- + */
- + public abstract Builder setOrientation(Orientation orientation);
- +
- + abstract Rect getRoi();
- +
- + abstract ImageProcessingOptions autoBuild();
- +
- + public ImageProcessingOptions build() {
- + setRoi(new Rect(getRoi())); // Make a defensive copy, since Rect is mutable.
- + return autoBuild();
- + }
- }
- - }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/processor/NearestNeighbor.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/processor/NearestNeighbor.java
- index f5cc5af615117..a39247f1239c8 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/processor/NearestNeighbor.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/processor/NearestNeighbor.java
- @@ -16,37 +16,38 @@ limitations under the License.
- package org.tensorflow.lite.task.processor;
-
- import com.google.auto.value.AutoValue;
- +
- +import org.tensorflow.lite.task.core.annotations.UsedByReflection;
- +
- import java.nio.ByteBuffer;
- import java.nio.ByteOrder;
- -import org.tensorflow.lite.task.core.annotations.UsedByReflection;
-
- /** Represents the search result of a Searcher model. */
- @AutoValue
- @UsedByReflection("searcher_jni.cc")
- public abstract class NearestNeighbor {
- -
- - @UsedByReflection("searcher_jni.cc")
- - static NearestNeighbor create(byte[] metadataArray, float distance) {
- - // Convert byte[] metadataArray to ByteBuffer which handles endianess better.
- - //
- - // Ideally, the API should accept a ByteBuffer instead of a byte[]. However, converting byte[]
- - // to ByteBuffer in JNI will lead to unnecessarily complex code which involves 6 more reflection
- - // calls. We can make this method package private, because users in general shouldn't need to
- - // create NearestNeighbor instances, but only consume the objects return from Task Library. This
- - // API will be used mostly for internal purpose.
- - ByteBuffer metadata = ByteBuffer.wrap(metadataArray);
- - metadata.order(ByteOrder.nativeOrder());
- - return new AutoValue_NearestNeighbor(metadata, distance);
- - }
- -
- - /**
- - * Gets the user-defined metadata about the result. This could be a label, a unique ID, a
- - * serialized proto of some sort, etc.
- - *
- - * <p><b>Do not mutate</b> the returned metadata.
- - */
- - public abstract ByteBuffer getMetadata();
- -
- - /** Gets the distance score indicating how confident the result is. Lower is better. */
- - public abstract float getDistance();
- + @UsedByReflection("searcher_jni.cc")
- + static NearestNeighbor create(byte[] metadataArray, float distance) {
- + // Convert byte[] metadataArray to ByteBuffer which handles endianess better.
- + //
- + // Ideally, the API should accept a ByteBuffer instead of a byte[]. However, converting
- + // byte[] to ByteBuffer in JNI will lead to unnecessarily complex code which involves 6 more
- + // reflection calls. We can make this method package private, because users in general
- + // shouldn't need to create NearestNeighbor instances, but only consume the objects return
- + // from Task Library. This API will be used mostly for internal purpose.
- + ByteBuffer metadata = ByteBuffer.wrap(metadataArray);
- + metadata.order(ByteOrder.nativeOrder());
- + return new AutoValue_NearestNeighbor(metadata, distance);
- + }
- +
- + /**
- + * Gets the user-defined metadata about the result. This could be a label, a unique ID, a
- + * serialized proto of some sort, etc.
- + *
- + * <p><b>Do not mutate</b> the returned metadata.
- + */
- + public abstract ByteBuffer getMetadata();
- +
- + /** Gets the distance score indicating how confident the result is. Lower is better. */
- + public abstract float getDistance();
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/processor/SearcherOptions.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/processor/SearcherOptions.java
- index fa601edf92b30..86f5fdde0187c 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/processor/SearcherOptions.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/processor/SearcherOptions.java
- @@ -16,66 +16,68 @@ limitations under the License.
- package org.tensorflow.lite.task.processor;
-
- import androidx.annotation.Nullable;
- +
- import com.google.auto.value.AutoValue;
- +
- import java.io.File;
-
- /** Options to configure Searcher API. */
- @AutoValue
- public abstract class SearcherOptions {
- - private static final boolean DEFAULT_L2_NORMALIZE = false;
- - private static final boolean DEFAULT_QUANTIZE = false;
- - private static final int DEFAULT_MAX_RESULTS = 5;
- -
- - public abstract boolean getL2Normalize();
- -
- - public abstract boolean getQuantize();
- -
- - @Nullable
- - public abstract File getIndexFile();
- -
- - public abstract int getMaxResults();
- -
- - public static Builder builder() {
- - return new AutoValue_SearcherOptions.Builder()
- - .setL2Normalize(DEFAULT_L2_NORMALIZE)
- - .setQuantize(DEFAULT_QUANTIZE)
- - .setIndexFile(null)
- - .setMaxResults(DEFAULT_MAX_RESULTS);
- - }
- -
- - /** Builder for {@link SearcherOptions}. */
- - @AutoValue.Builder
- - public abstract static class Builder {
- - /**
- - * Sets whether to normalize the embedding feature vector with L2 norm. Defaults to false.
- - *
- - * <p>Use this option only if the model does not already contain a native L2_NORMALIZATION
- - * TFLite Op. In most cases, this is already the case and L2 norm is thus achieved through
- - * TFLite inference.
- - */
- - public abstract Builder setL2Normalize(boolean l2Normalize);
- -
- - /**
- - * Sets whether the embedding should be quantized to bytes via scalar quantization. Defaults to
- - * false.
- - *
- - * <p>Embeddings are implicitly assumed to be unit-norm and therefore any dimension is
- - * guaranteed to have a value in {@code [-1.0, 1.0]}. Use the l2_normalize option if this is not
- - * the case.
- - */
- - public abstract Builder setQuantize(boolean quantize);
- -
- - /**
- - * Sets the index file to search into.
- - *
- - * <p>Required if the model does not come with an index file inside. Otherwise, it can be ignore
- - * by setting to {@code null}.
- - */
- - public abstract Builder setIndexFile(@Nullable File indexFile);
- -
- - /** Sets the maximum number of nearest neighbor results to return. Defaults to {@code 5} */
- - public abstract Builder setMaxResults(int maxResults);
- -
- - public abstract SearcherOptions build();
- - }
- + private static final boolean DEFAULT_L2_NORMALIZE = false;
- + private static final boolean DEFAULT_QUANTIZE = false;
- + private static final int DEFAULT_MAX_RESULTS = 5;
- +
- + public abstract boolean getL2Normalize();
- +
- + public abstract boolean getQuantize();
- +
- + @Nullable
- + public abstract File getIndexFile();
- +
- + public abstract int getMaxResults();
- +
- + public static Builder builder() {
- + return new AutoValue_SearcherOptions.Builder()
- + .setL2Normalize(DEFAULT_L2_NORMALIZE)
- + .setQuantize(DEFAULT_QUANTIZE)
- + .setIndexFile(null)
- + .setMaxResults(DEFAULT_MAX_RESULTS);
- + }
- +
- + /** Builder for {@link SearcherOptions}. */
- + @AutoValue.Builder
- + public abstract static class Builder {
- + /**
- + * Sets whether to normalize the embedding feature vector with L2 norm. Defaults to false.
- + *
- + * <p>Use this option only if the model does not already contain a native L2_NORMALIZATION
- + * TFLite Op. In most cases, this is already the case and L2 norm is thus achieved through
- + * TFLite inference.
- + */
- + public abstract Builder setL2Normalize(boolean l2Normalize);
- +
- + /**
- + * Sets whether the embedding should be quantized to bytes via scalar quantization. Defaults
- + * to false.
- + *
- + * <p>Embeddings are implicitly assumed to be unit-norm and therefore any dimension is
- + * guaranteed to have a value in {@code [-1.0, 1.0]}. Use the l2_normalize option if this is
- + * not the case.
- + */
- + public abstract Builder setQuantize(boolean quantize);
- +
- + /**
- + * Sets the index file to search into.
- + *
- + * <p>Required if the model does not come with an index file inside. Otherwise, it can be
- + * ignore by setting to {@code null}.
- + */
- + public abstract Builder setIndexFile(@Nullable File indexFile);
- +
- + /** Sets the maximum number of nearest neighbor results to return. Defaults to {@code 5} */
- + public abstract Builder setMaxResults(int maxResults);
- +
- + public abstract SearcherOptions build();
- + }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifier.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifier.java
- index 55743055ff408..070b945e72b90 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifier.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifier.java
- @@ -17,12 +17,9 @@ package org.tensorflow.lite.task.text.nlclassifier;
-
- import android.content.Context;
- import android.os.ParcelFileDescriptor;
- +
- import com.google.auto.value.AutoValue;
- -import java.io.File;
- -import java.io.IOException;
- -import java.nio.ByteBuffer;
- -import java.nio.MappedByteBuffer;
- -import java.util.List;
- +
- import org.tensorflow.lite.support.label.Category;
- import org.tensorflow.lite.task.core.BaseOptions;
- import org.tensorflow.lite.task.core.BaseTaskApi;
- @@ -30,6 +27,12 @@ import org.tensorflow.lite.task.core.TaskJniUtils;
- import org.tensorflow.lite.task.core.TaskJniUtils.EmptyHandleProvider;
- import org.tensorflow.lite.task.core.annotations.UsedByReflection;
-
- +import java.io.File;
- +import java.io.IOException;
- +import java.nio.ByteBuffer;
- +import java.nio.MappedByteBuffer;
- +import java.util.List;
- +
- /**
- * Classifier API for NLClassification tasks with Bert models, categorizes string into different
- * classes. The API expects a Bert based TFLite model with metadata populated.
- @@ -45,209 +48,199 @@ import org.tensorflow.lite.task.core.annotations.UsedByReflection;
- * </ul>
- */
- public class BertNLClassifier extends BaseTaskApi {
- + private static final String BERT_NL_CLASSIFIER_NATIVE_LIBNAME = "task_text_jni";
- +
- + /** Options to configure BertNLClassifier. */
- + @AutoValue
- + @UsedByReflection("bert_nl_classifier_jni.cc")
- + public abstract static class BertNLClassifierOptions {
- + static final int DEFAULT_MAX_SEQ_LEN = 128;
- +
- + abstract int getMaxSeqLen();
- +
- + abstract BaseOptions getBaseOptions();
- +
- + public static Builder builder() {
- + return new AutoValue_BertNLClassifier_BertNLClassifierOptions.Builder()
- + .setMaxSeqLen(DEFAULT_MAX_SEQ_LEN)
- + .setBaseOptions(BaseOptions.builder().build());
- + }
- +
- + /** Builder for {@link BertNLClassifierOptions}. */
- + @AutoValue.Builder
- + public abstract static class Builder {
- + /** Sets the general options to configure Task APIs, such as accelerators. */
- + public abstract Builder setBaseOptions(BaseOptions baseOptions);
- +
- + /**
- + * Set the maximum sequence length.
- + *
- + * @deprecated maximum sequence length is now read from the model (i.e. input tensor
- + * size)
- + * automatically
- + */
- + @Deprecated
- + public abstract Builder setMaxSeqLen(int value);
- +
- + public abstract BertNLClassifierOptions build();
- + }
- + }
-
- - private static final String BERT_NL_CLASSIFIER_NATIVE_LIBNAME = "task_text_jni";
- -
- - /** Options to configure BertNLClassifier. */
- - @AutoValue
- - @UsedByReflection("bert_nl_classifier_jni.cc")
- - public abstract static class BertNLClassifierOptions {
- - static final int DEFAULT_MAX_SEQ_LEN = 128;
- -
- - abstract int getMaxSeqLen();
- + /**
- + * Creates {@link BertNLClassifier} from a model file with metadata and default {@link
- + * BertNLClassifierOptions}.
- + *
- + * @param context Android context
- + * @param modelPath Path to the classification model
- + * @return a {@link BertNLClassifier} instance
- + * @throws IOException If model file fails to load
- + * @throws IllegalArgumentException if an argument is invalid
- + * @throws IllegalStateException if there is an internal error
- + * @throws RuntimeException if there is an otherwise unspecified error
- + */
- + public static BertNLClassifier createFromFile(final Context context, final String modelPath)
- + throws IOException {
- + return createFromBuffer(TaskJniUtils.loadMappedFile(context, modelPath));
- + }
-
- - abstract BaseOptions getBaseOptions();
- + /**
- + * Creates {@link BertNLClassifier} from a {@link File} object with metadata and default {@link
- + * BertNLClassifierOptions}.
- + *
- + * @param modelFile The classification model {@link File} instance
- + * @return a {@link BertNLClassifier} instance
- + * @throws IOException If model file fails to load
- + * @throws IllegalArgumentException if an argument is invalid
- + * @throws IllegalStateException if there is an internal error
- + * @throws RuntimeException if there is an otherwise unspecified error
- + */
- + public static BertNLClassifier createFromFile(File modelFile) throws IOException {
- + return createFromFileAndOptions(modelFile, BertNLClassifierOptions.builder().build());
- + }
-
- - public static Builder builder() {
- - return new AutoValue_BertNLClassifier_BertNLClassifierOptions.Builder()
- - .setMaxSeqLen(DEFAULT_MAX_SEQ_LEN)
- - .setBaseOptions(BaseOptions.builder().build());
- + /**
- + * Creates {@link BertNLClassifier} from a model file with metadata and {@link
- + * BertNLClassifierOptions}.
- + *
- + * @param context Android context.
- + * @param modelPath Path to the classification model
- + * @param options to configure the classifier
- + * @return a {@link BertNLClassifier} instance
- + * @throws IOException If model file fails to load
- + * @throws IllegalArgumentException if an argument is invalid
- + * @throws IllegalStateException if there is an internal error
- + * @throws RuntimeException if there is an otherwise unspecified error
- + */
- + public static BertNLClassifier createFromFileAndOptions(final Context context,
- + final String modelPath, BertNLClassifierOptions options) throws IOException {
- + return createFromBufferAndOptions(TaskJniUtils.loadMappedFile(context, modelPath), options);
- }
-
- - /** Builder for {@link BertNLClassifierOptions}. */
- - @AutoValue.Builder
- - public abstract static class Builder {
- + /**
- + * Creates {@link BertNLClassifier} from a {@link File} object with metadata and {@link
- + * BertNLClassifierOptions}.
- + *
- + * @param modelFile The classification model {@link File} instance
- + * @param options to configure the classifier
- + * @return a {@link BertNLClassifier} instance
- + * @throws IOException If model file fails to load
- + * @throws IllegalArgumentException if an argument is invalid
- + * @throws IllegalStateException if there is an internal error
- + * @throws RuntimeException if there is an otherwise unspecified error
- + */
- + public static BertNLClassifier createFromFileAndOptions(
- + File modelFile, final BertNLClassifierOptions options) throws IOException {
- + try (ParcelFileDescriptor descriptor =
- + ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
- + return new BertNLClassifier(
- + TaskJniUtils.createHandleFromLibrary(new EmptyHandleProvider() {
- + @Override
- + public long createHandle() {
- + return initJniWithFileDescriptor(descriptor.getFd(), options,
- + TaskJniUtils.createProtoBaseOptionsHandle(
- + options.getBaseOptions()));
- + }
- + }, BERT_NL_CLASSIFIER_NATIVE_LIBNAME));
- + }
- + }
-
- - /** Sets the general options to configure Task APIs, such as accelerators. */
- - public abstract Builder setBaseOptions(BaseOptions baseOptions);
- + /**
- + * Creates {@link BertNLClassifier} with a model buffer and default {@link
- + * BertNLClassifierOptions}.
- + *
- + * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the model
- + * @return a {@link BertNLClassifier} instance
- + * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
- + * {@link MappedByteBuffer}
- + * @throws IllegalStateException if there is an internal error
- + * @throws RuntimeException if there is an otherwise unspecified error
- + */
- + public static BertNLClassifier createFromBuffer(final ByteBuffer modelBuffer) {
- + return createFromBufferAndOptions(modelBuffer, BertNLClassifierOptions.builder().build());
- + }
-
- - /**
- - * Set the maximum sequence length.
- - *
- - * @deprecated maximum sequence length is now read from the model (i.e. input tensor size)
- - * automatically
- - */
- - @Deprecated
- - public abstract Builder setMaxSeqLen(int value);
- + /**
- + * Creates {@link BertNLClassifier} with a model buffer and {@link BertNLClassifierOptions}.
- + *
- + * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the model
- + * @param options to configure the classifier
- + * @return a {@link BertNLClassifier} instance
- + * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
- + * {@link MappedByteBuffer}
- + * @throws IllegalStateException if there is an internal error
- + * @throws RuntimeException if there is an otherwise unspecified error
- + */
- + public static BertNLClassifier createFromBufferAndOptions(
- + final ByteBuffer modelBuffer, final BertNLClassifierOptions options) {
- + if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) {
- + throw new IllegalArgumentException(
- + "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer.");
- + }
- + return new BertNLClassifier(TaskJniUtils.createHandleFromLibrary(new EmptyHandleProvider() {
- + @Override
- + public long createHandle() {
- + return initJniWithByteBuffer(modelBuffer, options,
- + TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions()));
- + }
- + }, BERT_NL_CLASSIFIER_NATIVE_LIBNAME));
- + }
-
- - public abstract BertNLClassifierOptions build();
- + /**
- + * Performs classification on a string input, returns classified {@link Category}s.
- + *
- + * @param text input text to the model.
- + * @return A list of Category results.
- + */
- + public List<Category> classify(String text) {
- + return classifyNative(getNativeHandle(), text);
- }
- - }
- -
- - /**
- - * Creates {@link BertNLClassifier} from a model file with metadata and default {@link
- - * BertNLClassifierOptions}.
- - *
- - * @param context Android context
- - * @param modelPath Path to the classification model
- - * @return a {@link BertNLClassifier} instance
- - * @throws IOException If model file fails to load
- - * @throws IllegalArgumentException if an argument is invalid
- - * @throws IllegalStateException if there is an internal error
- - * @throws RuntimeException if there is an otherwise unspecified error
- - */
- - public static BertNLClassifier createFromFile(final Context context, final String modelPath)
- - throws IOException {
- - return createFromBuffer(TaskJniUtils.loadMappedFile(context, modelPath));
- - }
- -
- - /**
- - * Creates {@link BertNLClassifier} from a {@link File} object with metadata and default {@link
- - * BertNLClassifierOptions}.
- - *
- - * @param modelFile The classification model {@link File} instance
- - * @return a {@link BertNLClassifier} instance
- - * @throws IOException If model file fails to load
- - * @throws IllegalArgumentException if an argument is invalid
- - * @throws IllegalStateException if there is an internal error
- - * @throws RuntimeException if there is an otherwise unspecified error
- - */
- - public static BertNLClassifier createFromFile(File modelFile) throws IOException {
- - return createFromFileAndOptions(modelFile, BertNLClassifierOptions.builder().build());
- - }
- -
- - /**
- - * Creates {@link BertNLClassifier} from a model file with metadata and {@link
- - * BertNLClassifierOptions}.
- - *
- - * @param context Android context.
- - * @param modelPath Path to the classification model
- - * @param options to configure the classifier
- - * @return a {@link BertNLClassifier} instance
- - * @throws IOException If model file fails to load
- - * @throws IllegalArgumentException if an argument is invalid
- - * @throws IllegalStateException if there is an internal error
- - * @throws RuntimeException if there is an otherwise unspecified error
- - */
- - public static BertNLClassifier createFromFileAndOptions(
- - final Context context, final String modelPath, BertNLClassifierOptions options)
- - throws IOException {
- - return createFromBufferAndOptions(TaskJniUtils.loadMappedFile(context, modelPath), options);
- - }
- -
- - /**
- - * Creates {@link BertNLClassifier} from a {@link File} object with metadata and {@link
- - * BertNLClassifierOptions}.
- - *
- - * @param modelFile The classification model {@link File} instance
- - * @param options to configure the classifier
- - * @return a {@link BertNLClassifier} instance
- - * @throws IOException If model file fails to load
- - * @throws IllegalArgumentException if an argument is invalid
- - * @throws IllegalStateException if there is an internal error
- - * @throws RuntimeException if there is an otherwise unspecified error
- - */
- - public static BertNLClassifier createFromFileAndOptions(
- - File modelFile, final BertNLClassifierOptions options) throws IOException {
- - try (ParcelFileDescriptor descriptor =
- - ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
- - return new BertNLClassifier(
- - TaskJniUtils.createHandleFromLibrary(
- - new EmptyHandleProvider() {
- - @Override
- - public long createHandle() {
- - return initJniWithFileDescriptor(
- - descriptor.getFd(),
- - options,
- - TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions()));
- - }
- - },
- - BERT_NL_CLASSIFIER_NATIVE_LIBNAME));
- +
- + /**
- + * Constructor to initialize the JNI with a pointer from C++.
- + *
- + * @param nativeHandle a pointer referencing memory allocated in C++.
- + */
- + private BertNLClassifier(long nativeHandle) {
- + super(nativeHandle);
- }
- - }
- -
- - /**
- - * Creates {@link BertNLClassifier} with a model buffer and default {@link
- - * BertNLClassifierOptions}.
- - *
- - * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the model
- - * @return a {@link BertNLClassifier} instance
- - * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
- - * {@link MappedByteBuffer}
- - * @throws IllegalStateException if there is an internal error
- - * @throws RuntimeException if there is an otherwise unspecified error
- - */
- - public static BertNLClassifier createFromBuffer(final ByteBuffer modelBuffer) {
- - return createFromBufferAndOptions(modelBuffer, BertNLClassifierOptions.builder().build());
- - }
- -
- - /**
- - * Creates {@link BertNLClassifier} with a model buffer and {@link BertNLClassifierOptions}.
- - *
- - * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the model
- - * @param options to configure the classifier
- - * @return a {@link BertNLClassifier} instance
- - * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
- - * {@link MappedByteBuffer}
- - * @throws IllegalStateException if there is an internal error
- - * @throws RuntimeException if there is an otherwise unspecified error
- - */
- - public static BertNLClassifier createFromBufferAndOptions(
- - final ByteBuffer modelBuffer, final BertNLClassifierOptions options) {
- - if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) {
- - throw new IllegalArgumentException(
- - "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer.");
- +
- + private static native long initJniWithByteBuffer(
- + ByteBuffer modelBuffer, BertNLClassifierOptions options, long baseOptionsHandle);
- +
- + private static native long initJniWithFileDescriptor(
- + int fd, BertNLClassifierOptions options, long baseOptionsHandle);
- +
- + private static native List<Category> classifyNative(long nativeHandle, String text);
- +
- + @Override
- + protected void deinit(long nativeHandle) {
- + deinitJni(nativeHandle);
- }
- - return new BertNLClassifier(
- - TaskJniUtils.createHandleFromLibrary(
- - new EmptyHandleProvider() {
- - @Override
- - public long createHandle() {
- - return initJniWithByteBuffer(
- - modelBuffer,
- - options,
- - TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions()));
- - }
- - },
- - BERT_NL_CLASSIFIER_NATIVE_LIBNAME));
- - }
- -
- - /**
- - * Performs classification on a string input, returns classified {@link Category}s.
- - *
- - * @param text input text to the model.
- - * @return A list of Category results.
- - */
- - public List<Category> classify(String text) {
- - return classifyNative(getNativeHandle(), text);
- - }
- -
- - /**
- - * Constructor to initialize the JNI with a pointer from C++.
- - *
- - * @param nativeHandle a pointer referencing memory allocated in C++.
- - */
- - private BertNLClassifier(long nativeHandle) {
- - super(nativeHandle);
- - }
- -
- - private static native long initJniWithByteBuffer(
- - ByteBuffer modelBuffer, BertNLClassifierOptions options, long baseOptionsHandle);
- -
- - private static native long initJniWithFileDescriptor(
- - int fd, BertNLClassifierOptions options, long baseOptionsHandle);
- -
- - private static native List<Category> classifyNative(long nativeHandle, String text);
- -
- - @Override
- - protected void deinit(long nativeHandle) {
- - deinitJni(nativeHandle);
- - }
- -
- - /**
- - * Native implementation to release memory pointed by the pointer.
- - *
- - * @param nativeHandle pointer to memory allocated
- - */
- - private native void deinitJni(long nativeHandle);
- +
- + /**
- + * Native implementation to release memory pointed by the pointer.
- + *
- + * @param nativeHandle pointer to memory allocated
- + */
- + private native void deinitJni(long nativeHandle);
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/NLClassifier.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/NLClassifier.java
- index 19dcffca5e697..5c3eb2c9e3768 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/NLClassifier.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/NLClassifier.java
- @@ -17,13 +17,11 @@ package org.tensorflow.lite.task.text.nlclassifier;
-
- import android.content.Context;
- import android.os.ParcelFileDescriptor;
- +
- import androidx.annotation.Nullable;
- +
- import com.google.auto.value.AutoValue;
- -import java.io.File;
- -import java.io.IOException;
- -import java.nio.ByteBuffer;
- -import java.nio.MappedByteBuffer;
- -import java.util.List;
- +
- import org.tensorflow.lite.support.label.Category;
- import org.tensorflow.lite.task.core.BaseOptions;
- import org.tensorflow.lite.task.core.BaseTaskApi;
- @@ -31,6 +29,12 @@ import org.tensorflow.lite.task.core.TaskJniUtils;
- import org.tensorflow.lite.task.core.TaskJniUtils.EmptyHandleProvider;
- import org.tensorflow.lite.task.core.annotations.UsedByReflection;
-
- +import java.io.File;
- +import java.io.IOException;
- +import java.nio.ByteBuffer;
- +import java.nio.MappedByteBuffer;
- +import java.util.List;
- +
- /**
- * Classifier API for natural language classification tasks, categorizes string into different
- * classes.
- @@ -67,294 +71,296 @@ import org.tensorflow.lite.task.core.annotations.UsedByReflection;
- * configurable for different TFLite models.
- */
- public class NLClassifier extends BaseTaskApi {
- -
- - /** Options to identify input and output tensors of the model. */
- - @AutoValue
- - @UsedByReflection("nl_classifier_jni.cc")
- - public abstract static class NLClassifierOptions {
- - private static final int DEFAULT_INPUT_TENSOR_INDEX = 0;
- - private static final int DEFAULT_OUTPUT_SCORE_TENSOR_INDEX = 0;
- - // By default there is no output label tensor. The label file can be attached
- - // to the output score tensor metadata.
- - private static final int DEFAULT_OUTPUT_LABEL_TENSOR_INDEX = -1;
- - private static final String DEFAULT_INPUT_TENSOR_NAME = "INPUT";
- - private static final String DEFAULT_OUTPUT_SCORE_TENSOR_NAME = "OUTPUT_SCORE";
- - private static final String DEFAULT_OUTPUT_LABEL_TENSOR_NAME = "OUTPUT_LABEL";
- -
- - @UsedByReflection("nl_classifier_jni.cc")
- - abstract int getInputTensorIndex();
- -
- - @UsedByReflection("nl_classifier_jni.cc")
- - abstract int getOutputScoreTensorIndex();
- -
- + /** Options to identify input and output tensors of the model. */
- + @AutoValue
- @UsedByReflection("nl_classifier_jni.cc")
- - abstract int getOutputLabelTensorIndex();
- -
- - @UsedByReflection("nl_classifier_jni.cc")
- - abstract String getInputTensorName();
- + public abstract static class NLClassifierOptions {
- + private static final int DEFAULT_INPUT_TENSOR_INDEX = 0;
- + private static final int DEFAULT_OUTPUT_SCORE_TENSOR_INDEX = 0;
- + // By default there is no output label tensor. The label file can be attached
- + // to the output score tensor metadata.
- + private static final int DEFAULT_OUTPUT_LABEL_TENSOR_INDEX = -1;
- + private static final String DEFAULT_INPUT_TENSOR_NAME = "INPUT";
- + private static final String DEFAULT_OUTPUT_SCORE_TENSOR_NAME = "OUTPUT_SCORE";
- + private static final String DEFAULT_OUTPUT_LABEL_TENSOR_NAME = "OUTPUT_LABEL";
- +
- + @UsedByReflection("nl_classifier_jni.cc")
- + abstract int getInputTensorIndex();
- +
- + @UsedByReflection("nl_classifier_jni.cc")
- + abstract int getOutputScoreTensorIndex();
- +
- + @UsedByReflection("nl_classifier_jni.cc")
- + abstract int getOutputLabelTensorIndex();
- +
- + @UsedByReflection("nl_classifier_jni.cc")
- + abstract String getInputTensorName();
- +
- + @UsedByReflection("nl_classifier_jni.cc")
- + abstract String getOutputScoreTensorName();
- +
- + @UsedByReflection("nl_classifier_jni.cc")
- + abstract String getOutputLabelTensorName();
- +
- + @Nullable
- + abstract BaseOptions getBaseOptions();
- +
- + public static Builder builder() {
- + return new AutoValue_NLClassifier_NLClassifierOptions.Builder()
- + .setInputTensorIndex(DEFAULT_INPUT_TENSOR_INDEX)
- + .setOutputScoreTensorIndex(DEFAULT_OUTPUT_SCORE_TENSOR_INDEX)
- + .setOutputLabelTensorIndex(DEFAULT_OUTPUT_LABEL_TENSOR_INDEX)
- + .setInputTensorName(DEFAULT_INPUT_TENSOR_NAME)
- + .setOutputScoreTensorName(DEFAULT_OUTPUT_SCORE_TENSOR_NAME)
- + .setOutputLabelTensorName(DEFAULT_OUTPUT_LABEL_TENSOR_NAME);
- + }
- +
- + /** Builder for {@link NLClassifierOptions}. */
- + @AutoValue.Builder
- + public abstract static class Builder {
- + /** Sets the general options to configure Task APIs, such as accelerators. */
- + public abstract Builder setBaseOptions(@Nullable BaseOptions baseOptions);
- +
- + /**
- + * Configure the input/output tensors for NLClassifier:
- + *
- + * <p>- No special configuration is needed if the model has only one input tensor and
- + * one output tensor.
- + *
- + * <p>- When the model has multiple input or output tensors, use the following
- + * configurations to specifiy the desired tensors: <br>
- + * -- tensor names: {@code inputTensorName}, {@code outputScoreTensorName}, {@code
- + * outputLabelTensorName}<br>
- + * -- tensor indices: {@code inputTensorIndex}, {@code outputScoreTensorIndex}, {@code
- + * outputLabelTensorIndex} <br>
- + * Tensor names has higher priorities than tensor indices in locating the tensors. It
- + * means the tensors will be first located according to tensor names. If not found, then
- + * the tensors will be located according to tensor indices.
- + *
- + * <p>- Failing to match the input text tensor or output score tensor with neither
- + * tensor names nor tensor indices will trigger a runtime error. However, failing to
- + * locate the output label tensor will not trigger an error because the label tensor is
- + * optional.
- + */
- +
- + /**
- + * Set the name of the input text tensor, if the model has multiple inputs. Only the
- + * input tensor specified will be used for inference; other input tensors will be
- + * ignored. Dafualt to {@code "INPUT"}.
- + *
- + * <p>See the section, Configure the input/output tensors for NLClassifier, for more
- + * details.
- + */
- + public abstract Builder setInputTensorName(String inputTensorName);
- +
- + /**
- + * Set the name of the output score tensor, if the model has multiple outputs. Dafualt
- + * to
- + * {@code "OUTPUT_SCORE"}.
- + *
- + * <p>See the section, Configure the input/output tensors for NLClassifier, for more
- + * details.
- + */
- + public abstract Builder setOutputScoreTensorName(String outputScoreTensorName);
- +
- + /**
- + * Set the name of the output label tensor, if the model has multiple outputs. Dafualt
- + * to
- + * {@code "OUTPUT_LABEL"}.
- + *
- + * <p>See the section, Configure the input/output tensors for NLClassifier, for more
- + * details.
- + *
- + * <p>By default, label file should be packed with the output score tensor through Model
- + * Metadata. See the <a
- + * href="https://www.tensorflow.org/lite/convert/metadata_writer_tutorial#natural_language_classifiers">MetadataWriter
- + * for NLClassifier</a>. NLClassifier reads and parses labels from the label file
- + * automatically. However, some models may output a specific label tensor instead. In
- + * this case, NLClassifier reads labels from the output label tensor.
- + */
- + public abstract Builder setOutputLabelTensorName(String outputLabelTensorName);
- +
- + /**
- + * Set the index of the input text tensor among all input tensors, if the model has
- + * multiple inputs. Only the input tensor specified will be used for inference; other
- + * input tensors will be ignored. Dafualt to 0.
- + *
- + * <p>See the section, Configure the input/output tensors for NLClassifier, for more
- + * details.
- + */
- + public abstract Builder setInputTensorIndex(int inputTensorIndex);
- +
- + /**
- + * Set the index of the output score tensor among all output tensors, if the model has
- + * multiple outputs. Dafualt to 0.
- + *
- + * <p>See the section, Configure the input/output tensors for NLClassifier, for more
- + * details.
- + */
- + public abstract Builder setOutputScoreTensorIndex(int outputScoreTensorIndex);
- +
- + /**
- + * Set the index of the optional output label tensor among all output tensors, if the
- + * model has multiple outputs.
- + *
- + * <p>See the document above {@code outputLabelTensorName} for more information about
- + * what the output label tensor is.
- + *
- + * <p>See the section, Configure the input/output tensors for NLClassifier, for more
- + * details.
- + *
- + * <p>{@code outputLabelTensorIndex} dafualts to -1, meaning to disable the output label
- + * tensor.
- + */
- + public abstract Builder setOutputLabelTensorIndex(int outputLabelTensorIndex);
- +
- + public abstract NLClassifierOptions build();
- + }
- + }
-
- - @UsedByReflection("nl_classifier_jni.cc")
- - abstract String getOutputScoreTensorName();
- + private static final String NL_CLASSIFIER_NATIVE_LIBNAME = "task_text_jni";
- +
- + /**
- + * Creates {@link NLClassifier} from default {@link NLClassifierOptions}.
- + *
- + * @param context Android context
- + * @param modelPath path to the classification model relative to asset dir
- + * @return an {@link NLClassifier} instance
- + * @throws IOException if model file fails to load
- + * @throws IllegalArgumentException if an argument is invalid
- + * @throws IllegalStateException if there is an internal error
- + * @throws RuntimeException if there is an otherwise unspecified error
- + */
- + public static NLClassifier createFromFile(Context context, String modelPath)
- + throws IOException {
- + return createFromFileAndOptions(context, modelPath, NLClassifierOptions.builder().build());
- + }
-
- - @UsedByReflection("nl_classifier_jni.cc")
- - abstract String getOutputLabelTensorName();
- -
- - @Nullable
- - abstract BaseOptions getBaseOptions();
- -
- - public static Builder builder() {
- - return new AutoValue_NLClassifier_NLClassifierOptions.Builder()
- - .setInputTensorIndex(DEFAULT_INPUT_TENSOR_INDEX)
- - .setOutputScoreTensorIndex(DEFAULT_OUTPUT_SCORE_TENSOR_INDEX)
- - .setOutputLabelTensorIndex(DEFAULT_OUTPUT_LABEL_TENSOR_INDEX)
- - .setInputTensorName(DEFAULT_INPUT_TENSOR_NAME)
- - .setOutputScoreTensorName(DEFAULT_OUTPUT_SCORE_TENSOR_NAME)
- - .setOutputLabelTensorName(DEFAULT_OUTPUT_LABEL_TENSOR_NAME);
- + /**
- + * Creates {@link NLClassifier} from default {@link NLClassifierOptions}.
- + *
- + * @param modelFile the classification model {@link File} instance
- + * @return an {@link NLClassifier} instance
- + * @throws IOException if model file fails to load
- + * @throws IllegalArgumentException if an argument is invalid
- + * @throws IllegalStateException if there is an internal error
- + * @throws RuntimeException if there is an otherwise unspecified error
- + */
- + public static NLClassifier createFromFile(File modelFile) throws IOException {
- + return createFromFileAndOptions(modelFile, NLClassifierOptions.builder().build());
- }
-
- - /** Builder for {@link NLClassifierOptions}. */
- - @AutoValue.Builder
- - public abstract static class Builder {
- - /** Sets the general options to configure Task APIs, such as accelerators. */
- - public abstract Builder setBaseOptions(@Nullable BaseOptions baseOptions);
- -
- - /**
- - * Configure the input/output tensors for NLClassifier:
- - *
- - * <p>- No special configuration is needed if the model has only one input tensor and one
- - * output tensor.
- - *
- - * <p>- When the model has multiple input or output tensors, use the following configurations
- - * to specifiy the desired tensors: <br>
- - * -- tensor names: {@code inputTensorName}, {@code outputScoreTensorName}, {@code
- - * outputLabelTensorName}<br>
- - * -- tensor indices: {@code inputTensorIndex}, {@code outputScoreTensorIndex}, {@code
- - * outputLabelTensorIndex} <br>
- - * Tensor names has higher priorities than tensor indices in locating the tensors. It means
- - * the tensors will be first located according to tensor names. If not found, then the tensors
- - * will be located according to tensor indices.
- - *
- - * <p>- Failing to match the input text tensor or output score tensor with neither tensor
- - * names nor tensor indices will trigger a runtime error. However, failing to locate the
- - * output label tensor will not trigger an error because the label tensor is optional.
- - */
- -
- - /**
- - * Set the name of the input text tensor, if the model has multiple inputs. Only the input
- - * tensor specified will be used for inference; other input tensors will be ignored. Dafualt
- - * to {@code "INPUT"}.
- - *
- - * <p>See the section, Configure the input/output tensors for NLClassifier, for more details.
- - */
- - public abstract Builder setInputTensorName(String inputTensorName);
- -
- - /**
- - * Set the name of the output score tensor, if the model has multiple outputs. Dafualt to
- - * {@code "OUTPUT_SCORE"}.
- - *
- - * <p>See the section, Configure the input/output tensors for NLClassifier, for more details.
- - */
- - public abstract Builder setOutputScoreTensorName(String outputScoreTensorName);
- -
- - /**
- - * Set the name of the output label tensor, if the model has multiple outputs. Dafualt to
- - * {@code "OUTPUT_LABEL"}.
- - *
- - * <p>See the section, Configure the input/output tensors for NLClassifier, for more details.
- - *
- - * <p>By default, label file should be packed with the output score tensor through Model
- - * Metadata. See the <a
- - * href="https://www.tensorflow.org/lite/convert/metadata_writer_tutorial#natural_language_classifiers">MetadataWriter
- - * for NLClassifier</a>. NLClassifier reads and parses labels from the label file
- - * automatically. However, some models may output a specific label tensor instead. In this
- - * case, NLClassifier reads labels from the output label tensor.
- - */
- - public abstract Builder setOutputLabelTensorName(String outputLabelTensorName);
- -
- - /**
- - * Set the index of the input text tensor among all input tensors, if the model has multiple
- - * inputs. Only the input tensor specified will be used for inference; other input tensors
- - * will be ignored. Dafualt to 0.
- - *
- - * <p>See the section, Configure the input/output tensors for NLClassifier, for more details.
- - */
- - public abstract Builder setInputTensorIndex(int inputTensorIndex);
- -
- - /**
- - * Set the index of the output score tensor among all output tensors, if the model has
- - * multiple outputs. Dafualt to 0.
- - *
- - * <p>See the section, Configure the input/output tensors for NLClassifier, for more details.
- - */
- - public abstract Builder setOutputScoreTensorIndex(int outputScoreTensorIndex);
- -
- - /**
- - * Set the index of the optional output label tensor among all output tensors, if the model
- - * has multiple outputs.
- - *
- - * <p>See the document above {@code outputLabelTensorName} for more information about what the
- - * output label tensor is.
- - *
- - * <p>See the section, Configure the input/output tensors for NLClassifier, for more details.
- - *
- - * <p>{@code outputLabelTensorIndex} dafualts to -1, meaning to disable the output label
- - * tensor.
- - */
- - public abstract Builder setOutputLabelTensorIndex(int outputLabelTensorIndex);
- -
- - public abstract NLClassifierOptions build();
- + /**
- + * Creates {@link NLClassifier} from {@link NLClassifierOptions}.
- + *
- + * @param context Android context
- + * @param modelPath path to the classification model relative to asset dir
- + * @param options configurations for the model.
- + * @return an {@link NLClassifier} instance
- + * @throws IOException if model file fails to load
- + * @throws IllegalArgumentException if an argument is invalid
- + * @throws IllegalStateException if there is an internal error
- + * @throws RuntimeException if there is an otherwise unspecified error
- + */
- + public static NLClassifier createFromFileAndOptions(
- + Context context, String modelPath, NLClassifierOptions options) throws IOException {
- + return createFromBufferAndOptions(TaskJniUtils.loadMappedFile(context, modelPath), options);
- }
- - }
- -
- - private static final String NL_CLASSIFIER_NATIVE_LIBNAME = "task_text_jni";
- -
- - /**
- - * Creates {@link NLClassifier} from default {@link NLClassifierOptions}.
- - *
- - * @param context Android context
- - * @param modelPath path to the classification model relative to asset dir
- - * @return an {@link NLClassifier} instance
- - * @throws IOException if model file fails to load
- - * @throws IllegalArgumentException if an argument is invalid
- - * @throws IllegalStateException if there is an internal error
- - * @throws RuntimeException if there is an otherwise unspecified error
- - */
- - public static NLClassifier createFromFile(Context context, String modelPath) throws IOException {
- - return createFromFileAndOptions(context, modelPath, NLClassifierOptions.builder().build());
- - }
- -
- - /**
- - * Creates {@link NLClassifier} from default {@link NLClassifierOptions}.
- - *
- - * @param modelFile the classification model {@link File} instance
- - * @return an {@link NLClassifier} instance
- - * @throws IOException if model file fails to load
- - * @throws IllegalArgumentException if an argument is invalid
- - * @throws IllegalStateException if there is an internal error
- - * @throws RuntimeException if there is an otherwise unspecified error
- - */
- - public static NLClassifier createFromFile(File modelFile) throws IOException {
- - return createFromFileAndOptions(modelFile, NLClassifierOptions.builder().build());
- - }
- -
- - /**
- - * Creates {@link NLClassifier} from {@link NLClassifierOptions}.
- - *
- - * @param context Android context
- - * @param modelPath path to the classification model relative to asset dir
- - * @param options configurations for the model.
- - * @return an {@link NLClassifier} instance
- - * @throws IOException if model file fails to load
- - * @throws IllegalArgumentException if an argument is invalid
- - * @throws IllegalStateException if there is an internal error
- - * @throws RuntimeException if there is an otherwise unspecified error
- - */
- - public static NLClassifier createFromFileAndOptions(
- - Context context, String modelPath, NLClassifierOptions options) throws IOException {
- - return createFromBufferAndOptions(TaskJniUtils.loadMappedFile(context, modelPath), options);
- - }
- -
- - /**
- - * Creates {@link NLClassifier} from {@link NLClassifierOptions}.
- - *
- - * @param modelFile the classification model {@link File} instance
- - * @param options configurations for the model
- - * @return an {@link NLClassifier} instance
- - * @throws IOException if model file fails to load
- - * @throws IllegalArgumentException if an argument is invalid
- - * @throws IllegalStateException if there is an internal error
- - * @throws RuntimeException if there is an otherwise unspecified error
- - */
- - public static NLClassifier createFromFileAndOptions(
- - File modelFile, final NLClassifierOptions options) throws IOException {
- - try (ParcelFileDescriptor descriptor =
- - ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
- - return new NLClassifier(
- - TaskJniUtils.createHandleFromLibrary(
- - new EmptyHandleProvider() {
- +
- + /**
- + * Creates {@link NLClassifier} from {@link NLClassifierOptions}.
- + *
- + * @param modelFile the classification model {@link File} instance
- + * @param options configurations for the model
- + * @return an {@link NLClassifier} instance
- + * @throws IOException if model file fails to load
- + * @throws IllegalArgumentException if an argument is invalid
- + * @throws IllegalStateException if there is an internal error
- + * @throws RuntimeException if there is an otherwise unspecified error
- + */
- + public static NLClassifier createFromFileAndOptions(
- + File modelFile, final NLClassifierOptions options) throws IOException {
- + try (ParcelFileDescriptor descriptor =
- + ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
- + return new NLClassifier(TaskJniUtils.createHandleFromLibrary(new EmptyHandleProvider() {
- @Override
- public long createHandle() {
- - long baseOptionsHandle =
- - options.getBaseOptions() == null
- - ? 0 // pass an invalid native handle
- - : TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions());
- - return initJniWithFileDescriptor(options, descriptor.getFd(), baseOptionsHandle);
- + long baseOptionsHandle = options.getBaseOptions() == null
- + ? 0 // pass an invalid native handle
- + : TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions());
- + return initJniWithFileDescriptor(
- + options, descriptor.getFd(), baseOptionsHandle);
- }
- - },
- - NL_CLASSIFIER_NATIVE_LIBNAME));
- - }
- - }
- -
- - /**
- - * Creates {@link NLClassifier} with a model {@link ByteBuffer} and {@link NLClassifierOptions}.
- - *
- - * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the
- - * classification model
- - * @param options configurations for the model
- - * @return {@link NLClassifier} instance
- - * @throws IllegalStateException if there is an internal error
- - * @throws RuntimeException if there is an otherwise unspecified error
- - * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
- - * {@link MappedByteBuffer}
- - */
- - public static NLClassifier createFromBufferAndOptions(
- - final ByteBuffer modelBuffer, final NLClassifierOptions options) {
- - if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) {
- - throw new IllegalArgumentException(
- - "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer.");
- + }, NL_CLASSIFIER_NATIVE_LIBNAME));
- + }
- }
-
- - return new NLClassifier(
- - TaskJniUtils.createHandleFromLibrary(
- - new EmptyHandleProvider() {
- - @Override
- - public long createHandle() {
- - long baseOptionsHandle =
- - options.getBaseOptions() == null
- + /**
- + * Creates {@link NLClassifier} with a model {@link ByteBuffer} and {@link NLClassifierOptions}.
- + *
- + * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the
- + * classification model
- + * @param options configurations for the model
- + * @return {@link NLClassifier} instance
- + * @throws IllegalStateException if there is an internal error
- + * @throws RuntimeException if there is an otherwise unspecified error
- + * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
- + * {@link MappedByteBuffer}
- + */
- + public static NLClassifier createFromBufferAndOptions(
- + final ByteBuffer modelBuffer, final NLClassifierOptions options) {
- + if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) {
- + throw new IllegalArgumentException(
- + "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer.");
- + }
- +
- + return new NLClassifier(TaskJniUtils.createHandleFromLibrary(new EmptyHandleProvider() {
- + @Override
- + public long createHandle() {
- + long baseOptionsHandle = options.getBaseOptions() == null
- ? 0 // pass an invalid native handle
- : TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions());
- return initJniWithByteBuffer(options, modelBuffer, baseOptionsHandle);
- - }
- - },
- - NL_CLASSIFIER_NATIVE_LIBNAME));
- - }
- -
- - /**
- - * Performs classification on a string input, returns classified {@link Category}s.
- - *
- - * @param text input text to the model
- - * @return a list of Category results
- - */
- - public List<Category> classify(String text) {
- - return classifyNative(getNativeHandle(), text);
- - }
- -
- - /**
- - * Constructor to initialize the JNI with a pointer from C++.
- - *
- - * @param nativeHandle a pointer referencing memory allocated in C++.
- - */
- - protected NLClassifier(long nativeHandle) {
- - super(nativeHandle);
- - }
- -
- - @Override
- - protected void deinit(long nativeHandle) {
- - deinitJni(nativeHandle);
- - }
- -
- - private static native long initJniWithByteBuffer(
- - NLClassifierOptions options, ByteBuffer modelBuffer, long baseOptionsHandle);
- -
- - private static native long initJniWithFileDescriptor(
- - NLClassifierOptions options, int fd, long baseOptionsHandle);
- -
- - private static native List<Category> classifyNative(long nativeHandle, String text);
- -
- - /**
- - * Native implementation to release memory pointed by the pointer.
- - *
- - * @param nativeHandle pointer to memory allocated
- - */
- - private native void deinitJni(long nativeHandle);
- + }
- + }, NL_CLASSIFIER_NATIVE_LIBNAME));
- + }
- +
- + /**
- + * Performs classification on a string input, returns classified {@link Category}s.
- + *
- + * @param text input text to the model
- + * @return a list of Category results
- + */
- + public List<Category> classify(String text) {
- + return classifyNative(getNativeHandle(), text);
- + }
- +
- + /**
- + * Constructor to initialize the JNI with a pointer from C++.
- + *
- + * @param nativeHandle a pointer referencing memory allocated in C++.
- + */
- + protected NLClassifier(long nativeHandle) {
- + super(nativeHandle);
- + }
- +
- + @Override
- + protected void deinit(long nativeHandle) {
- + deinitJni(nativeHandle);
- + }
- +
- + private static native long initJniWithByteBuffer(
- + NLClassifierOptions options, ByteBuffer modelBuffer, long baseOptionsHandle);
- +
- + private static native long initJniWithFileDescriptor(
- + NLClassifierOptions options, int fd, long baseOptionsHandle);
- +
- + private static native List<Category> classifyNative(long nativeHandle, String text);
- +
- + /**
- + * Native implementation to release memory pointed by the pointer.
- + *
- + * @param nativeHandle pointer to memory allocated
- + */
- + private native void deinitJni(long nativeHandle);
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/BertQuestionAnswerer.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/BertQuestionAnswerer.java
- index aafa2c88c55e8..39648d9bb4042 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/BertQuestionAnswerer.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/BertQuestionAnswerer.java
- @@ -17,11 +17,9 @@ package org.tensorflow.lite.task.text.qa;
-
- import android.content.Context;
- import android.os.ParcelFileDescriptor;
- +
- import com.google.auto.value.AutoValue;
- -import java.io.File;
- -import java.io.IOException;
- -import java.nio.ByteBuffer;
- -import java.util.List;
- +
- import org.tensorflow.lite.task.core.BaseOptions;
- import org.tensorflow.lite.task.core.BaseTaskApi;
- import org.tensorflow.lite.task.core.TaskJniUtils;
- @@ -29,6 +27,11 @@ import org.tensorflow.lite.task.core.TaskJniUtils.EmptyHandleProvider;
- import org.tensorflow.lite.task.core.TaskJniUtils.FdAndOptionsHandleProvider;
- import org.tensorflow.lite.task.core.TaskJniUtils.MultipleBuffersHandleProvider;
-
- +import java.io.File;
- +import java.io.IOException;
- +import java.nio.ByteBuffer;
- +import java.util.List;
- +
- /**
- * Returns the most possible answers on a given question for QA models (BERT, Albert, etc.).
- *
- @@ -45,225 +48,204 @@ import org.tensorflow.lite.task.core.TaskJniUtils.MultipleBuffersHandleProvider;
- * </ul>
- */
- public class BertQuestionAnswerer extends BaseTaskApi implements QuestionAnswerer {
- - private static final String BERT_QUESTION_ANSWERER_NATIVE_LIBNAME = "task_text_jni";
- - private static final int OPTIONAL_FD_LENGTH = -1;
- - private static final int OPTIONAL_FD_OFFSET = -1;
- -
- - /**
- - * Creates a {@link BertQuestionAnswerer} instance from the default {@link
- - * BertQuestionAnswererOptions}.
- - *
- - * @param context android context
- - * @param modelPath file path to the model with metadata. Note: The model should not be compressed
- - * @return a {@link BertQuestionAnswerer} instance
- - * @throws IOException if model file fails to load
- - * @throws IllegalArgumentException if an argument is invalid
- - * @throws IllegalStateException if there is an internal error
- - * @throws RuntimeException if there is an otherwise unspecified error
- - */
- - public static BertQuestionAnswerer createFromFile(Context context, String modelPath)
- - throws IOException {
- - return createFromFileAndOptions(
- - context, modelPath, BertQuestionAnswererOptions.builder().build());
- - }
- + private static final String BERT_QUESTION_ANSWERER_NATIVE_LIBNAME = "task_text_jni";
- + private static final int OPTIONAL_FD_LENGTH = -1;
- + private static final int OPTIONAL_FD_OFFSET = -1;
- +
- + /**
- + * Creates a {@link BertQuestionAnswerer} instance from the default {@link
- + * BertQuestionAnswererOptions}.
- + *
- + * @param context android context
- + * @param modelPath file path to the model with metadata. Note: The model should not be
- + * compressed
- + * @return a {@link BertQuestionAnswerer} instance
- + * @throws IOException if model file fails to load
- + * @throws IllegalArgumentException if an argument is invalid
- + * @throws IllegalStateException if there is an internal error
- + * @throws RuntimeException if there is an otherwise unspecified error
- + */
- + public static BertQuestionAnswerer createFromFile(Context context, String modelPath)
- + throws IOException {
- + return createFromFileAndOptions(
- + context, modelPath, BertQuestionAnswererOptions.builder().build());
- + }
-
- - /**
- - * Creates a {@link BertQuestionAnswerer} instance from the default {@link
- - * BertQuestionAnswererOptions}.
- - *
- - * @param modelFile a {@link File} object of the model
- - * @return a {@link BertQuestionAnswerer} instance
- - * @throws IOException if model file fails to load
- - * @throws IllegalArgumentException if an argument is invalid
- - * @throws IllegalStateException if there is an internal error
- - * @throws RuntimeException if there is an otherwise unspecified error
- - */
- - public static BertQuestionAnswerer createFromFile(File modelFile) throws IOException {
- - return createFromFileAndOptions(modelFile, BertQuestionAnswererOptions.builder().build());
- - }
- + /**
- + * Creates a {@link BertQuestionAnswerer} instance from the default {@link
- + * BertQuestionAnswererOptions}.
- + *
- + * @param modelFile a {@link File} object of the model
- + * @return a {@link BertQuestionAnswerer} instance
- + * @throws IOException if model file fails to load
- + * @throws IllegalArgumentException if an argument is invalid
- + * @throws IllegalStateException if there is an internal error
- + * @throws RuntimeException if there is an otherwise unspecified error
- + */
- + public static BertQuestionAnswerer createFromFile(File modelFile) throws IOException {
- + return createFromFileAndOptions(modelFile, BertQuestionAnswererOptions.builder().build());
- + }
-
- - /**
- - * Creates a {@link BertQuestionAnswerer} instance from {@link BertQuestionAnswererOptions}.
- - *
- - * @param context android context
- - * @param modelPath file path to the model with metadata. Note: The model should not be compressed
- - * @return a {@link BertQuestionAnswerer} instance
- - * @throws IOException if model file fails to load
- - * @throws IllegalArgumentException if an argument is invalid
- - * @throws IllegalStateException if there is an internal error
- - * @throws RuntimeException if there is an otherwise unspecified error
- - */
- - public static BertQuestionAnswerer createFromFileAndOptions(
- - Context context, String modelPath, BertQuestionAnswererOptions options) throws IOException {
- - return new BertQuestionAnswerer(
- - TaskJniUtils.createHandleFromFdAndOptions(
- - context,
- - new FdAndOptionsHandleProvider<BertQuestionAnswererOptions>() {
- - @Override
- - public long createHandle(
- - int fileDescriptor,
- - long fileDescriptorLength,
- - long fileDescriptorOffset,
- - BertQuestionAnswererOptions options) {
- - return initJniWithFileDescriptor(
- - fileDescriptor,
- - fileDescriptorLength,
- - fileDescriptorOffset,
- - TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions()));
- - }
- - },
- - BERT_QUESTION_ANSWERER_NATIVE_LIBNAME,
- - modelPath,
- - options));
- - }
- + /**
- + * Creates a {@link BertQuestionAnswerer} instance from {@link BertQuestionAnswererOptions}.
- + *
- + * @param context android context
- + * @param modelPath file path to the model with metadata. Note: The model should not be
- + * compressed
- + * @return a {@link BertQuestionAnswerer} instance
- + * @throws IOException if model file fails to load
- + * @throws IllegalArgumentException if an argument is invalid
- + * @throws IllegalStateException if there is an internal error
- + * @throws RuntimeException if there is an otherwise unspecified error
- + */
- + public static BertQuestionAnswerer createFromFileAndOptions(Context context, String modelPath,
- + BertQuestionAnswererOptions options) throws IOException {
- + return new BertQuestionAnswerer(TaskJniUtils.createHandleFromFdAndOptions(
- + context, new FdAndOptionsHandleProvider<BertQuestionAnswererOptions>() {
- + @Override
- + public long createHandle(int fileDescriptor, long fileDescriptorLength,
- + long fileDescriptorOffset, BertQuestionAnswererOptions options) {
- + return initJniWithFileDescriptor(fileDescriptor, fileDescriptorLength,
- + fileDescriptorOffset,
- + TaskJniUtils.createProtoBaseOptionsHandle(
- + options.getBaseOptions()));
- + }
- + }, BERT_QUESTION_ANSWERER_NATIVE_LIBNAME, modelPath, options));
- + }
-
- - /**
- - * Creates a {@link BertQuestionAnswerer} instance from {@link BertQuestionAnswererOptions}.
- - *
- - * @param modelFile a {@link File} object of the model
- - * @return a {@link BertQuestionAnswerer} instance
- - * @throws IOException if model file fails to load
- - * @throws IllegalArgumentException if an argument is invalid
- - * @throws IllegalStateException if there is an internal error
- - * @throws RuntimeException if there is an otherwise unspecified error
- - */
- - public static BertQuestionAnswerer createFromFileAndOptions(
- - File modelFile, final BertQuestionAnswererOptions options) throws IOException {
- - try (ParcelFileDescriptor descriptor =
- - ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
- - return new BertQuestionAnswerer(
- - TaskJniUtils.createHandleFromLibrary(
- - new EmptyHandleProvider() {
- - @Override
- - public long createHandle() {
- - return initJniWithFileDescriptor(
- - /*fileDescriptor=*/ descriptor.getFd(),
- - /*fileDescriptorLength=*/ OPTIONAL_FD_LENGTH,
- - /*fileDescriptorOffset=*/ OPTIONAL_FD_OFFSET,
- - TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions()));
- - }
- - },
- - BERT_QUESTION_ANSWERER_NATIVE_LIBNAME));
- + /**
- + * Creates a {@link BertQuestionAnswerer} instance from {@link BertQuestionAnswererOptions}.
- + *
- + * @param modelFile a {@link File} object of the model
- + * @return a {@link BertQuestionAnswerer} instance
- + * @throws IOException if model file fails to load
- + * @throws IllegalArgumentException if an argument is invalid
- + * @throws IllegalStateException if there is an internal error
- + * @throws RuntimeException if there is an otherwise unspecified error
- + */
- + public static BertQuestionAnswerer createFromFileAndOptions(
- + File modelFile, final BertQuestionAnswererOptions options) throws IOException {
- + try (ParcelFileDescriptor descriptor =
- + ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
- + return new BertQuestionAnswerer(
- + TaskJniUtils.createHandleFromLibrary(new EmptyHandleProvider() {
- + @Override
- + public long createHandle() {
- + return initJniWithFileDescriptor(
- + /*fileDescriptor=*/descriptor.getFd(),
- + /*fileDescriptorLength=*/OPTIONAL_FD_LENGTH,
- + /*fileDescriptorOffset=*/OPTIONAL_FD_OFFSET,
- + TaskJniUtils.createProtoBaseOptionsHandle(
- + options.getBaseOptions()));
- + }
- + }, BERT_QUESTION_ANSWERER_NATIVE_LIBNAME));
- + }
- }
- - }
-
- - /**
- - * Creates a {@link BertQuestionAnswerer} instance with a Bert model and a vocabulary file.
- - *
- - * <p>One suitable model is: https://tfhub.dev/tensorflow/lite-model/mobilebert/1/default/1
- - *
- - * @param context android context
- - * @param modelPath file path to the Bert model. Note: The model should not be compressed
- - * @param vocabPath file path to the vocabulary file. Note: The file should not be compressed
- - * @return a {@link BertQuestionAnswerer} instance
- - * @throws IOException If model file fails to load
- - * @throws IllegalArgumentException if an argument is invalid
- - * @throws IllegalStateException if there is an internal error
- - * @throws RuntimeException if there is an otherwise unspecified error
- - */
- - public static BertQuestionAnswerer createBertQuestionAnswererFromFile(
- - Context context, String modelPath, String vocabPath) throws IOException {
- - return new BertQuestionAnswerer(
- - TaskJniUtils.createHandleWithMultipleAssetFilesFromLibrary(
- - context,
- - new MultipleBuffersHandleProvider() {
- - @Override
- - public long createHandle(ByteBuffer... buffers) {
- - return initJniWithBertByteBuffers(buffers);
- - }
- - },
- - BERT_QUESTION_ANSWERER_NATIVE_LIBNAME,
- - modelPath,
- - vocabPath));
- - }
- + /**
- + * Creates a {@link BertQuestionAnswerer} instance with a Bert model and a vocabulary file.
- + *
- + * <p>One suitable model is: https://tfhub.dev/tensorflow/lite-model/mobilebert/1/default/1
- + *
- + * @param context android context
- + * @param modelPath file path to the Bert model. Note: The model should not be compressed
- + * @param vocabPath file path to the vocabulary file. Note: The file should not be compressed
- + * @return a {@link BertQuestionAnswerer} instance
- + * @throws IOException If model file fails to load
- + * @throws IllegalArgumentException if an argument is invalid
- + * @throws IllegalStateException if there is an internal error
- + * @throws RuntimeException if there is an otherwise unspecified error
- + */
- + public static BertQuestionAnswerer createBertQuestionAnswererFromFile(
- + Context context, String modelPath, String vocabPath) throws IOException {
- + return new BertQuestionAnswerer(TaskJniUtils.createHandleWithMultipleAssetFilesFromLibrary(
- + context, new MultipleBuffersHandleProvider() {
- + @Override
- + public long createHandle(ByteBuffer... buffers) {
- + return initJniWithBertByteBuffers(buffers);
- + }
- + }, BERT_QUESTION_ANSWERER_NATIVE_LIBNAME, modelPath, vocabPath));
- + }
-
- - /**
- - * Creates a {@link BertQuestionAnswerer} instance with an Albert model and a sentence piece model
- - * file.
- - *
- - * <p>One suitable model is: https://tfhub.dev/tensorflow/lite-model/albert_lite_base/squadv1/1
- - *
- - * @param context android context
- - * @param modelPath file path to the Albert model. Note: The model should not be compressed
- - * @param sentencePieceModelPath file path to the sentence piece model file. Note: The model
- - * should not be compressed
- - * @return a {@link BertQuestionAnswerer} instance
- - * @throws IOException If model file fails to load
- - * @throws IllegalArgumentException if an argument is invalid
- - * @throws IllegalStateException if there is an internal error
- - * @throws RuntimeException if there is an otherwise unspecified error
- - */
- - public static BertQuestionAnswerer createAlbertQuestionAnswererFromFile(
- - Context context, String modelPath, String sentencePieceModelPath) throws IOException {
- - return new BertQuestionAnswerer(
- - TaskJniUtils.createHandleWithMultipleAssetFilesFromLibrary(
- - context,
- - new MultipleBuffersHandleProvider() {
- - @Override
- - public long createHandle(ByteBuffer... buffers) {
- - return initJniWithAlbertByteBuffers(buffers);
- - }
- - },
- - BERT_QUESTION_ANSWERER_NATIVE_LIBNAME,
- - modelPath,
- - sentencePieceModelPath));
- - }
- + /**
- + * Creates a {@link BertQuestionAnswerer} instance with an Albert model and a sentence piece
- + * model file.
- + *
- + * <p>One suitable model is: https://tfhub.dev/tensorflow/lite-model/albert_lite_base/squadv1/1
- + *
- + * @param context android context
- + * @param modelPath file path to the Albert model. Note: The model should not be compressed
- + * @param sentencePieceModelPath file path to the sentence piece model file. Note: The model
- + * should not be compressed
- + * @return a {@link BertQuestionAnswerer} instance
- + * @throws IOException If model file fails to load
- + * @throws IllegalArgumentException if an argument is invalid
- + * @throws IllegalStateException if there is an internal error
- + * @throws RuntimeException if there is an otherwise unspecified error
- + */
- + public static BertQuestionAnswerer createAlbertQuestionAnswererFromFile(
- + Context context, String modelPath, String sentencePieceModelPath) throws IOException {
- + return new BertQuestionAnswerer(TaskJniUtils.createHandleWithMultipleAssetFilesFromLibrary(
- + context, new MultipleBuffersHandleProvider() {
- + @Override
- + public long createHandle(ByteBuffer... buffers) {
- + return initJniWithAlbertByteBuffers(buffers);
- + }
- + }, BERT_QUESTION_ANSWERER_NATIVE_LIBNAME, modelPath, sentencePieceModelPath));
- + }
-
- - /** Options for setting up a {@link BertQuestionAnswerer}. */
- - @AutoValue
- - public abstract static class BertQuestionAnswererOptions {
- - abstract BaseOptions getBaseOptions();
- + /** Options for setting up a {@link BertQuestionAnswerer}. */
- + @AutoValue
- + public abstract static class BertQuestionAnswererOptions {
- + abstract BaseOptions getBaseOptions();
-
- - public static Builder builder() {
- - return new AutoValue_BertQuestionAnswerer_BertQuestionAnswererOptions.Builder()
- - .setBaseOptions(BaseOptions.builder().build());
- - }
- + public static Builder builder() {
- + return new AutoValue_BertQuestionAnswerer_BertQuestionAnswererOptions.Builder()
- + .setBaseOptions(BaseOptions.builder().build());
- + }
-
- - /** Builder for {@link BertQuestionAnswererOptions}. */
- - @AutoValue.Builder
- - public abstract static class Builder {
- - /** Sets the general options to configure Task APIs, such as accelerators. */
- - public abstract Builder setBaseOptions(BaseOptions baseOptions);
- + /** Builder for {@link BertQuestionAnswererOptions}. */
- + @AutoValue.Builder
- + public abstract static class Builder {
- + /** Sets the general options to configure Task APIs, such as accelerators. */
- + public abstract Builder setBaseOptions(BaseOptions baseOptions);
-
- - public abstract BertQuestionAnswererOptions build();
- + public abstract BertQuestionAnswererOptions build();
- + }
- }
- - }
-
- - @Override
- - public List<QaAnswer> answer(String context, String question) {
- - checkNotClosed();
- - return answerNative(getNativeHandle(), context, question);
- - }
- + @Override
- + public List<QaAnswer> answer(String context, String question) {
- + checkNotClosed();
- + return answerNative(getNativeHandle(), context, question);
- + }
-
- - private BertQuestionAnswerer(long nativeHandle) {
- - super(nativeHandle);
- - }
- + private BertQuestionAnswerer(long nativeHandle) {
- + super(nativeHandle);
- + }
-
- - // modelBuffers[0] is tflite model file buffer, and modelBuffers[1] is vocab file buffer.
- - private static native long initJniWithBertByteBuffers(ByteBuffer... modelBuffers);
- + // modelBuffers[0] is tflite model file buffer, and modelBuffers[1] is vocab file buffer.
- + private static native long initJniWithBertByteBuffers(ByteBuffer... modelBuffers);
-
- - // modelBuffers[0] is tflite model file buffer, and modelBuffers[1] is sentencepiece model file
- - // buffer.
- - private static native long initJniWithAlbertByteBuffers(ByteBuffer... modelBuffers);
- + // modelBuffers[0] is tflite model file buffer, and modelBuffers[1] is sentencepiece model file
- + // buffer.
- + private static native long initJniWithAlbertByteBuffers(ByteBuffer... modelBuffers);
-
- - private static native long initJniWithFileDescriptor(
- - int fileDescriptor,
- - long fileDescriptorLength,
- - long fileDescriptorOffset,
- - long baseOptionsHandle);
- + private static native long initJniWithFileDescriptor(int fileDescriptor,
- + long fileDescriptorLength, long fileDescriptorOffset, long baseOptionsHandle);
-
- - private static native List<QaAnswer> answerNative(
- - long nativeHandle, String context, String question);
- + private static native List<QaAnswer> answerNative(
- + long nativeHandle, String context, String question);
-
- - @Override
- - protected void deinit(long nativeHandle) {
- - deinitJni(nativeHandle);
- - }
- + @Override
- + protected void deinit(long nativeHandle) {
- + deinitJni(nativeHandle);
- + }
-
- - /**
- - * Native implementation to release memory pointed by the pointer.
- - *
- - * @param nativeHandle pointer to memory allocated
- - */
- - private native void deinitJni(long nativeHandle);
- + /**
- + * Native implementation to release memory pointed by the pointer.
- + *
- + * @param nativeHandle pointer to memory allocated
- + */
- + private native void deinitJni(long nativeHandle);
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/QaAnswer.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/QaAnswer.java
- index b75a07e10cc7b..50917c035a995 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/QaAnswer.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/QaAnswer.java
- @@ -22,37 +22,37 @@ import org.tensorflow.lite.task.core.annotations.UsedByReflection;
- * position information to the context.
- */
- public class QaAnswer {
- - public Pos pos;
- - public String text;
- -
- - @UsedByReflection("bert_question_answerer_jni.cc")
- - public QaAnswer(String text, Pos pos) {
- - this.text = text;
- - this.pos = pos;
- - }
- -
- - public QaAnswer(String text, int start, int end, float logit) {
- - this(text, new Pos(start, end, logit));
- - }
- -
- - /**
- - * Position information of the answer relative to context. It is sortable in descending order
- - * based on logit.
- - */
- - public static class Pos implements Comparable<Pos> {
- - public int start;
- - public int end;
- - public float logit;
- -
- - public Pos(int start, int end, float logit) {
- - this.start = start;
- - this.end = end;
- - this.logit = logit;
- + public Pos pos;
- + public String text;
- +
- + @UsedByReflection("bert_question_answerer_jni.cc")
- + public QaAnswer(String text, Pos pos) {
- + this.text = text;
- + this.pos = pos;
- + }
- +
- + public QaAnswer(String text, int start, int end, float logit) {
- + this(text, new Pos(start, end, logit));
- }
-
- - @Override
- - public int compareTo(Pos other) {
- - return Float.compare(other.logit, this.logit);
- + /**
- + * Position information of the answer relative to context. It is sortable in descending order
- + * based on logit.
- + */
- + public static class Pos implements Comparable<Pos> {
- + public int start;
- + public int end;
- + public float logit;
- +
- + public Pos(int start, int end, float logit) {
- + this.start = start;
- + this.end = end;
- + this.logit = logit;
- + }
- +
- + @Override
- + public int compareTo(Pos other) {
- + return Float.compare(other.logit, this.logit);
- + }
- }
- - }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/QuestionAnswerer.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/QuestionAnswerer.java
- index 8df6d3794e1b5..7a59a99d7fddf 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/QuestionAnswerer.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/QuestionAnswerer.java
- @@ -19,14 +19,13 @@ import java.util.List;
-
- /** API to answer questions based on context. */
- public interface QuestionAnswerer {
- -
- - /**
- - * Answers question based on context, and returns a list of possible {@link QaAnswer}s. Could be
- - * empty if no answer was found from the given context.
- - *
- - * @param context context the question bases on
- - * @param question question to ask
- - * @return a list of possible answers in {@link QaAnswer}
- - */
- - List<QaAnswer> answer(String context, String question);
- + /**
- + * Answers question based on context, and returns a list of possible {@link QaAnswer}s. Could be
- + * empty if no answer was found from the given context.
- + *
- + * @param context context the question bases on
- + * @param question question to ask
- + * @return a list of possible answers in {@link QaAnswer}
- + */
- + List<QaAnswer> answer(String context, String question);
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/searcher/TextSearcher.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/searcher/TextSearcher.java
- index 1a32d10e47114..ea3b1b8c25b34 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/searcher/TextSearcher.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/searcher/TextSearcher.java
- @@ -18,12 +18,9 @@ package org.tensorflow.lite.task.text.searcher;
- import android.content.Context;
- import android.content.res.AssetFileDescriptor;
- import android.os.ParcelFileDescriptor;
- +
- import com.google.auto.value.AutoValue;
- -import java.io.File;
- -import java.io.IOException;
- -import java.nio.ByteBuffer;
- -import java.nio.MappedByteBuffer;
- -import java.util.List;
- +
- import org.tensorflow.lite.task.core.BaseOptions;
- import org.tensorflow.lite.task.core.BaseTaskApi;
- import org.tensorflow.lite.task.core.TaskJniUtils;
- @@ -31,6 +28,12 @@ import org.tensorflow.lite.task.core.TaskJniUtils.EmptyHandleProvider;
- import org.tensorflow.lite.task.processor.NearestNeighbor;
- import org.tensorflow.lite.task.processor.SearcherOptions;
-
- +import java.io.File;
- +import java.io.IOException;
- +import java.nio.ByteBuffer;
- +import java.nio.MappedByteBuffer;
- +import java.util.List;
- +
- /**
- * Performs similarity search on text string.
- *
- @@ -67,227 +70,193 @@ import org.tensorflow.lite.task.processor.SearcherOptions;
- * the single file format (index file packed in the model) is supported.
- */
- public final class TextSearcher extends BaseTaskApi {
- + private static final String TEXT_SEARCHER_NATIVE_LIB = "task_text_jni";
- + private static final int OPTIONAL_FD_LENGTH = -1;
- + private static final int OPTIONAL_FD_OFFSET = -1;
-
- - private static final String TEXT_SEARCHER_NATIVE_LIB = "task_text_jni";
- - private static final int OPTIONAL_FD_LENGTH = -1;
- - private static final int OPTIONAL_FD_OFFSET = -1;
- + /**
- + * Creates an {@link TextSearcher} instance from {@link TextSearcherOptions}.
- + *
- + * @param modelPath path of the search model with metadata in the assets
- + * @throws IOException if an I/O error occurs when loading the tflite model or the index file
- + * @throws IllegalArgumentException if an argument is invalid
- + * @throws IllegalStateException if there is an internal error
- + * @throws RuntimeException if there is an otherwise unspecified error
- + */
- + public static TextSearcher createFromFileAndOptions(Context context, String modelPath,
- + final TextSearcherOptions options) throws IOException {
- + try (AssetFileDescriptor assetFileDescriptor = context.getAssets().openFd(modelPath)) {
- + return createFromModelFdAndOptions(
- + /*modelDescriptor=*/assetFileDescriptor.getParcelFileDescriptor().getFd(),
- + /*modelDescriptorLength=*/assetFileDescriptor.getLength(),
- + /*modelDescriptorOffset=*/assetFileDescriptor.getStartOffset(), options);
- + }
- + }
-
- - /**
- - * Creates an {@link TextSearcher} instance from {@link TextSearcherOptions}.
- - *
- - * @param modelPath path of the search model with metadata in the assets
- - * @throws IOException if an I/O error occurs when loading the tflite model or the index file
- - * @throws IllegalArgumentException if an argument is invalid
- - * @throws IllegalStateException if there is an internal error
- - * @throws RuntimeException if there is an otherwise unspecified error
- - */
- - public static TextSearcher createFromFileAndOptions(
- - Context context, String modelPath, final TextSearcherOptions options) throws IOException {
- - try (AssetFileDescriptor assetFileDescriptor = context.getAssets().openFd(modelPath)) {
- - return createFromModelFdAndOptions(
- - /*modelDescriptor=*/ assetFileDescriptor.getParcelFileDescriptor().getFd(),
- - /*modelDescriptorLength=*/ assetFileDescriptor.getLength(),
- - /*modelDescriptorOffset=*/ assetFileDescriptor.getStartOffset(),
- - options);
- + /**
- + * Creates an {@link TextSearcher} instance.
- + *
- + * @param modelFile the search model {@link File} instance
- + * @throws IOException if an I/O error occurs when loading the tflite model or the index file
- + * @throws IllegalArgumentException if an argument is invalid
- + * @throws IllegalStateException if there is an internal error
- + * @throws RuntimeException if there is an otherwise unspecified error
- + */
- + public static TextSearcher createFromFileAndOptions(
- + File modelFile, final TextSearcherOptions options) throws IOException {
- + try (ParcelFileDescriptor descriptor =
- + ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
- + return createFromModelFdAndOptions(
- + /*modelDescriptor=*/descriptor.getFd(),
- + /*modelDescriptorLength=*/OPTIONAL_FD_LENGTH,
- + /*modelDescriptorOffset=*/OPTIONAL_FD_OFFSET, options);
- + }
- }
- - }
-
- - /**
- - * Creates an {@link TextSearcher} instance.
- - *
- - * @param modelFile the search model {@link File} instance
- - * @throws IOException if an I/O error occurs when loading the tflite model or the index file
- - * @throws IllegalArgumentException if an argument is invalid
- - * @throws IllegalStateException if there is an internal error
- - * @throws RuntimeException if there is an otherwise unspecified error
- - */
- - public static TextSearcher createFromFileAndOptions(
- - File modelFile, final TextSearcherOptions options) throws IOException {
- - try (ParcelFileDescriptor descriptor =
- - ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
- - return createFromModelFdAndOptions(
- - /*modelDescriptor=*/ descriptor.getFd(),
- - /*modelDescriptorLength=*/ OPTIONAL_FD_LENGTH,
- - /*modelDescriptorOffset=*/ OPTIONAL_FD_OFFSET,
- - options);
- + /**
- + * Creates an {@link TextSearcher} instance with a model buffer and {@link TextSearcherOptions}.
- + *
- + * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the search
- + * model
- + * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
- + * {@link MappedByteBuffer}
- + * @throws IOException if an I/O error occurs when loading the index file
- + * @throws IllegalStateException if there is an internal error
- + * @throws RuntimeException if there is an otherwise unspecified error
- + */
- + public static TextSearcher createFromBufferAndOptions(
- + final ByteBuffer modelBuffer, final TextSearcherOptions options) throws IOException {
- + if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) {
- + throw new IllegalArgumentException(
- + "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer.");
- + }
- + if (options.getSearcherOptions().getIndexFile() != null) {
- + try (ParcelFileDescriptor indexDescriptor =
- + ParcelFileDescriptor.open(options.getSearcherOptions().getIndexFile(),
- + ParcelFileDescriptor.MODE_READ_ONLY)) {
- + return createFromBufferAndOptionsImpl(
- + modelBuffer, options, indexDescriptor.getFd());
- + }
- + } else {
- + return createFromBufferAndOptionsImpl(modelBuffer, options, /*indexFd=*/0);
- + }
- }
- - }
-
- - /**
- - * Creates an {@link TextSearcher} instance with a model buffer and {@link TextSearcherOptions}.
- - *
- - * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the search
- - * model
- - * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
- - * {@link MappedByteBuffer}
- - * @throws IOException if an I/O error occurs when loading the index file
- - * @throws IllegalStateException if there is an internal error
- - * @throws RuntimeException if there is an otherwise unspecified error
- - */
- - public static TextSearcher createFromBufferAndOptions(
- - final ByteBuffer modelBuffer, final TextSearcherOptions options) throws IOException {
- - if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) {
- - throw new IllegalArgumentException(
- - "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer.");
- + public static TextSearcher createFromBufferAndOptionsImpl(
- + final ByteBuffer modelBuffer, final TextSearcherOptions options, final int indexFd) {
- + return new TextSearcher(TaskJniUtils.createHandleFromLibrary(new EmptyHandleProvider() {
- + @Override
- + public long createHandle() {
- + return initJniWithByteBuffer(modelBuffer,
- + TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions()),
- + options.getSearcherOptions().getL2Normalize(),
- + options.getSearcherOptions().getQuantize(), indexFd,
- + options.getSearcherOptions().getMaxResults());
- + }
- + }, TEXT_SEARCHER_NATIVE_LIB));
- }
- - if (options.getSearcherOptions().getIndexFile() != null) {
- - try (ParcelFileDescriptor indexDescriptor =
- - ParcelFileDescriptor.open(
- - options.getSearcherOptions().getIndexFile(), ParcelFileDescriptor.MODE_READ_ONLY)) {
- - return createFromBufferAndOptionsImpl(modelBuffer, options, indexDescriptor.getFd());
- - }
- - } else {
- - return createFromBufferAndOptionsImpl(modelBuffer, options, /*indexFd=*/ 0);
- +
- + /**
- + * Constructor to initialize the JNI with a pointer from C++.
- + *
- + * @param nativeHandle a pointer referencing memory allocated in C++
- + */
- + TextSearcher(long nativeHandle) {
- + super(nativeHandle);
- }
- - }
-
- - public static TextSearcher createFromBufferAndOptionsImpl(
- - final ByteBuffer modelBuffer, final TextSearcherOptions options, final int indexFd) {
- - return new TextSearcher(
- - TaskJniUtils.createHandleFromLibrary(
- - new EmptyHandleProvider() {
- - @Override
- - public long createHandle() {
- - return initJniWithByteBuffer(
- - modelBuffer,
- - TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions()),
- - options.getSearcherOptions().getL2Normalize(),
- - options.getSearcherOptions().getQuantize(),
- - indexFd,
- - options.getSearcherOptions().getMaxResults());
- - }
- - },
- - TEXT_SEARCHER_NATIVE_LIB));
- - }
- + /** Options for setting up an TextSearcher. */
- + @AutoValue
- + public abstract static class TextSearcherOptions {
- + abstract BaseOptions getBaseOptions();
-
- - /**
- - * Constructor to initialize the JNI with a pointer from C++.
- - *
- - * @param nativeHandle a pointer referencing memory allocated in C++
- - */
- - TextSearcher(long nativeHandle) {
- - super(nativeHandle);
- - }
- + abstract SearcherOptions getSearcherOptions();
-
- - /** Options for setting up an TextSearcher. */
- - @AutoValue
- - public abstract static class TextSearcherOptions {
- + public static Builder builder() {
- + return new AutoValue_TextSearcher_TextSearcherOptions.Builder()
- + .setBaseOptions(BaseOptions.builder().build())
- + .setSearcherOptions(SearcherOptions.builder().build());
- + }
-
- - abstract BaseOptions getBaseOptions();
- + /** Builder for {@link TextSearcherOptions}. */
- + @AutoValue.Builder
- + public abstract static class Builder {
- + /** Sets the general options to configure Task APIs, such as accelerators. */
- + public abstract Builder setBaseOptions(BaseOptions baseOptions);
-
- - abstract SearcherOptions getSearcherOptions();
- + /** Sets the options to configure Searcher API. */
- + public abstract Builder setSearcherOptions(SearcherOptions searcherOptions);
-
- - public static Builder builder() {
- - return new AutoValue_TextSearcher_TextSearcherOptions.Builder()
- - .setBaseOptions(BaseOptions.builder().build())
- - .setSearcherOptions(SearcherOptions.builder().build());
- + public abstract TextSearcherOptions build();
- + }
- }
-
- - /** Builder for {@link TextSearcherOptions}. */
- - @AutoValue.Builder
- - public abstract static class Builder {
- - /** Sets the general options to configure Task APIs, such as accelerators. */
- - public abstract Builder setBaseOptions(BaseOptions baseOptions);
- -
- - /** Sets the options to configure Searcher API. */
- - public abstract Builder setSearcherOptions(SearcherOptions searcherOptions);
- -
- - public abstract TextSearcherOptions build();
- + /**
- + * Performs embedding extraction on the provided string input, followed by nearest-neighbor
- + * search in the index.
- + *
- + * @param text input text query to the model
- + */
- + public List<NearestNeighbor> search(String text) {
- + return searchNative(getNativeHandle(), text);
- }
- - }
- -
- - /**
- - * Performs embedding extraction on the provided string input, followed by nearest-neighbor search
- - * in the index.
- - *
- - * @param text input text query to the model
- - */
- - public List<NearestNeighbor> search(String text) {
- - return searchNative(getNativeHandle(), text);
- - }
-
- - private static TextSearcher createFromModelFdAndOptions(
- - final int modelDescriptor,
- - final long modelDescriptorLength,
- - final long modelDescriptorOffset,
- - final TextSearcherOptions options)
- - throws IOException {
- - if (options.getSearcherOptions().getIndexFile() != null) {
- - // indexDescriptor must be alive before TextSearcher is initialized completely in the native
- - // layer.
- - try (ParcelFileDescriptor indexDescriptor =
- - ParcelFileDescriptor.open(
- - options.getSearcherOptions().getIndexFile(), ParcelFileDescriptor.MODE_READ_ONLY)) {
- - return createFromModelFdAndOptionsImpl(
- - modelDescriptor,
- - modelDescriptorLength,
- - modelDescriptorOffset,
- - options,
- - indexDescriptor.getFd());
- - }
- - } else {
- - // Index file is not configured. We'll check if the model contains one in the native layer.
- - return createFromModelFdAndOptionsImpl(
- - modelDescriptor, modelDescriptorLength, modelDescriptorOffset, options, /*indexFd=*/ 0);
- + private static TextSearcher createFromModelFdAndOptions(final int modelDescriptor,
- + final long modelDescriptorLength, final long modelDescriptorOffset,
- + final TextSearcherOptions options) throws IOException {
- + if (options.getSearcherOptions().getIndexFile() != null) {
- + // indexDescriptor must be alive before TextSearcher is initialized completely in the
- + // native layer.
- + try (ParcelFileDescriptor indexDescriptor =
- + ParcelFileDescriptor.open(options.getSearcherOptions().getIndexFile(),
- + ParcelFileDescriptor.MODE_READ_ONLY)) {
- + return createFromModelFdAndOptionsImpl(modelDescriptor, modelDescriptorLength,
- + modelDescriptorOffset, options, indexDescriptor.getFd());
- + }
- + } else {
- + // Index file is not configured. We'll check if the model contains one in the native
- + // layer.
- + return createFromModelFdAndOptionsImpl(modelDescriptor, modelDescriptorLength,
- + modelDescriptorOffset, options, /*indexFd=*/0);
- + }
- }
- - }
-
- - private static TextSearcher createFromModelFdAndOptionsImpl(
- - final int modelDescriptor,
- - final long modelDescriptorLength,
- - final long modelDescriptorOffset,
- - final TextSearcherOptions options,
- - final int indexFd) {
- - long nativeHandle =
- - TaskJniUtils.createHandleFromLibrary(
- - new EmptyHandleProvider() {
- - @Override
- - public long createHandle() {
- - return initJniWithModelFdAndOptions(
- - modelDescriptor,
- - modelDescriptorLength,
- - modelDescriptorOffset,
- - TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions()),
- - options.getSearcherOptions().getL2Normalize(),
- - options.getSearcherOptions().getQuantize(),
- - indexFd,
- - options.getSearcherOptions().getMaxResults());
- - }
- - },
- - TEXT_SEARCHER_NATIVE_LIB);
- - return new TextSearcher(nativeHandle);
- - }
- + private static TextSearcher createFromModelFdAndOptionsImpl(final int modelDescriptor,
- + final long modelDescriptorLength, final long modelDescriptorOffset,
- + final TextSearcherOptions options, final int indexFd) {
- + long nativeHandle = TaskJniUtils.createHandleFromLibrary(new EmptyHandleProvider() {
- + @Override
- + public long createHandle() {
- + return initJniWithModelFdAndOptions(modelDescriptor, modelDescriptorLength,
- + modelDescriptorOffset,
- + TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions()),
- + options.getSearcherOptions().getL2Normalize(),
- + options.getSearcherOptions().getQuantize(), indexFd,
- + options.getSearcherOptions().getMaxResults());
- + }
- + }, TEXT_SEARCHER_NATIVE_LIB);
- + return new TextSearcher(nativeHandle);
- + }
-
- - private static native long initJniWithModelFdAndOptions(
- - int modelDescriptor,
- - long modelDescriptorLength,
- - long modelDescriptorOffset,
- - long baseOptionsHandle,
- - boolean l2Normalize,
- - boolean quantize,
- - int indexDescriptor,
- - int maxResults);
- + private static native long initJniWithModelFdAndOptions(int modelDescriptor,
- + long modelDescriptorLength, long modelDescriptorOffset, long baseOptionsHandle,
- + boolean l2Normalize, boolean quantize, int indexDescriptor, int maxResults);
-
- - private static native long initJniWithByteBuffer(
- - ByteBuffer modelBuffer,
- - long baseOptionsHandle,
- - boolean l2Normalize,
- - boolean quantize,
- - int indexFileDescriptor,
- - int maxResults);
- + private static native long initJniWithByteBuffer(ByteBuffer modelBuffer, long baseOptionsHandle,
- + boolean l2Normalize, boolean quantize, int indexFileDescriptor, int maxResults);
-
- - /** The native method to search an input text string. */
- - private static native List<NearestNeighbor> searchNative(long nativeHandle, String text);
- + /** The native method to search an input text string. */
- + private static native List<NearestNeighbor> searchNative(long nativeHandle, String text);
-
- - @Override
- - protected void deinit(long nativeHandle) {
- - deinitJni(nativeHandle);
- - }
- + @Override
- + protected void deinit(long nativeHandle) {
- + deinitJni(nativeHandle);
- + }
-
- - /**
- - * Native implementation to release memory pointed by the pointer.
- - *
- - * @param nativeHandle pointer to memory allocated
- - */
- - private native void deinitJni(long nativeHandle);
- + /**
- + * Native implementation to release memory pointed by the pointer.
- + *
- + * @param nativeHandle pointer to memory allocated
- + */
- + private native void deinitJni(long nativeHandle);
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/Classifications.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/Classifications.java
- index 88aeecc8d62ca..e59a2e89e86f4 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/Classifications.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/Classifications.java
- @@ -16,11 +16,13 @@ limitations under the License.
- package org.tensorflow.lite.task.vision.classifier;
-
- import com.google.auto.value.AutoValue;
- +
- +import org.tensorflow.lite.support.label.Category;
- +import org.tensorflow.lite.task.core.annotations.UsedByReflection;
- +
- import java.util.ArrayList;
- import java.util.Collections;
- import java.util.List;
- -import org.tensorflow.lite.support.label.Category;
- -import org.tensorflow.lite.task.core.annotations.UsedByReflection;
-
- /**
- * The classification results of one head in a multihead (a.k.a. multi-output) {@link
- @@ -31,16 +33,15 @@ import org.tensorflow.lite.task.core.annotations.UsedByReflection;
- @AutoValue
- @UsedByReflection("image_classifier_jni.cc")
- public abstract class Classifications {
- + @UsedByReflection("image_classifier_jni.cc")
- + static Classifications create(List<Category> categories, int headIndex) {
- + return new AutoValue_Classifications(
- + Collections.unmodifiableList(new ArrayList<Category>(categories)), headIndex);
- + }
-
- - @UsedByReflection("image_classifier_jni.cc")
- - static Classifications create(List<Category> categories, int headIndex) {
- - return new AutoValue_Classifications(
- - Collections.unmodifiableList(new ArrayList<Category>(categories)), headIndex);
- - }
- -
- - // Same reason for not using ImmutableList as stated in
- - // {@link ImageClassifier#ImageClassifierOptions#labelAllowList}.
- - public abstract List<Category> getCategories();
- + // Same reason for not using ImmutableList as stated in
- + // {@link ImageClassifier#ImageClassifierOptions#labelAllowList}.
- + public abstract List<Category> getCategories();
-
- - public abstract int getHeadIndex();
- + public abstract int getHeadIndex();
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/ImageClassifier.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/ImageClassifier.java
- index 90628928198d5..5b5be73bcca1e 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/ImageClassifier.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/ImageClassifier.java
- @@ -18,14 +18,9 @@ package org.tensorflow.lite.task.vision.classifier;
- import android.content.Context;
- import android.graphics.Rect;
- import android.os.ParcelFileDescriptor;
- +
- import com.google.android.odml.image.MlImage;
- -import java.io.File;
- -import java.io.IOException;
- -import java.nio.ByteBuffer;
- -import java.nio.MappedByteBuffer;
- -import java.util.ArrayList;
- -import java.util.Collections;
- -import java.util.List;
- +
- import org.tensorflow.lite.support.image.MlImageAdapter;
- import org.tensorflow.lite.support.image.TensorImage;
- import org.tensorflow.lite.task.core.BaseOptions;
- @@ -37,6 +32,14 @@ import org.tensorflow.lite.task.core.vision.ImageProcessingOptions;
- import org.tensorflow.lite.task.vision.core.BaseVisionTaskApi;
- import org.tensorflow.lite.task.vision.core.BaseVisionTaskApi.InferenceProvider;
-
- +import java.io.File;
- +import java.io.IOException;
- +import java.nio.ByteBuffer;
- +import java.nio.MappedByteBuffer;
- +import java.util.ArrayList;
- +import java.util.Collections;
- +import java.util.List;
- +
- /**
- * Performs classification on images.
- *
- @@ -71,476 +74,449 @@ import org.tensorflow.lite.task.vision.core.BaseVisionTaskApi.InferenceProvider;
- * Hub.</a>.
- */
- public final class ImageClassifier extends BaseVisionTaskApi {
- + private static final String IMAGE_CLASSIFIER_NATIVE_LIB = "task_vision_jni";
- + private static final int OPTIONAL_FD_LENGTH = -1;
- + private static final int OPTIONAL_FD_OFFSET = -1;
- +
- + /**
- + * Creates an {@link ImageClassifier} instance from the default {@link ImageClassifierOptions}.
- + *
- + * @param modelPath path of the classification model with metadata in the assets
- + * @throws IOException if an I/O error occurs when loading the tflite model
- + * @throws IllegalArgumentException if an argument is invalid
- + * @throws IllegalStateException if there is an internal error
- + * @throws RuntimeException if there is an otherwise unspecified error
- + */
- + public static ImageClassifier createFromFile(Context context, String modelPath)
- + throws IOException {
- + return createFromFileAndOptions(
- + context, modelPath, ImageClassifierOptions.builder().build());
- + }
-
- - private static final String IMAGE_CLASSIFIER_NATIVE_LIB = "task_vision_jni";
- - private static final int OPTIONAL_FD_LENGTH = -1;
- - private static final int OPTIONAL_FD_OFFSET = -1;
- -
- - /**
- - * Creates an {@link ImageClassifier} instance from the default {@link ImageClassifierOptions}.
- - *
- - * @param modelPath path of the classification model with metadata in the assets
- - * @throws IOException if an I/O error occurs when loading the tflite model
- - * @throws IllegalArgumentException if an argument is invalid
- - * @throws IllegalStateException if there is an internal error
- - * @throws RuntimeException if there is an otherwise unspecified error
- - */
- - public static ImageClassifier createFromFile(Context context, String modelPath)
- - throws IOException {
- - return createFromFileAndOptions(context, modelPath, ImageClassifierOptions.builder().build());
- - }
- -
- - /**
- - * Creates an {@link ImageClassifier} instance from the default {@link ImageClassifierOptions}.
- - *
- - * @param modelFile the classification model {@link File} instance
- - * @throws IOException if an I/O error occurs when loading the tflite model
- - * @throws IllegalArgumentException if an argument is invalid
- - * @throws IllegalStateException if there is an internal error
- - * @throws RuntimeException if there is an otherwise unspecified error
- - */
- - public static ImageClassifier createFromFile(File modelFile) throws IOException {
- - return createFromFileAndOptions(modelFile, ImageClassifierOptions.builder().build());
- - }
- -
- - /**
- - * Creates an {@link ImageClassifier} instance with a model buffer and the default {@link
- - * ImageClassifierOptions}.
- - *
- - * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the
- - * classification model
- - * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
- - * {@link MappedByteBuffer}
- - * @throws IllegalStateException if there is an internal error
- - * @throws RuntimeException if there is an otherwise unspecified error
- - */
- - public static ImageClassifier createFromBuffer(final ByteBuffer modelBuffer) {
- - return createFromBufferAndOptions(modelBuffer, ImageClassifierOptions.builder().build());
- - }
- -
- - /**
- - * Creates an {@link ImageClassifier} instance from {@link ImageClassifierOptions}.
- - *
- - * @param modelPath path of the classification model with metadata in the assets
- - * @throws IOException if an I/O error occurs when loading the tflite model
- - * @throws IllegalArgumentException if an argument is invalid
- - * @throws IllegalStateException if there is an internal error
- - * @throws RuntimeException if there is an otherwise unspecified error
- - */
- - public static ImageClassifier createFromFileAndOptions(
- - Context context, String modelPath, ImageClassifierOptions options) throws IOException {
- - return new ImageClassifier(
- - TaskJniUtils.createHandleFromFdAndOptions(
- - context,
- - new FdAndOptionsHandleProvider<ImageClassifierOptions>() {
- - @Override
- - public long createHandle(
- - int fileDescriptor,
- - long fileDescriptorLength,
- - long fileDescriptorOffset,
- - ImageClassifierOptions options) {
- - return initJniWithModelFdAndOptions(
- - fileDescriptor,
- - fileDescriptorLength,
- - fileDescriptorOffset,
- - options,
- - TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads(
- - options.getBaseOptions(), options.getNumThreads()));
- - }
- - },
- - IMAGE_CLASSIFIER_NATIVE_LIB,
- - modelPath,
- - options));
- - }
- -
- - /**
- - * Creates an {@link ImageClassifier} instance.
- - *
- - * @param modelFile the classification model {@link File} instance
- - * @throws IOException if an I/O error occurs when loading the tflite model
- - * @throws IllegalArgumentException if an argument is invalid
- - * @throws IllegalStateException if there is an internal error
- - * @throws RuntimeException if there is an otherwise unspecified error
- - */
- - public static ImageClassifier createFromFileAndOptions(
- - File modelFile, final ImageClassifierOptions options) throws IOException {
- - try (ParcelFileDescriptor descriptor =
- - ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
- - return new ImageClassifier(
- - TaskJniUtils.createHandleFromLibrary(
- - new TaskJniUtils.EmptyHandleProvider() {
- - @Override
- - public long createHandle() {
- - return initJniWithModelFdAndOptions(
- - descriptor.getFd(),
- - /*fileDescriptorLength=*/ OPTIONAL_FD_LENGTH,
- - /*fileDescriptorOffset=*/ OPTIONAL_FD_OFFSET,
- - options,
- - TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads(
- - options.getBaseOptions(), options.getNumThreads()));
- - }
- - },
- - IMAGE_CLASSIFIER_NATIVE_LIB));
- + /**
- + * Creates an {@link ImageClassifier} instance from the default {@link ImageClassifierOptions}.
- + *
- + * @param modelFile the classification model {@link File} instance
- + * @throws IOException if an I/O error occurs when loading the tflite model
- + * @throws IllegalArgumentException if an argument is invalid
- + * @throws IllegalStateException if there is an internal error
- + * @throws RuntimeException if there is an otherwise unspecified error
- + */
- + public static ImageClassifier createFromFile(File modelFile) throws IOException {
- + return createFromFileAndOptions(modelFile, ImageClassifierOptions.builder().build());
- }
- - }
- -
- - /**
- - * Creates an {@link ImageClassifier} instance with a model buffer and {@link
- - * ImageClassifierOptions}.
- - *
- - * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the
- - * classification model
- - * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
- - * {@link MappedByteBuffer}
- - * @throws IllegalStateException if there is an internal error
- - * @throws RuntimeException if there is an otherwise unspecified error
- - */
- - public static ImageClassifier createFromBufferAndOptions(
- - final ByteBuffer modelBuffer, final ImageClassifierOptions options) {
- - if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) {
- - throw new IllegalArgumentException(
- - "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer.");
- +
- + /**
- + * Creates an {@link ImageClassifier} instance with a model buffer and the default {@link
- + * ImageClassifierOptions}.
- + *
- + * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the
- + * classification model
- + * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
- + * {@link MappedByteBuffer}
- + * @throws IllegalStateException if there is an internal error
- + * @throws RuntimeException if there is an otherwise unspecified error
- + */
- + public static ImageClassifier createFromBuffer(final ByteBuffer modelBuffer) {
- + return createFromBufferAndOptions(modelBuffer, ImageClassifierOptions.builder().build());
- }
- - return new ImageClassifier(
- - TaskJniUtils.createHandleFromLibrary(
- - new EmptyHandleProvider() {
- - @Override
- - public long createHandle() {
- - return initJniWithByteBuffer(
- - modelBuffer,
- - options,
- - TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads(
- - options.getBaseOptions(), options.getNumThreads()));
- - }
- - },
- - IMAGE_CLASSIFIER_NATIVE_LIB));
- - }
- -
- - /**
- - * Constructor to initialize the JNI with a pointer from C++.
- - *
- - * @param nativeHandle a pointer referencing memory allocated in C++
- - */
- - ImageClassifier(long nativeHandle) {
- - super(nativeHandle);
- - }
- -
- - /** Options for setting up an ImageClassifier. */
- - @UsedByReflection("image_classifier_jni.cc")
- - public static class ImageClassifierOptions {
- - // Not using AutoValue for this class because scoreThreshold cannot have default value
- - // (otherwise, the default value would override the one in the model metadata) and `Optional` is
- - // not an option here, because
- - // 1. java.util.Optional require Java 8 while we need to support Java 7.
- - // 2. The Guava library (com.google.common.base.Optional) is avoided in this project. See the
- - // comments for labelAllowList.
- - private final BaseOptions baseOptions;
- - private final String displayNamesLocale;
- - private final int maxResults;
- - private final float scoreThreshold;
- - private final boolean isScoreThresholdSet;
- - // As an open source project, we've been trying avoiding depending on common java libraries,
- - // such as Guava, because it may introduce conflicts with clients who also happen to use those
- - // libraries. Therefore, instead of using ImmutableList here, we convert the List into
- - // unmodifiableList in setLabelAllowList() and setLabelDenyList() to make it less
- - // vulnerable.
- - private final List<String> labelAllowList;
- - private final List<String> labelDenyList;
- - private final int numThreads;
- -
- - public static Builder builder() {
- - return new Builder();
- +
- + /**
- + * Creates an {@link ImageClassifier} instance from {@link ImageClassifierOptions}.
- + *
- + * @param modelPath path of the classification model with metadata in the assets
- + * @throws IOException if an I/O error occurs when loading the tflite model
- + * @throws IllegalArgumentException if an argument is invalid
- + * @throws IllegalStateException if there is an internal error
- + * @throws RuntimeException if there is an otherwise unspecified error
- + */
- + public static ImageClassifier createFromFileAndOptions(
- + Context context, String modelPath, ImageClassifierOptions options) throws IOException {
- + return new ImageClassifier(TaskJniUtils.createHandleFromFdAndOptions(
- + context, new FdAndOptionsHandleProvider<ImageClassifierOptions>() {
- + @Override
- + public long createHandle(int fileDescriptor, long fileDescriptorLength,
- + long fileDescriptorOffset, ImageClassifierOptions options) {
- + return initJniWithModelFdAndOptions(fileDescriptor, fileDescriptorLength,
- + fileDescriptorOffset, options,
- + TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads(
- + options.getBaseOptions(), options.getNumThreads()));
- + }
- + }, IMAGE_CLASSIFIER_NATIVE_LIB, modelPath, options));
- }
-
- - /** A builder that helps to configure an instance of ImageClassifierOptions. */
- - public static class Builder {
- - private BaseOptions baseOptions = BaseOptions.builder().build();
- - private String displayNamesLocale = "en";
- - private int maxResults = -1;
- - private float scoreThreshold;
- - private boolean isScoreThresholdSet = false;
- - private List<String> labelAllowList = new ArrayList<>();
- - private List<String> labelDenyList = new ArrayList<>();
- - private int numThreads = -1;
- -
- - Builder() {}
- -
- - /** Sets the general options to configure Task APIs, such as accelerators. */
- - public Builder setBaseOptions(BaseOptions baseOptions) {
- - this.baseOptions = baseOptions;
- - return this;
- - }
- -
- - /**
- - * Sets the locale to use for display names specified through the TFLite Model Metadata, if
- - * any.
- - *
- - * <p>Defaults to English({@code "en"}). See the <a
- - * href="https://github.com/tensorflow/tflite-support/blob/3ce83f0cfe2c68fecf83e019f2acc354aaba471f/tensorflow_lite_support/metadata/metadata_schema.fbs#L147">TFLite
- - * Metadata schema file.</a> for the accepted pattern of locale.
- - */
- - public Builder setDisplayNamesLocale(String displayNamesLocale) {
- - this.displayNamesLocale = displayNamesLocale;
- - return this;
- - }
- -
- - /**
- - * Sets the maximum number of top scored results to return.
- - *
- - * <p>If < 0, all results will be returned. If 0, an invalid argument error is returned.
- - * Defaults to -1.
- - *
- - * @throws IllegalArgumentException if maxResults is 0.
- - */
- - public Builder setMaxResults(int maxResults) {
- - if (maxResults == 0) {
- - throw new IllegalArgumentException("maxResults cannot be 0.");
- + /**
- + * Creates an {@link ImageClassifier} instance.
- + *
- + * @param modelFile the classification model {@link File} instance
- + * @throws IOException if an I/O error occurs when loading the tflite model
- + * @throws IllegalArgumentException if an argument is invalid
- + * @throws IllegalStateException if there is an internal error
- + * @throws RuntimeException if there is an otherwise unspecified error
- + */
- + public static ImageClassifier createFromFileAndOptions(
- + File modelFile, final ImageClassifierOptions options) throws IOException {
- + try (ParcelFileDescriptor descriptor =
- + ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
- + return new ImageClassifier(
- + TaskJniUtils.createHandleFromLibrary(new TaskJniUtils.EmptyHandleProvider() {
- + @Override
- + public long createHandle() {
- + return initJniWithModelFdAndOptions(descriptor.getFd(),
- + /*fileDescriptorLength=*/OPTIONAL_FD_LENGTH,
- + /*fileDescriptorOffset=*/OPTIONAL_FD_OFFSET, options,
- + TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads(
- + options.getBaseOptions(), options.getNumThreads()));
- + }
- + }, IMAGE_CLASSIFIER_NATIVE_LIB));
- }
- - this.maxResults = maxResults;
- - return this;
- - }
- -
- - /**
- - * Sets the score threshold.
- - *
- - * <p>It overrides the one provided in the model metadata (if any). Results below this value
- - * are rejected.
- - */
- - public Builder setScoreThreshold(float scoreThreshold) {
- - this.scoreThreshold = scoreThreshold;
- - isScoreThresholdSet = true;
- - return this;
- - }
- -
- - /**
- - * Sets the optional allowlist of labels.
- - *
- - * <p>If non-empty, classifications whose label is not in this set will be filtered out.
- - * Duplicate or unknown labels are ignored. Mutually exclusive with labelDenyList.
- - */
- - public Builder setLabelAllowList(List<String> labelAllowList) {
- - this.labelAllowList = Collections.unmodifiableList(new ArrayList<>(labelAllowList));
- - return this;
- - }
- -
- - /**
- - * Sets the optional denylist of labels.
- - *
- - * <p>If non-empty, classifications whose label is in this set will be filtered out. Duplicate
- - * or unknown labels are ignored. Mutually exclusive with labelAllowList.
- - */
- - public Builder setLabelDenyList(List<String> labelDenyList) {
- - this.labelDenyList = Collections.unmodifiableList(new ArrayList<>(labelDenyList));
- - return this;
- - }
- -
- - /**
- - * Sets the number of threads to be used for TFLite ops that support multi-threading when
- - * running inference with CPU. Defaults to -1.
- - *
- - * <p>numThreads should be greater than 0 or equal to -1. Setting numThreads to -1 has the
- - * effect to let TFLite runtime set the value.
- - *
- - * @deprecated use {@link BaseOptions} to configure number of threads instead. This method
- - * will override the number of threads configured from {@link BaseOptions}.
- - */
- - @Deprecated
- - public Builder setNumThreads(int numThreads) {
- - this.numThreads = numThreads;
- - return this;
- - }
- -
- - public ImageClassifierOptions build() {
- - return new ImageClassifierOptions(this);
- - }
- }
-
- - @UsedByReflection("image_classifier_jni.cc")
- - public String getDisplayNamesLocale() {
- - return displayNamesLocale;
- + /**
- + * Creates an {@link ImageClassifier} instance with a model buffer and {@link
- + * ImageClassifierOptions}.
- + *
- + * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the
- + * classification model
- + * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
- + * {@link MappedByteBuffer}
- + * @throws IllegalStateException if there is an internal error
- + * @throws RuntimeException if there is an otherwise unspecified error
- + */
- + public static ImageClassifier createFromBufferAndOptions(
- + final ByteBuffer modelBuffer, final ImageClassifierOptions options) {
- + if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) {
- + throw new IllegalArgumentException(
- + "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer.");
- + }
- + return new ImageClassifier(TaskJniUtils.createHandleFromLibrary(new EmptyHandleProvider() {
- + @Override
- + public long createHandle() {
- + return initJniWithByteBuffer(modelBuffer, options,
- + TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads(
- + options.getBaseOptions(), options.getNumThreads()));
- + }
- + }, IMAGE_CLASSIFIER_NATIVE_LIB));
- }
-
- - @UsedByReflection("image_classifier_jni.cc")
- - public int getMaxResults() {
- - return maxResults;
- + /**
- + * Constructor to initialize the JNI with a pointer from C++.
- + *
- + * @param nativeHandle a pointer referencing memory allocated in C++
- + */
- + ImageClassifier(long nativeHandle) {
- + super(nativeHandle);
- }
-
- + /** Options for setting up an ImageClassifier. */
- @UsedByReflection("image_classifier_jni.cc")
- - public float getScoreThreshold() {
- - return scoreThreshold;
- + public static class ImageClassifierOptions {
- + // Not using AutoValue for this class because scoreThreshold cannot have default value
- + // (otherwise, the default value would override the one in the model metadata) and
- + // `Optional` is not an option here, because
- + // 1. java.util.Optional require Java 8 while we need to support Java 7.
- + // 2. The Guava library (com.google.common.base.Optional) is avoided in this project. See
- + // the comments for labelAllowList.
- + private final BaseOptions baseOptions;
- + private final String displayNamesLocale;
- + private final int maxResults;
- + private final float scoreThreshold;
- + private final boolean isScoreThresholdSet;
- + // As an open source project, we've been trying avoiding depending on common java libraries,
- + // such as Guava, because it may introduce conflicts with clients who also happen to use
- + // those libraries. Therefore, instead of using ImmutableList here, we convert the List into
- + // unmodifiableList in setLabelAllowList() and setLabelDenyList() to make it less
- + // vulnerable.
- + private final List<String> labelAllowList;
- + private final List<String> labelDenyList;
- + private final int numThreads;
- +
- + public static Builder builder() {
- + return new Builder();
- + }
- +
- + /** A builder that helps to configure an instance of ImageClassifierOptions. */
- + public static class Builder {
- + private BaseOptions baseOptions = BaseOptions.builder().build();
- + private String displayNamesLocale = "en";
- + private int maxResults = -1;
- + private float scoreThreshold;
- + private boolean isScoreThresholdSet = false;
- + private List<String> labelAllowList = new ArrayList<>();
- + private List<String> labelDenyList = new ArrayList<>();
- + private int numThreads = -1;
- +
- + Builder() {}
- +
- + /** Sets the general options to configure Task APIs, such as accelerators. */
- + public Builder setBaseOptions(BaseOptions baseOptions) {
- + this.baseOptions = baseOptions;
- + return this;
- + }
- +
- + /**
- + * Sets the locale to use for display names specified through the TFLite Model Metadata,
- + * if any.
- + *
- + * <p>Defaults to English({@code "en"}). See the <a
- + * href="https://github.com/tensorflow/tflite-support/blob/3ce83f0cfe2c68fecf83e019f2acc354aaba471f/tensorflow_lite_support/metadata/metadata_schema.fbs#L147">TFLite
- + * Metadata schema file.</a> for the accepted pattern of locale.
- + */
- + public Builder setDisplayNamesLocale(String displayNamesLocale) {
- + this.displayNamesLocale = displayNamesLocale;
- + return this;
- + }
- +
- + /**
- + * Sets the maximum number of top scored results to return.
- + *
- + * <p>If < 0, all results will be returned. If 0, an invalid argument error is returned.
- + * Defaults to -1.
- + *
- + * @throws IllegalArgumentException if maxResults is 0.
- + */
- + public Builder setMaxResults(int maxResults) {
- + if (maxResults == 0) {
- + throw new IllegalArgumentException("maxResults cannot be 0.");
- + }
- + this.maxResults = maxResults;
- + return this;
- + }
- +
- + /**
- + * Sets the score threshold.
- + *
- + * <p>It overrides the one provided in the model metadata (if any). Results below this
- + * value are rejected.
- + */
- + public Builder setScoreThreshold(float scoreThreshold) {
- + this.scoreThreshold = scoreThreshold;
- + isScoreThresholdSet = true;
- + return this;
- + }
- +
- + /**
- + * Sets the optional allowlist of labels.
- + *
- + * <p>If non-empty, classifications whose label is not in this set will be filtered out.
- + * Duplicate or unknown labels are ignored. Mutually exclusive with labelDenyList.
- + */
- + public Builder setLabelAllowList(List<String> labelAllowList) {
- + this.labelAllowList = Collections.unmodifiableList(new ArrayList<>(labelAllowList));
- + return this;
- + }
- +
- + /**
- + * Sets the optional denylist of labels.
- + *
- + * <p>If non-empty, classifications whose label is in this set will be filtered out.
- + * Duplicate or unknown labels are ignored. Mutually exclusive with labelAllowList.
- + */
- + public Builder setLabelDenyList(List<String> labelDenyList) {
- + this.labelDenyList = Collections.unmodifiableList(new ArrayList<>(labelDenyList));
- + return this;
- + }
- +
- + /**
- + * Sets the number of threads to be used for TFLite ops that support multi-threading
- + * when running inference with CPU. Defaults to -1.
- + *
- + * <p>numThreads should be greater than 0 or equal to -1. Setting numThreads to -1 has
- + * the effect to let TFLite runtime set the value.
- + *
- + * @deprecated use {@link BaseOptions} to configure number of threads instead. This
- + * method
- + * will override the number of threads configured from {@link BaseOptions}.
- + */
- + @Deprecated
- + public Builder setNumThreads(int numThreads) {
- + this.numThreads = numThreads;
- + return this;
- + }
- +
- + public ImageClassifierOptions build() {
- + return new ImageClassifierOptions(this);
- + }
- + }
- +
- + @UsedByReflection("image_classifier_jni.cc")
- + public String getDisplayNamesLocale() {
- + return displayNamesLocale;
- + }
- +
- + @UsedByReflection("image_classifier_jni.cc")
- + public int getMaxResults() {
- + return maxResults;
- + }
- +
- + @UsedByReflection("image_classifier_jni.cc")
- + public float getScoreThreshold() {
- + return scoreThreshold;
- + }
- +
- + @UsedByReflection("image_classifier_jni.cc")
- + public boolean getIsScoreThresholdSet() {
- + return isScoreThresholdSet;
- + }
- +
- + @UsedByReflection("image_classifier_jni.cc")
- + public List<String> getLabelAllowList() {
- + return new ArrayList<>(labelAllowList);
- + }
- +
- + @UsedByReflection("image_classifier_jni.cc")
- + public List<String> getLabelDenyList() {
- + return new ArrayList<>(labelDenyList);
- + }
- +
- + @UsedByReflection("image_classifier_jni.cc")
- + public int getNumThreads() {
- + return numThreads;
- + }
- +
- + public BaseOptions getBaseOptions() {
- + return baseOptions;
- + }
- +
- + ImageClassifierOptions(Builder builder) {
- + displayNamesLocale = builder.displayNamesLocale;
- + maxResults = builder.maxResults;
- + scoreThreshold = builder.scoreThreshold;
- + isScoreThresholdSet = builder.isScoreThresholdSet;
- + labelAllowList = builder.labelAllowList;
- + labelDenyList = builder.labelDenyList;
- + numThreads = builder.numThreads;
- + baseOptions = builder.baseOptions;
- + }
- }
-
- - @UsedByReflection("image_classifier_jni.cc")
- - public boolean getIsScoreThresholdSet() {
- - return isScoreThresholdSet;
- + /**
- + * Performs actual classification on the provided {@link TensorImage}.
- + *
- + * <p>{@link ImageClassifier} supports the following {@link TensorImage} color space types:
- + *
- + * <ul>
- + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#RGB}
- + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV12}
- + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV21}
- + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV12}
- + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV21}
- + * </ul>
- + *
- + * @param image a UINT8 {@link TensorImage} object that represents an RGB or YUV image
- + * @throws IllegalArgumentException if the color space type of image is unsupported
- + */
- + public List<Classifications> classify(TensorImage image) {
- + return classify(image, ImageProcessingOptions.builder().build());
- }
-
- - @UsedByReflection("image_classifier_jni.cc")
- - public List<String> getLabelAllowList() {
- - return new ArrayList<>(labelAllowList);
- + /**
- + * Performs actual classification on the provided {@link TensorImage} with {@link
- + * ImageProcessingOptions}.
- + *
- + * <p>{@link ImageClassifier} supports the following options:
- + *
- + * <ul>
- + * <li>Region of interest (ROI) (through {@link ImageProcessingOptions.Builder#setRoi}). It
- + * defaults to the entire image.
- + * <li>image rotation (through {@link ImageProcessingOptions.Builder#setOrientation}). It
- + * defaults to {@link ImageProcessingOptions.Orientation#TOP_LEFT}.
- + * </ul>
- + *
- + * <p>{@link ImageClassifier} supports the following {@link TensorImage} color space types:
- + *
- + * <ul>
- + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#RGB}
- + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV12}
- + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV21}
- + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV12}
- + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV21}
- + * </ul>
- + *
- + * @param image a UINT8 {@link TensorImage} object that represents an RGB or YUV image
- + * @throws IllegalArgumentException if the color space type of image is unsupported
- + */
- + public List<Classifications> classify(TensorImage image, ImageProcessingOptions options) {
- + return run(new InferenceProvider<List<Classifications>>() {
- + @Override
- + public List<Classifications> run(
- + long frameBufferHandle, int width, int height, ImageProcessingOptions options) {
- + return classify(frameBufferHandle, width, height, options);
- + }
- + }, image, options);
- }
-
- - @UsedByReflection("image_classifier_jni.cc")
- - public List<String> getLabelDenyList() {
- - return new ArrayList<>(labelDenyList);
- + /**
- + * Performs actual classification on the provided {@code MlImage}.
- + *
- + * @param image an {@code MlImage} object that represents an image
- + * @throws IllegalArgumentException if the storage type or format of the image is unsupported
- + */
- + public List<Classifications> classify(MlImage image) {
- + return classify(image, ImageProcessingOptions.builder().build());
- }
-
- - @UsedByReflection("image_classifier_jni.cc")
- - public int getNumThreads() {
- - return numThreads;
- + /**
- + * Performs actual classification on the provided {@code MlImage} with {@link
- + * ImageProcessingOptions}.
- + *
- + * <p>{@link ImageClassifier} supports the following options:
- + *
- + * <ul>
- + * <li>Region of interest (ROI) (through {@link ImageProcessingOptions.Builder#setRoi}). It
- + * defaults to the entire image.
- + * <li>image rotation (through {@link ImageProcessingOptions.Builder#setOrientation}). It
- + * defaults to {@link ImageProcessingOptions.Orientation#TOP_LEFT}. {@link
- + * MlImage#getRotation()} is not effective.
- + * </ul>
- + *
- + * @param image a {@code MlImage} object that represents an image
- + * @param options configures options including ROI and rotation
- + * @throws IllegalArgumentException if the storage type or format of the image is unsupported
- + */
- + public List<Classifications> classify(MlImage image, ImageProcessingOptions options) {
- + image.getInternal().acquire();
- + TensorImage tensorImage = MlImageAdapter.createTensorImageFrom(image);
- + List<Classifications> result = classify(tensorImage, options);
- + image.close();
- + return result;
- }
-
- - public BaseOptions getBaseOptions() {
- - return baseOptions;
- + private List<Classifications> classify(
- + long frameBufferHandle, int width, int height, ImageProcessingOptions options) {
- + checkNotClosed();
- +
- + Rect roi = options.getRoi().isEmpty() ? new Rect(0, 0, width, height) : options.getRoi();
- +
- + return classifyNative(getNativeHandle(), frameBufferHandle,
- + new int[] {roi.left, roi.top, roi.width(), roi.height()});
- }
-
- - ImageClassifierOptions(Builder builder) {
- - displayNamesLocale = builder.displayNamesLocale;
- - maxResults = builder.maxResults;
- - scoreThreshold = builder.scoreThreshold;
- - isScoreThresholdSet = builder.isScoreThresholdSet;
- - labelAllowList = builder.labelAllowList;
- - labelDenyList = builder.labelDenyList;
- - numThreads = builder.numThreads;
- - baseOptions = builder.baseOptions;
- + private static native long initJniWithModelFdAndOptions(int fileDescriptor,
- + long fileDescriptorLength, long fileDescriptorOffset, ImageClassifierOptions options,
- + long baseOptionsHandle);
- +
- + private static native long initJniWithByteBuffer(
- + ByteBuffer modelBuffer, ImageClassifierOptions options, long baseOptionsHandle);
- +
- + /**
- + * The native method to classify an image with the ROI and orientation.
- + *
- + * @param roi the ROI of the input image, an array representing the bounding box as {left, top,
- + * width, height}
- + */
- + private static native List<Classifications> classifyNative(
- + long nativeHandle, long frameBufferHandle, int[] roi);
- +
- + @Override
- + protected void deinit(long nativeHandle) {
- + deinitJni(nativeHandle);
- }
- - }
- -
- - /**
- - * Performs actual classification on the provided {@link TensorImage}.
- - *
- - * <p>{@link ImageClassifier} supports the following {@link TensorImage} color space types:
- - *
- - * <ul>
- - * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#RGB}
- - * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV12}
- - * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV21}
- - * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV12}
- - * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV21}
- - * </ul>
- - *
- - * @param image a UINT8 {@link TensorImage} object that represents an RGB or YUV image
- - * @throws IllegalArgumentException if the color space type of image is unsupported
- - */
- - public List<Classifications> classify(TensorImage image) {
- - return classify(image, ImageProcessingOptions.builder().build());
- - }
- -
- - /**
- - * Performs actual classification on the provided {@link TensorImage} with {@link
- - * ImageProcessingOptions}.
- - *
- - * <p>{@link ImageClassifier} supports the following options:
- - *
- - * <ul>
- - * <li>Region of interest (ROI) (through {@link ImageProcessingOptions.Builder#setRoi}). It
- - * defaults to the entire image.
- - * <li>image rotation (through {@link ImageProcessingOptions.Builder#setOrientation}). It
- - * defaults to {@link ImageProcessingOptions.Orientation#TOP_LEFT}.
- - * </ul>
- - *
- - * <p>{@link ImageClassifier} supports the following {@link TensorImage} color space types:
- - *
- - * <ul>
- - * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#RGB}
- - * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV12}
- - * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV21}
- - * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV12}
- - * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV21}
- - * </ul>
- - *
- - * @param image a UINT8 {@link TensorImage} object that represents an RGB or YUV image
- - * @throws IllegalArgumentException if the color space type of image is unsupported
- - */
- - public List<Classifications> classify(TensorImage image, ImageProcessingOptions options) {
- - return run(
- - new InferenceProvider<List<Classifications>>() {
- - @Override
- - public List<Classifications> run(
- - long frameBufferHandle, int width, int height, ImageProcessingOptions options) {
- - return classify(frameBufferHandle, width, height, options);
- - }
- - },
- - image,
- - options);
- - }
- -
- - /**
- - * Performs actual classification on the provided {@code MlImage}.
- - *
- - * @param image an {@code MlImage} object that represents an image
- - * @throws IllegalArgumentException if the storage type or format of the image is unsupported
- - */
- - public List<Classifications> classify(MlImage image) {
- - return classify(image, ImageProcessingOptions.builder().build());
- - }
- -
- - /**
- - * Performs actual classification on the provided {@code MlImage} with {@link
- - * ImageProcessingOptions}.
- - *
- - * <p>{@link ImageClassifier} supports the following options:
- - *
- - * <ul>
- - * <li>Region of interest (ROI) (through {@link ImageProcessingOptions.Builder#setRoi}). It
- - * defaults to the entire image.
- - * <li>image rotation (through {@link ImageProcessingOptions.Builder#setOrientation}). It
- - * defaults to {@link ImageProcessingOptions.Orientation#TOP_LEFT}. {@link
- - * MlImage#getRotation()} is not effective.
- - * </ul>
- - *
- - * @param image a {@code MlImage} object that represents an image
- - * @param options configures options including ROI and rotation
- - * @throws IllegalArgumentException if the storage type or format of the image is unsupported
- - */
- - public List<Classifications> classify(MlImage image, ImageProcessingOptions options) {
- - image.getInternal().acquire();
- - TensorImage tensorImage = MlImageAdapter.createTensorImageFrom(image);
- - List<Classifications> result = classify(tensorImage, options);
- - image.close();
- - return result;
- - }
- -
- - private List<Classifications> classify(
- - long frameBufferHandle, int width, int height, ImageProcessingOptions options) {
- - checkNotClosed();
- -
- - Rect roi = options.getRoi().isEmpty() ? new Rect(0, 0, width, height) : options.getRoi();
- -
- - return classifyNative(
- - getNativeHandle(),
- - frameBufferHandle,
- - new int[] {roi.left, roi.top, roi.width(), roi.height()});
- - }
- -
- - private static native long initJniWithModelFdAndOptions(
- - int fileDescriptor,
- - long fileDescriptorLength,
- - long fileDescriptorOffset,
- - ImageClassifierOptions options,
- - long baseOptionsHandle);
- -
- - private static native long initJniWithByteBuffer(
- - ByteBuffer modelBuffer, ImageClassifierOptions options, long baseOptionsHandle);
- -
- - /**
- - * The native method to classify an image with the ROI and orientation.
- - *
- - * @param roi the ROI of the input image, an array representing the bounding box as {left, top,
- - * width, height}
- - */
- - private static native List<Classifications> classifyNative(
- - long nativeHandle, long frameBufferHandle, int[] roi);
- -
- - @Override
- - protected void deinit(long nativeHandle) {
- - deinitJni(nativeHandle);
- - }
- -
- - /**
- - * Native implementation to release memory pointed by the pointer.
- - *
- - * @param nativeHandle pointer to memory allocated
- - */
- - private native void deinitJni(long nativeHandle);
- +
- + /**
- + * Native implementation to release memory pointed by the pointer.
- + *
- + * @param nativeHandle pointer to memory allocated
- + */
- + private native void deinitJni(long nativeHandle);
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/core/BaseVisionTaskApi.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/core/BaseVisionTaskApi.java
- index fdc898f451337..59ab62a949a25 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/core/BaseVisionTaskApi.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/core/BaseVisionTaskApi.java
- @@ -21,213 +21,184 @@ import static org.tensorflow.lite.support.common.internal.SupportPreconditions.c
- import android.graphics.ImageFormat;
- import android.media.Image;
- import android.media.Image.Plane;
- +
- import com.google.auto.value.AutoValue;
- -import java.nio.ByteBuffer;
- +
- import org.tensorflow.lite.DataType;
- import org.tensorflow.lite.support.image.ColorSpaceType;
- import org.tensorflow.lite.support.image.TensorImage;
- import org.tensorflow.lite.task.core.BaseTaskApi;
- import org.tensorflow.lite.task.core.vision.ImageProcessingOptions;
-
- +import java.nio.ByteBuffer;
- +
- /** Base class for Task Vision APIs. */
- public abstract class BaseVisionTaskApi extends BaseTaskApi {
- -
- - /** Syntax sugar to run vision tasks with FrameBuffer and image processing options. */
- - public interface InferenceProvider<T> {
- - T run(long frameBufferHandle, int width, int height, ImageProcessingOptions options);
- - }
- -
- - protected BaseVisionTaskApi(long nativeHandle) {
- - super(nativeHandle);
- - }
- -
- - /** Runs inference with {@link TensorImage} and {@link ImageProcessingOptions}. */
- - protected <T> T run(
- - InferenceProvider<T> provider, TensorImage image, ImageProcessingOptions options) {
- - FrameBufferData frameBufferData = createFrameBuffer(image, options.getOrientation().getValue());
- - T results =
- - provider.run(
- - frameBufferData.getFrameBufferHandle(), image.getWidth(), image.getHeight(), options);
- - deleteFrameBuffer(
- - frameBufferData.getFrameBufferHandle(),
- - frameBufferData.getByteArrayHandle(),
- - frameBufferData.getByteArray());
- - return results;
- - }
- -
- - private static FrameBufferData createFrameBuffer(TensorImage image, int orientation) {
- - ColorSpaceType colorSpaceType = image.getColorSpaceType();
- - switch (colorSpaceType) {
- - case RGB:
- - case NV12:
- - case NV21:
- - case YV12:
- - case YV21:
- - // All these types can be converted to ByteBuffer inside TensorImage. Creating FrameBuffer
- - // base on the image ByteBuffer.
- - return createFrameBufferFromByteBuffer(image, orientation);
- - case YUV_420_888:
- - // YUV_420_888 is a specific type for android.media.Image.
- - return createFrameBufferFromMediaImage(image, orientation);
- - default:
- - throw new IllegalArgumentException(
- - "Color space type, " + colorSpaceType.name() + ", is unsupported.");
- + /** Syntax sugar to run vision tasks with FrameBuffer and image processing options. */
- + public interface InferenceProvider<T> {
- + T run(long frameBufferHandle, int width, int height, ImageProcessingOptions options);
- }
- - }
- -
- - /**
- - * Creates FrameBuffer from the {@link android.media.Image} stored in the given {@link
- - * TensorImage}.
- - */
- - private static FrameBufferData createFrameBufferFromMediaImage(
- - TensorImage image, int orientation) {
- - Image mediaImage = image.getMediaImage();
- -
- - checkArgument(
- - mediaImage.getFormat() == ImageFormat.YUV_420_888,
- - "Only supports loading YUV_420_888 Image.");
- -
- - Plane[] planes = mediaImage.getPlanes();
- - checkArgument(
- - planes.length == 3,
- - String.format("The input image should have 3 planes, but got %d plane(s).", planes.length));
- -
- - // Verify and rewind planes.
- - for (Plane plane : planes) {
- - ByteBuffer buffer = plane.getBuffer();
- - checkNotNull(buffer, "The image buffer is corrupted and the plane is null.");
- - // From the public documentation, plane.getBuffer() should always return a direct ByteBuffer.
- - // See https://developer.android.com/reference/android/media/Image.Plane#getBuffer()
- - checkArgument(
- - buffer.isDirect(),
- - "The image plane buffer is not a direct ByteBuffer, and is not supported.");
- - buffer.rewind();
- +
- + protected BaseVisionTaskApi(long nativeHandle) {
- + super(nativeHandle);
- }
-
- - return FrameBufferData.create(
- - createFrameBufferFromPlanes(
- - planes[0].getBuffer(),
- - planes[1].getBuffer(),
- - planes[2].getBuffer(),
- - mediaImage.getWidth(),
- - mediaImage.getHeight(),
- - planes[0].getRowStride(),
- - // row_stride and pixel_stride should be identical for U/V planes.
- - planes[1].getRowStride(),
- - planes[1].getPixelStride(),
- - orientation),
- - // FrameBuffer created with direct ByteBuffer does not require memory freeing.
- - /*byteArrayHandle=*/ 0,
- - /*byteArray=*/ new byte[0]);
- - }
- -
- - /** Creates FrameBuffer from the {@link ByteBuffer} stored in the given {@link TensorImage}. */
- - private static FrameBufferData createFrameBufferFromByteBuffer(
- - TensorImage image, int orientation) {
- - // base_vision_api_jni.cc expects an uint8 image. Convert image of other types into uint8.
- - TensorImage imageUint8 =
- - image.getDataType() == DataType.UINT8
- - ? image
- - : TensorImage.createFrom(image, DataType.UINT8);
- -
- - ByteBuffer byteBuffer = imageUint8.getBuffer();
- - byteBuffer.rewind();
- - ColorSpaceType colorSpaceType = image.getColorSpaceType();
- - if (byteBuffer.isDirect()) {
- - return FrameBufferData.create(
- - createFrameBufferFromByteBuffer(
- - byteBuffer,
- - imageUint8.getWidth(),
- - imageUint8.getHeight(),
- - orientation,
- - colorSpaceType.getValue()),
- - // FrameBuffer created with direct ByteBuffer does not require memory freeing.
- - /*byteArrayHandle=*/ 0,
- - /*byteArray=*/ new byte[0]);
- - } else {
- - // If the byte array is copied in jni (during GetByteArrayElements), need to free
- - // the copied array once inference is done.
- - long[] byteArrayHandle = new long[1];
- - byte[] byteArray = getBytesFromByteBuffer(byteBuffer);
- - return FrameBufferData.create(
- - createFrameBufferFromBytes(
- - byteArray,
- - imageUint8.getWidth(),
- - imageUint8.getHeight(),
- - orientation,
- - colorSpaceType.getValue(),
- - byteArrayHandle),
- - byteArrayHandle[0],
- - byteArray);
- + /** Runs inference with {@link TensorImage} and {@link ImageProcessingOptions}. */
- + protected <T> T run(
- + InferenceProvider<T> provider, TensorImage image, ImageProcessingOptions options) {
- + FrameBufferData frameBufferData =
- + createFrameBuffer(image, options.getOrientation().getValue());
- + T results = provider.run(frameBufferData.getFrameBufferHandle(), image.getWidth(),
- + image.getHeight(), options);
- + deleteFrameBuffer(frameBufferData.getFrameBufferHandle(),
- + frameBufferData.getByteArrayHandle(), frameBufferData.getByteArray());
- + return results;
- }
- - }
-
- - /** Holds the FrameBuffer and the underlying data pointers in C++. */
- - @AutoValue
- - abstract static class FrameBufferData {
- + private static FrameBufferData createFrameBuffer(TensorImage image, int orientation) {
- + ColorSpaceType colorSpaceType = image.getColorSpaceType();
- + switch (colorSpaceType) {
- + case RGB:
- + case NV12:
- + case NV21:
- + case YV12:
- + case YV21:
- + // All these types can be converted to ByteBuffer inside TensorImage. Creating
- + // FrameBuffer base on the image ByteBuffer.
- + return createFrameBufferFromByteBuffer(image, orientation);
- + case YUV_420_888:
- + // YUV_420_888 is a specific type for android.media.Image.
- + return createFrameBufferFromMediaImage(image, orientation);
- + default:
- + throw new IllegalArgumentException(
- + "Color space type, " + colorSpaceType.name() + ", is unsupported.");
- + }
- + }
-
- /**
- - * Initializes a {@link FrameBufferData} object.
- - *
- - * @param frameBufferHandle the native handle to the FrameBuffer object.
- - * @param byteArrayHandle the native handle to the data array that backs up the FrameBuffer
- - * object. If the FrameBuffer is created on a byte array, this byte array need to be freed
- - * after inference is done. If the FrameBuffer is created on a direct ByteBuffer, no byte
- - * array needs to be freed, and byteArrayHandle will be 0.
- - * @param byteArray the byte array that is used to create the c++ byte array object, which is
- - * needed when releasing byteArrayHandle. If the FrameBuffer is created on a direct
- - * ByteBuffer (no byte array needs to be freed), pass in an empty array for {@code
- - * byteArray}.
- + * Creates FrameBuffer from the {@link android.media.Image} stored in the given {@link
- + * TensorImage}.
- */
- - public static FrameBufferData create(
- - long frameBufferHandle, long byteArrayHandle, byte[] byteArray) {
- - return new AutoValue_BaseVisionTaskApi_FrameBufferData(
- - frameBufferHandle, byteArrayHandle, byteArray);
- + private static FrameBufferData createFrameBufferFromMediaImage(
- + TensorImage image, int orientation) {
- + Image mediaImage = image.getMediaImage();
- +
- + checkArgument(mediaImage.getFormat() == ImageFormat.YUV_420_888,
- + "Only supports loading YUV_420_888 Image.");
- +
- + Plane[] planes = mediaImage.getPlanes();
- + checkArgument(planes.length == 3,
- + String.format("The input image should have 3 planes, but got %d plane(s).",
- + planes.length));
- +
- + // Verify and rewind planes.
- + for (Plane plane : planes) {
- + ByteBuffer buffer = plane.getBuffer();
- + checkNotNull(buffer, "The image buffer is corrupted and the plane is null.");
- + // From the public documentation, plane.getBuffer() should always return a direct
- + // ByteBuffer. See
- + // https://developer.android.com/reference/android/media/Image.Plane#getBuffer()
- + checkArgument(buffer.isDirect(),
- + "The image plane buffer is not a direct ByteBuffer, and is not supported.");
- + buffer.rewind();
- + }
- +
- + return FrameBufferData.create(
- + createFrameBufferFromPlanes(planes[0].getBuffer(), planes[1].getBuffer(),
- + planes[2].getBuffer(), mediaImage.getWidth(), mediaImage.getHeight(),
- + planes[0].getRowStride(),
- + // row_stride and pixel_stride should be identical for U/V planes.
- + planes[1].getRowStride(), planes[1].getPixelStride(), orientation),
- + // FrameBuffer created with direct ByteBuffer does not require memory freeing.
- + /*byteArrayHandle=*/0,
- + /*byteArray=*/new byte[0]);
- + }
- +
- + /** Creates FrameBuffer from the {@link ByteBuffer} stored in the given {@link TensorImage}. */
- + private static FrameBufferData createFrameBufferFromByteBuffer(
- + TensorImage image, int orientation) {
- + // base_vision_api_jni.cc expects an uint8 image. Convert image of other types into uint8.
- + TensorImage imageUint8 = image.getDataType() == DataType.UINT8
- + ? image
- + : TensorImage.createFrom(image, DataType.UINT8);
- +
- + ByteBuffer byteBuffer = imageUint8.getBuffer();
- + byteBuffer.rewind();
- + ColorSpaceType colorSpaceType = image.getColorSpaceType();
- + if (byteBuffer.isDirect()) {
- + return FrameBufferData.create(
- + createFrameBufferFromByteBuffer(byteBuffer, imageUint8.getWidth(),
- + imageUint8.getHeight(), orientation, colorSpaceType.getValue()),
- + // FrameBuffer created with direct ByteBuffer does not require memory freeing.
- + /*byteArrayHandle=*/0,
- + /*byteArray=*/new byte[0]);
- + } else {
- + // If the byte array is copied in jni (during GetByteArrayElements), need to free
- + // the copied array once inference is done.
- + long[] byteArrayHandle = new long[1];
- + byte[] byteArray = getBytesFromByteBuffer(byteBuffer);
- + return FrameBufferData.create(
- + createFrameBufferFromBytes(byteArray, imageUint8.getWidth(),
- + imageUint8.getHeight(), orientation, colorSpaceType.getValue(),
- + byteArrayHandle),
- + byteArrayHandle[0], byteArray);
- + }
- + }
- +
- + /** Holds the FrameBuffer and the underlying data pointers in C++. */
- + @AutoValue
- + abstract static class FrameBufferData {
- + /**
- + * Initializes a {@link FrameBufferData} object.
- + *
- + * @param frameBufferHandle the native handle to the FrameBuffer object.
- + * @param byteArrayHandle the native handle to the data array that backs up the FrameBuffer
- + * object. If the FrameBuffer is created on a byte array, this byte array need to be
- + * freed after inference is done. If the FrameBuffer is created on a direct ByteBuffer, no
- + * byte array needs to be freed, and byteArrayHandle will be 0.
- + * @param byteArray the byte array that is used to create the c++ byte array object, which
- + * is
- + * needed when releasing byteArrayHandle. If the FrameBuffer is created on a direct
- + * ByteBuffer (no byte array needs to be freed), pass in an empty array for {@code
- + * byteArray}.
- + */
- + public static FrameBufferData create(
- + long frameBufferHandle, long byteArrayHandle, byte[] byteArray) {
- + return new AutoValue_BaseVisionTaskApi_FrameBufferData(
- + frameBufferHandle, byteArrayHandle, byteArray);
- + }
- +
- + abstract long getFrameBufferHandle();
- +
- + abstract long getByteArrayHandle();
- +
- + // Package private method for transferring data.
- + @SuppressWarnings("mutable")
- + abstract byte[] getByteArray();
- }
-
- - abstract long getFrameBufferHandle();
- -
- - abstract long getByteArrayHandle();
- -
- - // Package private method for transferring data.
- - @SuppressWarnings("mutable")
- - abstract byte[] getByteArray();
- - }
- -
- - private static native long createFrameBufferFromByteBuffer(
- - ByteBuffer image, int width, int height, int orientation, int colorSpaceType);
- -
- - private static native long createFrameBufferFromBytes(
- - byte[] image,
- - int width,
- - int height,
- - int orientation,
- - int colorSpaceType,
- - long[] byteArrayHandle);
- -
- - private static native long createFrameBufferFromPlanes(
- - ByteBuffer yBuffer,
- - ByteBuffer uBuffer,
- - ByteBuffer vBuffer,
- - int width,
- - int height,
- - int yRowStride,
- - int uvRowStride,
- - int uvPixelStride,
- - int orientation);
- -
- - private static native void deleteFrameBuffer(
- - long frameBufferHandle, long byteArrayHandle, byte[] byteArray);
- -
- - private static byte[] getBytesFromByteBuffer(ByteBuffer byteBuffer) {
- - // If the ByteBuffer has a back up array, use it directly without copy.
- - if (byteBuffer.hasArray() && byteBuffer.arrayOffset() == 0) {
- - return byteBuffer.array();
- + private static native long createFrameBufferFromByteBuffer(
- + ByteBuffer image, int width, int height, int orientation, int colorSpaceType);
- +
- + private static native long createFrameBufferFromBytes(byte[] image, int width, int height,
- + int orientation, int colorSpaceType, long[] byteArrayHandle);
- +
- + private static native long createFrameBufferFromPlanes(ByteBuffer yBuffer, ByteBuffer uBuffer,
- + ByteBuffer vBuffer, int width, int height, int yRowStride, int uvRowStride,
- + int uvPixelStride, int orientation);
- +
- + private static native void deleteFrameBuffer(
- + long frameBufferHandle, long byteArrayHandle, byte[] byteArray);
- +
- + private static byte[] getBytesFromByteBuffer(ByteBuffer byteBuffer) {
- + // If the ByteBuffer has a back up array, use it directly without copy.
- + if (byteBuffer.hasArray() && byteBuffer.arrayOffset() == 0) {
- + return byteBuffer.array();
- + }
- + // Copy out the data otherwise.
- + byteBuffer.rewind();
- + byte[] bytes = new byte[byteBuffer.limit()];
- + byteBuffer.get(bytes, 0, bytes.length);
- + return bytes;
- }
- - // Copy out the data otherwise.
- - byteBuffer.rewind();
- - byte[] bytes = new byte[byteBuffer.limit()];
- - byteBuffer.get(bytes, 0, bytes.length);
- - return bytes;
- - }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/Detection.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/Detection.java
- index 859e41fc038be..096af521c6b00 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/Detection.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/Detection.java
- @@ -16,27 +16,29 @@ limitations under the License.
- package org.tensorflow.lite.task.vision.detector;
-
- import android.graphics.RectF;
- +
- import com.google.auto.value.AutoValue;
- +
- +import org.tensorflow.lite.support.label.Category;
- +import org.tensorflow.lite.task.core.annotations.UsedByReflection;
- +
- import java.util.ArrayList;
- import java.util.Collections;
- import java.util.List;
- -import org.tensorflow.lite.support.label.Category;
- -import org.tensorflow.lite.task.core.annotations.UsedByReflection;
-
- /** Represents one detected object in the results of a {@link ObjectDetector}. */
- @AutoValue
- @UsedByReflection("object_detection_jni.cc")
- public abstract class Detection {
- + @UsedByReflection("object_detection_jni.cc")
- + public static Detection create(RectF boundingBox, List<Category> categories) {
- + return new AutoValue_Detection(new RectF(boundingBox),
- + Collections.unmodifiableList(new ArrayList<Category>(categories)));
- + }
-
- - @UsedByReflection("object_detection_jni.cc")
- - public static Detection create(RectF boundingBox, List<Category> categories) {
- - return new AutoValue_Detection(
- - new RectF(boundingBox), Collections.unmodifiableList(new ArrayList<Category>(categories)));
- - }
- -
- - public abstract RectF getBoundingBox();
- + public abstract RectF getBoundingBox();
-
- - // Same reason for not using ImmutableList as stated in
- - // {@link ObjectDetector#ObjectDetectorOptions#labelAllowList}.
- - public abstract List<Category> getCategories();
- + // Same reason for not using ImmutableList as stated in
- + // {@link ObjectDetector#ObjectDetectorOptions#labelAllowList}.
- + public abstract List<Category> getCategories();
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/ObjectDetector.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/ObjectDetector.java
- index 4aff7bfab8ca5..d1fb421fc0bbf 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/ObjectDetector.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/ObjectDetector.java
- @@ -17,14 +17,9 @@ package org.tensorflow.lite.task.vision.detector;
-
- import android.content.Context;
- import android.os.ParcelFileDescriptor;
- +
- import com.google.android.odml.image.MlImage;
- -import java.io.File;
- -import java.io.IOException;
- -import java.nio.ByteBuffer;
- -import java.nio.MappedByteBuffer;
- -import java.util.ArrayList;
- -import java.util.Collections;
- -import java.util.List;
- +
- import org.tensorflow.lite.support.image.MlImageAdapter;
- import org.tensorflow.lite.support.image.TensorImage;
- import org.tensorflow.lite.task.core.BaseOptions;
- @@ -35,6 +30,14 @@ import org.tensorflow.lite.task.core.annotations.UsedByReflection;
- import org.tensorflow.lite.task.core.vision.ImageProcessingOptions;
- import org.tensorflow.lite.task.vision.core.BaseVisionTaskApi;
-
- +import java.io.File;
- +import java.io.IOException;
- +import java.nio.ByteBuffer;
- +import java.nio.MappedByteBuffer;
- +import java.util.ArrayList;
- +import java.util.Collections;
- +import java.util.List;
- +
- /**
- * Performs object detection on images.
- *
- @@ -86,469 +89,447 @@ import org.tensorflow.lite.task.vision.core.BaseVisionTaskApi;
- * Hub.</a>.
- */
- public final class ObjectDetector extends BaseVisionTaskApi {
- + private static final String OBJECT_DETECTOR_NATIVE_LIB = "task_vision_jni";
- + private static final int OPTIONAL_FD_LENGTH = -1;
- + private static final int OPTIONAL_FD_OFFSET = -1;
- +
- + /**
- + * Creates an {@link ObjectDetector} instance from the default {@link ObjectDetectorOptions}.
- + *
- + * @param modelPath path to the detection model with metadata in the assets
- + * @throws IOException if an I/O error occurs when loading the tflite model
- + * @throws IllegalArgumentException if an argument is invalid
- + * @throws IllegalStateException if there is an internal error
- + * @throws RuntimeException if there is an otherwise unspecified error
- + */
- + public static ObjectDetector createFromFile(Context context, String modelPath)
- + throws IOException {
- + return createFromFileAndOptions(
- + context, modelPath, ObjectDetectorOptions.builder().build());
- + }
-
- - private static final String OBJECT_DETECTOR_NATIVE_LIB = "task_vision_jni";
- - private static final int OPTIONAL_FD_LENGTH = -1;
- - private static final int OPTIONAL_FD_OFFSET = -1;
- -
- - /**
- - * Creates an {@link ObjectDetector} instance from the default {@link ObjectDetectorOptions}.
- - *
- - * @param modelPath path to the detection model with metadata in the assets
- - * @throws IOException if an I/O error occurs when loading the tflite model
- - * @throws IllegalArgumentException if an argument is invalid
- - * @throws IllegalStateException if there is an internal error
- - * @throws RuntimeException if there is an otherwise unspecified error
- - */
- - public static ObjectDetector createFromFile(Context context, String modelPath)
- - throws IOException {
- - return createFromFileAndOptions(context, modelPath, ObjectDetectorOptions.builder().build());
- - }
- -
- - /**
- - * Creates an {@link ObjectDetector} instance from the default {@link ObjectDetectorOptions}.
- - *
- - * @param modelFile the detection model {@link File} instance
- - * @throws IOException if an I/O error occurs when loading the tflite model
- - * @throws IllegalArgumentException if an argument is invalid
- - * @throws IllegalStateException if there is an internal error
- - * @throws RuntimeException if there is an otherwise unspecified error
- - */
- - public static ObjectDetector createFromFile(File modelFile) throws IOException {
- - return createFromFileAndOptions(modelFile, ObjectDetectorOptions.builder().build());
- - }
- -
- - /**
- - * Creates an {@link ObjectDetector} instance with a model buffer and the default {@link
- - * ObjectDetectorOptions}.
- - *
- - * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the detection
- - * model
- - * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
- - * {@link MappedByteBuffer} * @throws IllegalStateException if there is an internal error
- - * @throws RuntimeException if there is an otherwise unspecified error
- - */
- - public static ObjectDetector createFromBuffer(final ByteBuffer modelBuffer) {
- - return createFromBufferAndOptions(modelBuffer, ObjectDetectorOptions.builder().build());
- - }
- -
- - /**
- - * Creates an {@link ObjectDetector} instance from {@link ObjectDetectorOptions}.
- - *
- - * @param modelPath path to the detection model with metadata in the assets
- - * @throws IOException if an I/O error occurs when loading the tflite model
- - * @throws IllegalArgumentException if an argument is invalid
- - * @throws IllegalStateException if there is an internal error
- - * @throws RuntimeException if there is an otherwise unspecified error
- - */
- - public static ObjectDetector createFromFileAndOptions(
- - Context context, String modelPath, ObjectDetectorOptions options) throws IOException {
- - return new ObjectDetector(
- - TaskJniUtils.createHandleFromFdAndOptions(
- - context,
- - new FdAndOptionsHandleProvider<ObjectDetectorOptions>() {
- - @Override
- - public long createHandle(
- - int fileDescriptor,
- - long fileDescriptorLength,
- - long fileDescriptorOffset,
- - ObjectDetectorOptions options) {
- - return initJniWithModelFdAndOptions(
- - fileDescriptor,
- - fileDescriptorLength,
- - fileDescriptorOffset,
- - options,
- - TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads(
- - options.getBaseOptions(), options.getNumThreads()));
- - }
- - },
- - OBJECT_DETECTOR_NATIVE_LIB,
- - modelPath,
- - options));
- - }
- -
- - /**
- - * Creates an {@link ObjectDetector} instance from {@link ObjectDetectorOptions}.
- - *
- - * @param modelFile the detection model {@link File} instance
- - * @throws IOException if an I/O error occurs when loading the tflite model
- - * @throws IllegalArgumentException if an argument is invalid
- - * @throws IllegalStateException if there is an internal error
- - * @throws RuntimeException if there is an otherwise unspecified error
- - */
- - public static ObjectDetector createFromFileAndOptions(
- - File modelFile, final ObjectDetectorOptions options) throws IOException {
- - try (ParcelFileDescriptor descriptor =
- - ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
- - return new ObjectDetector(
- - TaskJniUtils.createHandleFromLibrary(
- - new TaskJniUtils.EmptyHandleProvider() {
- - @Override
- - public long createHandle() {
- - return initJniWithModelFdAndOptions(
- - descriptor.getFd(),
- - /*fileDescriptorLength=*/ OPTIONAL_FD_LENGTH,
- - /*fileDescriptorOffset=*/ OPTIONAL_FD_OFFSET,
- - options,
- - TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads(
- - options.getBaseOptions(), options.getNumThreads()));
- - }
- - },
- - OBJECT_DETECTOR_NATIVE_LIB));
- + /**
- + * Creates an {@link ObjectDetector} instance from the default {@link ObjectDetectorOptions}.
- + *
- + * @param modelFile the detection model {@link File} instance
- + * @throws IOException if an I/O error occurs when loading the tflite model
- + * @throws IllegalArgumentException if an argument is invalid
- + * @throws IllegalStateException if there is an internal error
- + * @throws RuntimeException if there is an otherwise unspecified error
- + */
- + public static ObjectDetector createFromFile(File modelFile) throws IOException {
- + return createFromFileAndOptions(modelFile, ObjectDetectorOptions.builder().build());
- }
- - }
- -
- - /**
- - * Creates an {@link ObjectDetector} instance with a model buffer and {@link
- - * ObjectDetectorOptions}.
- - *
- - * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the detection
- - * model
- - * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
- - * {@link MappedByteBuffer}
- - * @throws IllegalStateException if there is an internal error
- - * @throws RuntimeException if there is an otherwise unspecified error
- - */
- - public static ObjectDetector createFromBufferAndOptions(
- - final ByteBuffer modelBuffer, final ObjectDetectorOptions options) {
- - if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) {
- - throw new IllegalArgumentException(
- - "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer.");
- +
- + /**
- + * Creates an {@link ObjectDetector} instance with a model buffer and the default {@link
- + * ObjectDetectorOptions}.
- + *
- + * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the detection
- + * model
- + * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
- + * {@link MappedByteBuffer} * @throws IllegalStateException if there is an internal error
- + * @throws RuntimeException if there is an otherwise unspecified error
- + */
- + public static ObjectDetector createFromBuffer(final ByteBuffer modelBuffer) {
- + return createFromBufferAndOptions(modelBuffer, ObjectDetectorOptions.builder().build());
- }
- - return new ObjectDetector(
- - TaskJniUtils.createHandleFromLibrary(
- - new EmptyHandleProvider() {
- - @Override
- - public long createHandle() {
- - return initJniWithByteBuffer(
- - modelBuffer,
- - options,
- - TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads(
- - options.getBaseOptions(), options.getNumThreads()));
- - }
- - },
- - OBJECT_DETECTOR_NATIVE_LIB));
- - }
- -
- - /**
- - * Constructor to initialize the JNI with a pointer from C++.
- - *
- - * @param nativeHandle a pointer referencing memory allocated in C++
- - */
- - private ObjectDetector(long nativeHandle) {
- - super(nativeHandle);
- - }
- -
- - /** Options for setting up an ObjectDetector. */
- - @UsedByReflection("object_detector_jni.cc")
- - public static class ObjectDetectorOptions {
- - // Not using AutoValue for this class because scoreThreshold cannot have default value
- - // (otherwise, the default value would override the one in the model metadata) and `Optional` is
- - // not an option here, because
- - // 1. java.util.Optional require Java 8 while we need to support Java 7.
- - // 2. The Guava library (com.google.common.base.Optional) is avoided in this project. See the
- - // comments for labelAllowList.
- - private final BaseOptions baseOptions;
- - private final String displayNamesLocale;
- - private final int maxResults;
- - private final float scoreThreshold;
- - private final boolean isScoreThresholdSet;
- - // As an open source project, we've been trying avoiding depending on common java libraries,
- - // such as Guava, because it may introduce conflicts with clients who also happen to use those
- - // libraries. Therefore, instead of using ImmutableList here, we convert the List into
- - // unmodifiableList in setLabelAllowList() and setLabelDenyList() to make it less
- - // vulnerable.
- - private final List<String> labelAllowList;
- - private final List<String> labelDenyList;
- - private final int numThreads;
- -
- - public static Builder builder() {
- - return new Builder();
- +
- + /**
- + * Creates an {@link ObjectDetector} instance from {@link ObjectDetectorOptions}.
- + *
- + * @param modelPath path to the detection model with metadata in the assets
- + * @throws IOException if an I/O error occurs when loading the tflite model
- + * @throws IllegalArgumentException if an argument is invalid
- + * @throws IllegalStateException if there is an internal error
- + * @throws RuntimeException if there is an otherwise unspecified error
- + */
- + public static ObjectDetector createFromFileAndOptions(
- + Context context, String modelPath, ObjectDetectorOptions options) throws IOException {
- + return new ObjectDetector(TaskJniUtils.createHandleFromFdAndOptions(
- + context, new FdAndOptionsHandleProvider<ObjectDetectorOptions>() {
- + @Override
- + public long createHandle(int fileDescriptor, long fileDescriptorLength,
- + long fileDescriptorOffset, ObjectDetectorOptions options) {
- + return initJniWithModelFdAndOptions(fileDescriptor, fileDescriptorLength,
- + fileDescriptorOffset, options,
- + TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads(
- + options.getBaseOptions(), options.getNumThreads()));
- + }
- + }, OBJECT_DETECTOR_NATIVE_LIB, modelPath, options));
- }
-
- - /** A builder that helps to configure an instance of ObjectDetectorOptions. */
- - public static class Builder {
- - private BaseOptions baseOptions = BaseOptions.builder().build();
- - private String displayNamesLocale = "en";
- - private int maxResults = -1;
- - private float scoreThreshold;
- - private boolean isScoreThresholdSet = false;
- - private List<String> labelAllowList = new ArrayList<>();
- - private List<String> labelDenyList = new ArrayList<>();
- - private int numThreads = -1;
- -
- - private Builder() {}
- -
- - /** Sets the general options to configure Task APIs, such as accelerators. */
- - public Builder setBaseOptions(BaseOptions baseOptions) {
- - this.baseOptions = baseOptions;
- - return this;
- - }
- -
- - /**
- - * Sets the locale to use for display names specified through the TFLite Model Metadata, if
- - * any.
- - *
- - * <p>Defaults to English({@code "en"}). See the <a
- - * href="https://github.com/tensorflow/tflite-support/blob/3ce83f0cfe2c68fecf83e019f2acc354aaba471f/tensorflow_lite_support/metadata/metadata_schema.fbs#L147">TFLite
- - * Metadata schema file.</a> for the accepted pattern of locale.
- - */
- - public Builder setDisplayNamesLocale(String displayNamesLocale) {
- - this.displayNamesLocale = displayNamesLocale;
- - return this;
- - }
- -
- - /**
- - * Sets the maximum number of top-scored detection results to return.
- - *
- - * <p>If < 0, all available results will be returned. If 0, an invalid argument error is
- - * returned. Note that models may intrinsically be limited to returning a maximum number of
- - * results N: if the provided value here is above N, only N results will be returned. Defaults
- - * to -1.
- - *
- - * @throws IllegalArgumentException if maxResults is 0.
- - */
- - public Builder setMaxResults(int maxResults) {
- - if (maxResults == 0) {
- - throw new IllegalArgumentException("maxResults cannot be 0.");
- + /**
- + * Creates an {@link ObjectDetector} instance from {@link ObjectDetectorOptions}.
- + *
- + * @param modelFile the detection model {@link File} instance
- + * @throws IOException if an I/O error occurs when loading the tflite model
- + * @throws IllegalArgumentException if an argument is invalid
- + * @throws IllegalStateException if there is an internal error
- + * @throws RuntimeException if there is an otherwise unspecified error
- + */
- + public static ObjectDetector createFromFileAndOptions(
- + File modelFile, final ObjectDetectorOptions options) throws IOException {
- + try (ParcelFileDescriptor descriptor =
- + ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
- + return new ObjectDetector(
- + TaskJniUtils.createHandleFromLibrary(new TaskJniUtils.EmptyHandleProvider() {
- + @Override
- + public long createHandle() {
- + return initJniWithModelFdAndOptions(descriptor.getFd(),
- + /*fileDescriptorLength=*/OPTIONAL_FD_LENGTH,
- + /*fileDescriptorOffset=*/OPTIONAL_FD_OFFSET, options,
- + TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads(
- + options.getBaseOptions(), options.getNumThreads()));
- + }
- + }, OBJECT_DETECTOR_NATIVE_LIB));
- }
- - this.maxResults = maxResults;
- - return this;
- - }
- -
- - /**
- - * Sets the score threshold that overrides the one provided in the model metadata (if any).
- - * Results below this value are rejected.
- - */
- - public Builder setScoreThreshold(float scoreThreshold) {
- - this.scoreThreshold = scoreThreshold;
- - this.isScoreThresholdSet = true;
- - return this;
- - }
- -
- - /**
- - * Sets the optional allow list of labels.
- - *
- - * <p>If non-empty, detection results whose label is not in this set will be filtered out.
- - * Duplicate or unknown labels are ignored. Mutually exclusive with {@code labelDenyList}. It
- - * will cause {@link IllegalStateException} when calling {@link #createFromFileAndOptions}, if
- - * both {@code labelDenyList} and {@code labelAllowList} are set.
- - */
- - public Builder setLabelAllowList(List<String> labelAllowList) {
- - this.labelAllowList = Collections.unmodifiableList(new ArrayList<>(labelAllowList));
- - return this;
- - }
- -
- - /**
- - * Sets the optional deny list of labels.
- - *
- - * <p>If non-empty, detection results whose label is in this set will be filtered out.
- - * Duplicate or unknown labels are ignored. Mutually exclusive with {@code labelAllowList}. It
- - * will cause {@link IllegalStateException} when calling {@link #createFromFileAndOptions}, if
- - * both {@code labelDenyList} and {@code labelAllowList} are set.
- - */
- - public Builder setLabelDenyList(List<String> labelDenyList) {
- - this.labelDenyList = Collections.unmodifiableList(new ArrayList<>(labelDenyList));
- - return this;
- - }
- -
- - /**
- - * Sets the number of threads to be used for TFLite ops that support multi-threading when
- - * running inference with CPU. Defaults to -1.
- - *
- - * <p>numThreads should be greater than 0 or equal to -1. Setting numThreads to -1 has the
- - * effect to let TFLite runtime set the value.
- - *
- - * @deprecated use {@link BaseOptions} to configure number of threads instead. This method
- - * will override the number of threads configured from {@link BaseOptions}.
- - */
- - @Deprecated
- - public Builder setNumThreads(int numThreads) {
- - this.numThreads = numThreads;
- - return this;
- - }
- -
- - public ObjectDetectorOptions build() {
- - return new ObjectDetectorOptions(this);
- - }
- }
-
- - @UsedByReflection("object_detector_jni.cc")
- - public String getDisplayNamesLocale() {
- - return displayNamesLocale;
- + /**
- + * Creates an {@link ObjectDetector} instance with a model buffer and {@link
- + * ObjectDetectorOptions}.
- + *
- + * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the detection
- + * model
- + * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
- + * {@link MappedByteBuffer}
- + * @throws IllegalStateException if there is an internal error
- + * @throws RuntimeException if there is an otherwise unspecified error
- + */
- + public static ObjectDetector createFromBufferAndOptions(
- + final ByteBuffer modelBuffer, final ObjectDetectorOptions options) {
- + if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) {
- + throw new IllegalArgumentException(
- + "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer.");
- + }
- + return new ObjectDetector(TaskJniUtils.createHandleFromLibrary(new EmptyHandleProvider() {
- + @Override
- + public long createHandle() {
- + return initJniWithByteBuffer(modelBuffer, options,
- + TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads(
- + options.getBaseOptions(), options.getNumThreads()));
- + }
- + }, OBJECT_DETECTOR_NATIVE_LIB));
- }
-
- - @UsedByReflection("object_detector_jni.cc")
- - public int getMaxResults() {
- - return maxResults;
- + /**
- + * Constructor to initialize the JNI with a pointer from C++.
- + *
- + * @param nativeHandle a pointer referencing memory allocated in C++
- + */
- + private ObjectDetector(long nativeHandle) {
- + super(nativeHandle);
- }
-
- + /** Options for setting up an ObjectDetector. */
- @UsedByReflection("object_detector_jni.cc")
- - public float getScoreThreshold() {
- - return scoreThreshold;
- + public static class ObjectDetectorOptions {
- + // Not using AutoValue for this class because scoreThreshold cannot have default value
- + // (otherwise, the default value would override the one in the model metadata) and
- + // `Optional` is not an option here, because
- + // 1. java.util.Optional require Java 8 while we need to support Java 7.
- + // 2. The Guava library (com.google.common.base.Optional) is avoided in this project. See
- + // the comments for labelAllowList.
- + private final BaseOptions baseOptions;
- + private final String displayNamesLocale;
- + private final int maxResults;
- + private final float scoreThreshold;
- + private final boolean isScoreThresholdSet;
- + // As an open source project, we've been trying avoiding depending on common java libraries,
- + // such as Guava, because it may introduce conflicts with clients who also happen to use
- + // those libraries. Therefore, instead of using ImmutableList here, we convert the List into
- + // unmodifiableList in setLabelAllowList() and setLabelDenyList() to make it less
- + // vulnerable.
- + private final List<String> labelAllowList;
- + private final List<String> labelDenyList;
- + private final int numThreads;
- +
- + public static Builder builder() {
- + return new Builder();
- + }
- +
- + /** A builder that helps to configure an instance of ObjectDetectorOptions. */
- + public static class Builder {
- + private BaseOptions baseOptions = BaseOptions.builder().build();
- + private String displayNamesLocale = "en";
- + private int maxResults = -1;
- + private float scoreThreshold;
- + private boolean isScoreThresholdSet = false;
- + private List<String> labelAllowList = new ArrayList<>();
- + private List<String> labelDenyList = new ArrayList<>();
- + private int numThreads = -1;
- +
- + private Builder() {}
- +
- + /** Sets the general options to configure Task APIs, such as accelerators. */
- + public Builder setBaseOptions(BaseOptions baseOptions) {
- + this.baseOptions = baseOptions;
- + return this;
- + }
- +
- + /**
- + * Sets the locale to use for display names specified through the TFLite Model Metadata,
- + * if any.
- + *
- + * <p>Defaults to English({@code "en"}). See the <a
- + * href="https://github.com/tensorflow/tflite-support/blob/3ce83f0cfe2c68fecf83e019f2acc354aaba471f/tensorflow_lite_support/metadata/metadata_schema.fbs#L147">TFLite
- + * Metadata schema file.</a> for the accepted pattern of locale.
- + */
- + public Builder setDisplayNamesLocale(String displayNamesLocale) {
- + this.displayNamesLocale = displayNamesLocale;
- + return this;
- + }
- +
- + /**
- + * Sets the maximum number of top-scored detection results to return.
- + *
- + * <p>If < 0, all available results will be returned. If 0, an invalid argument error is
- + * returned. Note that models may intrinsically be limited to returning a maximum number
- + * of results N: if the provided value here is above N, only N results will be returned.
- + * Defaults to -1.
- + *
- + * @throws IllegalArgumentException if maxResults is 0.
- + */
- + public Builder setMaxResults(int maxResults) {
- + if (maxResults == 0) {
- + throw new IllegalArgumentException("maxResults cannot be 0.");
- + }
- + this.maxResults = maxResults;
- + return this;
- + }
- +
- + /**
- + * Sets the score threshold that overrides the one provided in the model metadata (if
- + * any). Results below this value are rejected.
- + */
- + public Builder setScoreThreshold(float scoreThreshold) {
- + this.scoreThreshold = scoreThreshold;
- + this.isScoreThresholdSet = true;
- + return this;
- + }
- +
- + /**
- + * Sets the optional allow list of labels.
- + *
- + * <p>If non-empty, detection results whose label is not in this set will be filtered
- + * out. Duplicate or unknown labels are ignored. Mutually exclusive with {@code
- + * labelDenyList}. It will cause {@link IllegalStateException} when calling {@link
- + * #createFromFileAndOptions}, if both {@code labelDenyList} and {@code labelAllowList}
- + * are set.
- + */
- + public Builder setLabelAllowList(List<String> labelAllowList) {
- + this.labelAllowList = Collections.unmodifiableList(new ArrayList<>(labelAllowList));
- + return this;
- + }
- +
- + /**
- + * Sets the optional deny list of labels.
- + *
- + * <p>If non-empty, detection results whose label is in this set will be filtered out.
- + * Duplicate or unknown labels are ignored. Mutually exclusive with {@code
- + * labelAllowList}. It will cause {@link IllegalStateException} when calling {@link
- + * #createFromFileAndOptions}, if both {@code labelDenyList} and {@code labelAllowList}
- + * are set.
- + */
- + public Builder setLabelDenyList(List<String> labelDenyList) {
- + this.labelDenyList = Collections.unmodifiableList(new ArrayList<>(labelDenyList));
- + return this;
- + }
- +
- + /**
- + * Sets the number of threads to be used for TFLite ops that support multi-threading
- + * when running inference with CPU. Defaults to -1.
- + *
- + * <p>numThreads should be greater than 0 or equal to -1. Setting numThreads to -1 has
- + * the effect to let TFLite runtime set the value.
- + *
- + * @deprecated use {@link BaseOptions} to configure number of threads instead. This
- + * method
- + * will override the number of threads configured from {@link BaseOptions}.
- + */
- + @Deprecated
- + public Builder setNumThreads(int numThreads) {
- + this.numThreads = numThreads;
- + return this;
- + }
- +
- + public ObjectDetectorOptions build() {
- + return new ObjectDetectorOptions(this);
- + }
- + }
- +
- + @UsedByReflection("object_detector_jni.cc")
- + public String getDisplayNamesLocale() {
- + return displayNamesLocale;
- + }
- +
- + @UsedByReflection("object_detector_jni.cc")
- + public int getMaxResults() {
- + return maxResults;
- + }
- +
- + @UsedByReflection("object_detector_jni.cc")
- + public float getScoreThreshold() {
- + return scoreThreshold;
- + }
- +
- + @UsedByReflection("object_detector_jni.cc")
- + public boolean getIsScoreThresholdSet() {
- + return isScoreThresholdSet;
- + }
- +
- + @UsedByReflection("object_detector_jni.cc")
- + public List<String> getLabelAllowList() {
- + return new ArrayList<>(labelAllowList);
- + }
- +
- + @UsedByReflection("object_detector_jni.cc")
- + public List<String> getLabelDenyList() {
- + return new ArrayList<>(labelDenyList);
- + }
- +
- + @UsedByReflection("object_detector_jni.cc")
- + public int getNumThreads() {
- + return numThreads;
- + }
- +
- + public BaseOptions getBaseOptions() {
- + return baseOptions;
- + }
- +
- + private ObjectDetectorOptions(Builder builder) {
- + displayNamesLocale = builder.displayNamesLocale;
- + maxResults = builder.maxResults;
- + scoreThreshold = builder.scoreThreshold;
- + isScoreThresholdSet = builder.isScoreThresholdSet;
- + labelAllowList = builder.labelAllowList;
- + labelDenyList = builder.labelDenyList;
- + numThreads = builder.numThreads;
- + baseOptions = builder.baseOptions;
- + }
- }
-
- - @UsedByReflection("object_detector_jni.cc")
- - public boolean getIsScoreThresholdSet() {
- - return isScoreThresholdSet;
- + /**
- + * Performs actual detection on the provided image.
- + *
- + * <p>{@link ObjectDetector} supports the following {@link TensorImage} color space types:
- + *
- + * <ul>
- + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#RGB}
- + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV12}
- + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV21}
- + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV12}
- + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV21}
- + * </ul>
- + *
- + * @param image a UINT8 {@link TensorImage} object that represents an RGB or YUV image
- + * @throws IllegalStateException if there is an internal error
- + * @throws RuntimeException if there is an otherwise unspecified error
- + * @throws IllegalArgumentException if the color space type of image is unsupported
- + */
- + public List<Detection> detect(TensorImage image) {
- + return detect(image, ImageProcessingOptions.builder().build());
- }
-
- - @UsedByReflection("object_detector_jni.cc")
- - public List<String> getLabelAllowList() {
- - return new ArrayList<>(labelAllowList);
- + /**
- + * Performs actual detection on the provided image.
- + *
- + * <p>{@link ObjectDetector} supports the following {@link TensorImage} color space types:
- + *
- + * <ul>
- + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#RGB}
- + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV12}
- + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV21}
- + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV12}
- + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV21}
- + * </ul>
- + *
- + * <p>{@link ObjectDetector} supports the following options:
- + *
- + * <ul>
- + * <li>image rotation (through {@link ImageProcessingOptions.Builder#setOrientation}). It
- + * defaults to {@link ImageProcessingOptions.Orientation#TOP_LEFT}.
- + * </ul>
- + *
- + * @param image a UINT8 {@link TensorImage} object that represents an RGB or YUV image
- + * @param options the options to configure how to preprocess the image
- + * @throws IllegalStateException if there is an internal error
- + * @throws RuntimeException if there is an otherwise unspecified error
- + * @throws IllegalArgumentException if the color space type of image is unsupported
- + */
- + public List<Detection> detect(TensorImage image, ImageProcessingOptions options) {
- + return run(new InferenceProvider<List<Detection>>() {
- + @Override
- + public List<Detection> run(
- + long frameBufferHandle, int width, int height, ImageProcessingOptions options) {
- + return detect(frameBufferHandle, options);
- + }
- + }, image, options);
- }
-
- - @UsedByReflection("object_detector_jni.cc")
- - public List<String> getLabelDenyList() {
- - return new ArrayList<>(labelDenyList);
- + /**
- + * Performs actual detection on the provided {@code MlImage}.
- + *
- + * @param image an {@code MlImage} object that represents an image
- + * @throws IllegalStateException if there is an internal error
- + * @throws RuntimeException if there is an otherwise unspecified error
- + * @throws IllegalArgumentException if the storage type or format of the image is unsupported
- + */
- + public List<Detection> detect(MlImage image) {
- + return detect(image, ImageProcessingOptions.builder().build());
- }
-
- - @UsedByReflection("object_detector_jni.cc")
- - public int getNumThreads() {
- - return numThreads;
- + /**
- + * Performs actual detection on the provided {@code MlImage} with {@link
- + * ImageProcessingOptions}.
- + *
- + * <p>{@link ObjectDetector} supports the following options:
- + *
- + * <ul>
- + * <li>image rotation (through {@link ImageProcessingOptions.Builder#setOrientation}). It
- + * defaults to {@link ImageProcessingOptions.Orientation#TOP_LEFT}. {@link
- + * MlImage#getRotation()} is not effective.
- + * </ul>
- + *
- + * @param image an {@code MlImage} object that represents an image
- + * @param options the options to configure how to preprocess the image
- + * @throws IllegalStateException if there is an internal error
- + * @throws RuntimeException if there is an otherwise unspecified error
- + * @throws IllegalArgumentException if the storage type or format of the image is unsupported
- + */
- + public List<Detection> detect(MlImage image, ImageProcessingOptions options) {
- + image.getInternal().acquire();
- + TensorImage tensorImage = MlImageAdapter.createTensorImageFrom(image);
- + List<Detection> result = detect(tensorImage, options);
- + image.close();
- + return result;
- }
-
- - public BaseOptions getBaseOptions() {
- - return baseOptions;
- + private List<Detection> detect(long frameBufferHandle, ImageProcessingOptions options) {
- + checkNotClosed();
- +
- + return detectNative(getNativeHandle(), frameBufferHandle);
- }
-
- - private ObjectDetectorOptions(Builder builder) {
- - displayNamesLocale = builder.displayNamesLocale;
- - maxResults = builder.maxResults;
- - scoreThreshold = builder.scoreThreshold;
- - isScoreThresholdSet = builder.isScoreThresholdSet;
- - labelAllowList = builder.labelAllowList;
- - labelDenyList = builder.labelDenyList;
- - numThreads = builder.numThreads;
- - baseOptions = builder.baseOptions;
- + private static native long initJniWithModelFdAndOptions(int fileDescriptor,
- + long fileDescriptorLength, long fileDescriptorOffset, ObjectDetectorOptions options,
- + long baseOptionsHandle);
- +
- + private static native long initJniWithByteBuffer(
- + ByteBuffer modelBuffer, ObjectDetectorOptions options, long baseOptionsHandle);
- +
- + private static native List<Detection> detectNative(long nativeHandle, long frameBufferHandle);
- +
- + @Override
- + protected void deinit(long nativeHandle) {
- + deinitJni(nativeHandle);
- }
- - }
- -
- - /**
- - * Performs actual detection on the provided image.
- - *
- - * <p>{@link ObjectDetector} supports the following {@link TensorImage} color space types:
- - *
- - * <ul>
- - * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#RGB}
- - * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV12}
- - * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV21}
- - * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV12}
- - * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV21}
- - * </ul>
- - *
- - * @param image a UINT8 {@link TensorImage} object that represents an RGB or YUV image
- - * @throws IllegalStateException if there is an internal error
- - * @throws RuntimeException if there is an otherwise unspecified error
- - * @throws IllegalArgumentException if the color space type of image is unsupported
- - */
- - public List<Detection> detect(TensorImage image) {
- - return detect(image, ImageProcessingOptions.builder().build());
- - }
- -
- - /**
- - * Performs actual detection on the provided image.
- - *
- - * <p>{@link ObjectDetector} supports the following {@link TensorImage} color space types:
- - *
- - * <ul>
- - * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#RGB}
- - * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV12}
- - * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV21}
- - * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV12}
- - * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV21}
- - * </ul>
- - *
- - * <p>{@link ObjectDetector} supports the following options:
- - *
- - * <ul>
- - * <li>image rotation (through {@link ImageProcessingOptions.Builder#setOrientation}). It
- - * defaults to {@link ImageProcessingOptions.Orientation#TOP_LEFT}.
- - * </ul>
- - *
- - * @param image a UINT8 {@link TensorImage} object that represents an RGB or YUV image
- - * @param options the options to configure how to preprocess the image
- - * @throws IllegalStateException if there is an internal error
- - * @throws RuntimeException if there is an otherwise unspecified error
- - * @throws IllegalArgumentException if the color space type of image is unsupported
- - */
- - public List<Detection> detect(TensorImage image, ImageProcessingOptions options) {
- - return run(
- - new InferenceProvider<List<Detection>>() {
- - @Override
- - public List<Detection> run(
- - long frameBufferHandle, int width, int height, ImageProcessingOptions options) {
- - return detect(frameBufferHandle, options);
- - }
- - },
- - image,
- - options);
- - }
- -
- - /**
- - * Performs actual detection on the provided {@code MlImage}.
- - *
- - * @param image an {@code MlImage} object that represents an image
- - * @throws IllegalStateException if there is an internal error
- - * @throws RuntimeException if there is an otherwise unspecified error
- - * @throws IllegalArgumentException if the storage type or format of the image is unsupported
- - */
- - public List<Detection> detect(MlImage image) {
- - return detect(image, ImageProcessingOptions.builder().build());
- - }
- -
- - /**
- - * Performs actual detection on the provided {@code MlImage} with {@link ImageProcessingOptions}.
- - *
- - * <p>{@link ObjectDetector} supports the following options:
- - *
- - * <ul>
- - * <li>image rotation (through {@link ImageProcessingOptions.Builder#setOrientation}). It
- - * defaults to {@link ImageProcessingOptions.Orientation#TOP_LEFT}. {@link
- - * MlImage#getRotation()} is not effective.
- - * </ul>
- - *
- - * @param image an {@code MlImage} object that represents an image
- - * @param options the options to configure how to preprocess the image
- - * @throws IllegalStateException if there is an internal error
- - * @throws RuntimeException if there is an otherwise unspecified error
- - * @throws IllegalArgumentException if the storage type or format of the image is unsupported
- - */
- - public List<Detection> detect(MlImage image, ImageProcessingOptions options) {
- - image.getInternal().acquire();
- - TensorImage tensorImage = MlImageAdapter.createTensorImageFrom(image);
- - List<Detection> result = detect(tensorImage, options);
- - image.close();
- - return result;
- - }
- -
- - private List<Detection> detect(long frameBufferHandle, ImageProcessingOptions options) {
- - checkNotClosed();
- -
- - return detectNative(getNativeHandle(), frameBufferHandle);
- - }
- -
- - private static native long initJniWithModelFdAndOptions(
- - int fileDescriptor,
- - long fileDescriptorLength,
- - long fileDescriptorOffset,
- - ObjectDetectorOptions options,
- - long baseOptionsHandle);
- -
- - private static native long initJniWithByteBuffer(
- - ByteBuffer modelBuffer, ObjectDetectorOptions options, long baseOptionsHandle);
- -
- - private static native List<Detection> detectNative(long nativeHandle, long frameBufferHandle);
- -
- - @Override
- - protected void deinit(long nativeHandle) {
- - deinitJni(nativeHandle);
- - }
- -
- - /**
- - * Native implementation to release memory pointed by the pointer.
- - *
- - * @param nativeHandle pointer to memory allocated
- - */
- - private native void deinitJni(long nativeHandle);
- +
- + /**
- + * Native implementation to release memory pointed by the pointer.
- + *
- + * @param nativeHandle pointer to memory allocated
- + */
- + private native void deinitJni(long nativeHandle);
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/searcher/ImageSearcher.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/searcher/ImageSearcher.java
- index 7a02ad8a037a2..d3d1e6a4f4878 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/searcher/ImageSearcher.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/searcher/ImageSearcher.java
- @@ -19,13 +19,10 @@ import android.content.Context;
- import android.content.res.AssetFileDescriptor;
- import android.graphics.Rect;
- import android.os.ParcelFileDescriptor;
- +
- import com.google.android.odml.image.MlImage;
- import com.google.auto.value.AutoValue;
- -import java.io.File;
- -import java.io.IOException;
- -import java.nio.ByteBuffer;
- -import java.nio.MappedByteBuffer;
- -import java.util.List;
- +
- import org.tensorflow.lite.support.image.MlImageAdapter;
- import org.tensorflow.lite.support.image.TensorImage;
- import org.tensorflow.lite.task.core.BaseOptions;
- @@ -37,6 +34,12 @@ import org.tensorflow.lite.task.processor.SearcherOptions;
- import org.tensorflow.lite.task.vision.core.BaseVisionTaskApi;
- import org.tensorflow.lite.task.vision.core.BaseVisionTaskApi.InferenceProvider;
-
- +import java.io.File;
- +import java.io.IOException;
- +import java.nio.ByteBuffer;
- +import java.nio.MappedByteBuffer;
- +import java.util.List;
- +
- /**
- * Performs similarity search on images.
- *
- @@ -66,330 +69,292 @@ import org.tensorflow.lite.task.vision.core.BaseVisionTaskApi.InferenceProvider;
- * the single file format (index file packed in the model) is supported.
- */
- public final class ImageSearcher extends BaseVisionTaskApi {
- + private static final String IMAGE_SEARCHER_NATIVE_LIB = "task_vision_jni";
- + private static final int OPTIONAL_FD_LENGTH = -1;
- + private static final int OPTIONAL_FD_OFFSET = -1;
-
- - private static final String IMAGE_SEARCHER_NATIVE_LIB = "task_vision_jni";
- - private static final int OPTIONAL_FD_LENGTH = -1;
- - private static final int OPTIONAL_FD_OFFSET = -1;
- -
- - /**
- - * Creates an {@link ImageSearcher} instance from {@link ImageSearcherOptions}.
- - *
- - * @param modelPath path of the search model with metadata in the assets
- - * @throws IOException if an I/O error occurs when loading the tflite model or the index file
- - * @throws IllegalArgumentException if an argument is invalid
- - * @throws IllegalStateException if there is an internal error
- - * @throws RuntimeException if there is an otherwise unspecified error
- - */
- - public static ImageSearcher createFromFileAndOptions(
- - Context context, String modelPath, final ImageSearcherOptions options) throws IOException {
- - try (AssetFileDescriptor assetFileDescriptor = context.getAssets().openFd(modelPath)) {
- - return createFromModelFdAndOptions(
- - /*modelDescriptor=*/ assetFileDescriptor.getParcelFileDescriptor().getFd(),
- - /*modelDescriptorLength=*/ assetFileDescriptor.getLength(),
- - /*modelDescriptorOffset=*/ assetFileDescriptor.getStartOffset(),
- - options);
- + /**
- + * Creates an {@link ImageSearcher} instance from {@link ImageSearcherOptions}.
- + *
- + * @param modelPath path of the search model with metadata in the assets
- + * @throws IOException if an I/O error occurs when loading the tflite model or the index file
- + * @throws IllegalArgumentException if an argument is invalid
- + * @throws IllegalStateException if there is an internal error
- + * @throws RuntimeException if there is an otherwise unspecified error
- + */
- + public static ImageSearcher createFromFileAndOptions(Context context, String modelPath,
- + final ImageSearcherOptions options) throws IOException {
- + try (AssetFileDescriptor assetFileDescriptor = context.getAssets().openFd(modelPath)) {
- + return createFromModelFdAndOptions(
- + /*modelDescriptor=*/assetFileDescriptor.getParcelFileDescriptor().getFd(),
- + /*modelDescriptorLength=*/assetFileDescriptor.getLength(),
- + /*modelDescriptorOffset=*/assetFileDescriptor.getStartOffset(), options);
- + }
- }
- - }
- -
- - /**
- - * Creates an {@link ImageSearcher} instance.
- - *
- - * @param modelFile the search model {@link File} instance
- - * @throws IOException if an I/O error occurs when loading the tflite model or the index file
- - * @throws IllegalArgumentException if an argument is invalid
- - * @throws IllegalStateException if there is an internal error
- - * @throws RuntimeException if there is an otherwise unspecified error
- - */
- - public static ImageSearcher createFromFileAndOptions(
- - File modelFile, final ImageSearcherOptions options) throws IOException {
- - try (ParcelFileDescriptor descriptor =
- - ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
- - return createFromModelFdAndOptions(
- - /*modelDescriptor=*/ descriptor.getFd(),
- - /*modelDescriptorLength=*/ OPTIONAL_FD_LENGTH,
- - /*modelDescriptorOffset=*/ OPTIONAL_FD_OFFSET,
- - options);
- +
- + /**
- + * Creates an {@link ImageSearcher} instance.
- + *
- + * @param modelFile the search model {@link File} instance
- + * @throws IOException if an I/O error occurs when loading the tflite model or the index file
- + * @throws IllegalArgumentException if an argument is invalid
- + * @throws IllegalStateException if there is an internal error
- + * @throws RuntimeException if there is an otherwise unspecified error
- + */
- + public static ImageSearcher createFromFileAndOptions(
- + File modelFile, final ImageSearcherOptions options) throws IOException {
- + try (ParcelFileDescriptor descriptor =
- + ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
- + return createFromModelFdAndOptions(
- + /*modelDescriptor=*/descriptor.getFd(),
- + /*modelDescriptorLength=*/OPTIONAL_FD_LENGTH,
- + /*modelDescriptorOffset=*/OPTIONAL_FD_OFFSET, options);
- + }
- }
- - }
- -
- - /**
- - * Creates an {@link ImageSearcher} instance with a model buffer and {@link ImageSearcherOptions}.
- - *
- - * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the search
- - * model
- - * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
- - * {@link MappedByteBuffer}
- - * @throws IOException if an I/O error occurs when loading the index file
- - * @throws IllegalStateException if there is an internal error
- - * @throws RuntimeException if there is an otherwise unspecified error
- - */
- - public static ImageSearcher createFromBufferAndOptions(
- - final ByteBuffer modelBuffer, final ImageSearcherOptions options) throws IOException {
- - if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) {
- - throw new IllegalArgumentException(
- - "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer.");
- +
- + /**
- + * Creates an {@link ImageSearcher} instance with a model buffer and {@link
- + * ImageSearcherOptions}.
- + *
- + * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the search
- + * model
- + * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
- + * {@link MappedByteBuffer}
- + * @throws IOException if an I/O error occurs when loading the index file
- + * @throws IllegalStateException if there is an internal error
- + * @throws RuntimeException if there is an otherwise unspecified error
- + */
- + public static ImageSearcher createFromBufferAndOptions(
- + final ByteBuffer modelBuffer, final ImageSearcherOptions options) throws IOException {
- + if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) {
- + throw new IllegalArgumentException(
- + "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer.");
- + }
- + if (options.getSearcherOptions().getIndexFile() != null) {
- + try (ParcelFileDescriptor indexDescriptor =
- + ParcelFileDescriptor.open(options.getSearcherOptions().getIndexFile(),
- + ParcelFileDescriptor.MODE_READ_ONLY)) {
- + return createFromBufferAndOptionsImpl(
- + modelBuffer, options, indexDescriptor.getFd());
- + }
- + } else {
- + return createFromBufferAndOptionsImpl(modelBuffer, options, /*indexFd=*/0);
- + }
- }
- - if (options.getSearcherOptions().getIndexFile() != null) {
- - try (ParcelFileDescriptor indexDescriptor =
- - ParcelFileDescriptor.open(
- - options.getSearcherOptions().getIndexFile(), ParcelFileDescriptor.MODE_READ_ONLY)) {
- - return createFromBufferAndOptionsImpl(modelBuffer, options, indexDescriptor.getFd());
- - }
- - } else {
- - return createFromBufferAndOptionsImpl(modelBuffer, options, /*indexFd=*/ 0);
- +
- + public static ImageSearcher createFromBufferAndOptionsImpl(
- + final ByteBuffer modelBuffer, final ImageSearcherOptions options, final int indexFd) {
- + return new ImageSearcher(TaskJniUtils.createHandleFromLibrary(new EmptyHandleProvider() {
- + @Override
- + public long createHandle() {
- + return initJniWithByteBuffer(modelBuffer,
- + TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions()),
- + options.getSearcherOptions().getL2Normalize(),
- + options.getSearcherOptions().getQuantize(), indexFd,
- + options.getSearcherOptions().getMaxResults());
- + }
- + }, IMAGE_SEARCHER_NATIVE_LIB));
- }
- - }
- -
- - public static ImageSearcher createFromBufferAndOptionsImpl(
- - final ByteBuffer modelBuffer, final ImageSearcherOptions options, final int indexFd) {
- - return new ImageSearcher(
- - TaskJniUtils.createHandleFromLibrary(
- - new EmptyHandleProvider() {
- - @Override
- - public long createHandle() {
- - return initJniWithByteBuffer(
- - modelBuffer,
- - TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions()),
- - options.getSearcherOptions().getL2Normalize(),
- - options.getSearcherOptions().getQuantize(),
- - indexFd,
- - options.getSearcherOptions().getMaxResults());
- - }
- - },
- - IMAGE_SEARCHER_NATIVE_LIB));
- - }
- -
- - /**
- - * Constructor to initialize the JNI with a pointer from C++.
- - *
- - * @param nativeHandle a pointer referencing memory allocated in C++
- - */
- - ImageSearcher(long nativeHandle) {
- - super(nativeHandle);
- - }
- -
- - /** Options for setting up an ImageSearcher. */
- - @AutoValue
- - public abstract static class ImageSearcherOptions {
- -
- - abstract BaseOptions getBaseOptions();
- -
- - abstract SearcherOptions getSearcherOptions();
- -
- - public static Builder builder() {
- - return new AutoValue_ImageSearcher_ImageSearcherOptions.Builder()
- - .setBaseOptions(BaseOptions.builder().build())
- - .setSearcherOptions(SearcherOptions.builder().build());
- +
- + /**
- + * Constructor to initialize the JNI with a pointer from C++.
- + *
- + * @param nativeHandle a pointer referencing memory allocated in C++
- + */
- + ImageSearcher(long nativeHandle) {
- + super(nativeHandle);
- }
-
- - /** Builder for {@link ImageSearcherOptions}. */
- - @AutoValue.Builder
- - public abstract static class Builder {
- - /** Sets the general options to configure Task APIs, such as accelerators. */
- - public abstract Builder setBaseOptions(BaseOptions baseOptions);
- + /** Options for setting up an ImageSearcher. */
- + @AutoValue
- + public abstract static class ImageSearcherOptions {
- + abstract BaseOptions getBaseOptions();
- +
- + abstract SearcherOptions getSearcherOptions();
- +
- + public static Builder builder() {
- + return new AutoValue_ImageSearcher_ImageSearcherOptions.Builder()
- + .setBaseOptions(BaseOptions.builder().build())
- + .setSearcherOptions(SearcherOptions.builder().build());
- + }
- +
- + /** Builder for {@link ImageSearcherOptions}. */
- + @AutoValue.Builder
- + public abstract static class Builder {
- + /** Sets the general options to configure Task APIs, such as accelerators. */
- + public abstract Builder setBaseOptions(BaseOptions baseOptions);
-
- - /** Sets the options to configure Searcher API. */
- - public abstract Builder setSearcherOptions(SearcherOptions searcherOptions);
- + /** Sets the options to configure Searcher API. */
- + public abstract Builder setSearcherOptions(SearcherOptions searcherOptions);
-
- - public abstract ImageSearcherOptions build();
- + public abstract ImageSearcherOptions build();
- + }
- }
- - }
- -
- - /**
- - * Performs embedding extraction on the provided {@link TensorImage}, followed by nearest-neighbor
- - * search in the index.
- - *
- - * <p>{@link ImageSearcher} supports the following {@link TensorImage} color space types:
- - *
- - * <ul>
- - * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#RGB}
- - * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV12}
- - * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV21}
- - * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV12}
- - * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV21}
- - * </ul>
- - *
- - * @param image a UINT8 {@link TensorImage} object that represents an RGB or YUV image
- - * @throws IllegalArgumentException if the color space type of image is unsupported
- - */
- - public List<NearestNeighbor> search(TensorImage image) {
- - return search(image, ImageProcessingOptions.builder().build());
- - }
- -
- - /**
- - * Performs embedding extraction on the provided {@link TensorImage} with {@link
- - * ImageProcessingOptions}, followed by nearest-neighbor search in the index.
- - *
- - * <p>{@link ImageSearcher} supports the following options:
- - *
- - * <ul>
- - * <li>Region of interest (ROI) (through {@link ImageProcessingOptions.Builder#setRoi}). It
- - * defaults to the entire image.
- - * <li>image rotation (through {@link ImageProcessingOptions.Builder#setOrientation}). It
- - * defaults to {@link ImageProcessingOptions.Orientation#TOP_LEFT}.
- - * </ul>
- - *
- - * <p>{@link ImageSearcher} supports the following {@link TensorImage} color space types:
- - *
- - * <ul>
- - * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#RGB}
- - * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV12}
- - * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV21}
- - * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV12}
- - * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV21}
- - * </ul>
- - *
- - * @param image a UINT8 {@link TensorImage} object that represents an RGB or YUV image
- - * @throws IllegalArgumentException if the color space type of image is unsupported
- - */
- - public List<NearestNeighbor> search(TensorImage image, ImageProcessingOptions options) {
- - return run(
- - new InferenceProvider<List<NearestNeighbor>>() {
- - @Override
- - public List<NearestNeighbor> run(
- - long frameBufferHandle, int width, int height, ImageProcessingOptions options) {
- - return search(frameBufferHandle, width, height, options);
- - }
- - },
- - image,
- - options);
- - }
- -
- - /**
- - * Performs embedding extraction on the provided {@code MlImage}, followed by nearest-neighbor
- - * search in the index.
- - *
- - * @param image an {@code MlImage} object that represents an image
- - * @throws IllegalArgumentException if the storage type or format of the image is unsupported
- - */
- - public List<NearestNeighbor> search(MlImage image) {
- - return search(image, ImageProcessingOptions.builder().build());
- - }
- -
- - /**
- - * Performs embedding extraction on the provided {@code MlImage} with {@link
- - * ImageProcessingOptions}, followed by nearest-neighbor search in the index.
- - *
- - * <p>{@link ImageSearcher} supports the following options:
- - *
- - * <ul>
- - * <li>Region of interest (ROI) (through {@link ImageProcessingOptions.Builder#setRoi}). It
- - * defaults to the entire image.
- - * <li>image rotation (through {@link ImageProcessingOptions.Builder#setOrientation}). It
- - * defaults to {@link ImageProcessingOptions.Orientation#TOP_LEFT}. {@link
- - * MlImage#getRotation()} is not effective.
- - * </ul>
- - *
- - * @param image a {@code MlImage} object that represents an image
- - * @param options configures options including ROI and rotation
- - * @throws IllegalArgumentException if the storage type or format of the image is unsupported
- - */
- - public List<NearestNeighbor> search(MlImage image, ImageProcessingOptions options) {
- - image.getInternal().acquire();
- - TensorImage tensorImage = MlImageAdapter.createTensorImageFrom(image);
- - List<NearestNeighbor> result = search(tensorImage, options);
- - image.close();
- - return result;
- - }
- -
- - private List<NearestNeighbor> search(
- - long frameBufferHandle, int width, int height, ImageProcessingOptions options) {
- - checkNotClosed();
- - Rect roi = options.getRoi().isEmpty() ? new Rect(0, 0, width, height) : options.getRoi();
- - return searchNative(
- - getNativeHandle(),
- - frameBufferHandle,
- - new int[] {roi.left, roi.top, roi.width(), roi.height()});
- - }
- -
- - private static ImageSearcher createFromModelFdAndOptions(
- - final int modelDescriptor,
- - final long modelDescriptorLength,
- - final long modelDescriptorOffset,
- - final ImageSearcherOptions options)
- - throws IOException {
- - if (options.getSearcherOptions().getIndexFile() != null) {
- - // indexDescriptor must be alive before ImageSearcher is initialized completely in the native
- - // layer.
- - try (ParcelFileDescriptor indexDescriptor =
- - ParcelFileDescriptor.open(
- - options.getSearcherOptions().getIndexFile(), ParcelFileDescriptor.MODE_READ_ONLY)) {
- - return createFromModelFdAndOptionsImpl(
- - modelDescriptor,
- - modelDescriptorLength,
- - modelDescriptorOffset,
- - options,
- - indexDescriptor.getFd());
- - }
- - } else {
- - // Index file is not configured. We'll check if the model contains one in the native layer.
- - return createFromModelFdAndOptionsImpl(
- - modelDescriptor, modelDescriptorLength, modelDescriptorOffset, options, /*indexFd=*/ 0);
- +
- + /**
- + * Performs embedding extraction on the provided {@link TensorImage}, followed by
- + * nearest-neighbor search in the index.
- + *
- + * <p>{@link ImageSearcher} supports the following {@link TensorImage} color space types:
- + *
- + * <ul>
- + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#RGB}
- + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV12}
- + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV21}
- + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV12}
- + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV21}
- + * </ul>
- + *
- + * @param image a UINT8 {@link TensorImage} object that represents an RGB or YUV image
- + * @throws IllegalArgumentException if the color space type of image is unsupported
- + */
- + public List<NearestNeighbor> search(TensorImage image) {
- + return search(image, ImageProcessingOptions.builder().build());
- + }
- +
- + /**
- + * Performs embedding extraction on the provided {@link TensorImage} with {@link
- + * ImageProcessingOptions}, followed by nearest-neighbor search in the index.
- + *
- + * <p>{@link ImageSearcher} supports the following options:
- + *
- + * <ul>
- + * <li>Region of interest (ROI) (through {@link ImageProcessingOptions.Builder#setRoi}). It
- + * defaults to the entire image.
- + * <li>image rotation (through {@link ImageProcessingOptions.Builder#setOrientation}). It
- + * defaults to {@link ImageProcessingOptions.Orientation#TOP_LEFT}.
- + * </ul>
- + *
- + * <p>{@link ImageSearcher} supports the following {@link TensorImage} color space types:
- + *
- + * <ul>
- + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#RGB}
- + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV12}
- + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV21}
- + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV12}
- + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV21}
- + * </ul>
- + *
- + * @param image a UINT8 {@link TensorImage} object that represents an RGB or YUV image
- + * @throws IllegalArgumentException if the color space type of image is unsupported
- + */
- + public List<NearestNeighbor> search(TensorImage image, ImageProcessingOptions options) {
- + return run(new InferenceProvider<List<NearestNeighbor>>() {
- + @Override
- + public List<NearestNeighbor> run(
- + long frameBufferHandle, int width, int height, ImageProcessingOptions options) {
- + return search(frameBufferHandle, width, height, options);
- + }
- + }, image, options);
- + }
- +
- + /**
- + * Performs embedding extraction on the provided {@code MlImage}, followed by nearest-neighbor
- + * search in the index.
- + *
- + * @param image an {@code MlImage} object that represents an image
- + * @throws IllegalArgumentException if the storage type or format of the image is unsupported
- + */
- + public List<NearestNeighbor> search(MlImage image) {
- + return search(image, ImageProcessingOptions.builder().build());
- }
- - }
- -
- - private static ImageSearcher createFromModelFdAndOptionsImpl(
- - final int modelDescriptor,
- - final long modelDescriptorLength,
- - final long modelDescriptorOffset,
- - final ImageSearcherOptions options,
- - final int indexFd) {
- - long nativeHandle =
- - TaskJniUtils.createHandleFromLibrary(
- - new EmptyHandleProvider() {
- - @Override
- - public long createHandle() {
- - return initJniWithModelFdAndOptions(
- - modelDescriptor,
- - modelDescriptorLength,
- - modelDescriptorOffset,
- - TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions()),
- - options.getSearcherOptions().getL2Normalize(),
- - options.getSearcherOptions().getQuantize(),
- - indexFd,
- - options.getSearcherOptions().getMaxResults());
- - }
- - },
- - IMAGE_SEARCHER_NATIVE_LIB);
- - return new ImageSearcher(nativeHandle);
- - }
- -
- - private static native long initJniWithModelFdAndOptions(
- - int modelDescriptor,
- - long modelDescriptorLength,
- - long modelDescriptorOffset,
- - long baseOptionsHandle,
- - boolean l2Normalize,
- - boolean quantize,
- - int indexDescriptor,
- - int maxResults);
- -
- - private static native long initJniWithByteBuffer(
- - ByteBuffer modelBuffer,
- - long baseOptionsHandle,
- - boolean l2Normalize,
- - boolean quantize,
- - int indexFileDescriptor,
- - int maxResults);
- -
- - /**
- - * The native method to search an image based on the ROI specified.
- - *
- - * @param roi the ROI of the input image, an array representing the bounding box as {left, top,
- - * width, height}
- - */
- - private static native List<NearestNeighbor> searchNative(
- - long nativeHandle, long frameBufferHandle, int[] roi);
- -
- - @Override
- - protected void deinit(long nativeHandle) {
- - deinitJni(nativeHandle);
- - }
- -
- - /**
- - * Native implementation to release memory pointed by the pointer.
- - *
- - * @param nativeHandle pointer to memory allocated
- - */
- - private native void deinitJni(long nativeHandle);
- +
- + /**
- + * Performs embedding extraction on the provided {@code MlImage} with {@link
- + * ImageProcessingOptions}, followed by nearest-neighbor search in the index.
- + *
- + * <p>{@link ImageSearcher} supports the following options:
- + *
- + * <ul>
- + * <li>Region of interest (ROI) (through {@link ImageProcessingOptions.Builder#setRoi}). It
- + * defaults to the entire image.
- + * <li>image rotation (through {@link ImageProcessingOptions.Builder#setOrientation}). It
- + * defaults to {@link ImageProcessingOptions.Orientation#TOP_LEFT}. {@link
- + * MlImage#getRotation()} is not effective.
- + * </ul>
- + *
- + * @param image a {@code MlImage} object that represents an image
- + * @param options configures options including ROI and rotation
- + * @throws IllegalArgumentException if the storage type or format of the image is unsupported
- + */
- + public List<NearestNeighbor> search(MlImage image, ImageProcessingOptions options) {
- + image.getInternal().acquire();
- + TensorImage tensorImage = MlImageAdapter.createTensorImageFrom(image);
- + List<NearestNeighbor> result = search(tensorImage, options);
- + image.close();
- + return result;
- + }
- +
- + private List<NearestNeighbor> search(
- + long frameBufferHandle, int width, int height, ImageProcessingOptions options) {
- + checkNotClosed();
- + Rect roi = options.getRoi().isEmpty() ? new Rect(0, 0, width, height) : options.getRoi();
- + return searchNative(getNativeHandle(), frameBufferHandle,
- + new int[] {roi.left, roi.top, roi.width(), roi.height()});
- + }
- +
- + private static ImageSearcher createFromModelFdAndOptions(final int modelDescriptor,
- + final long modelDescriptorLength, final long modelDescriptorOffset,
- + final ImageSearcherOptions options) throws IOException {
- + if (options.getSearcherOptions().getIndexFile() != null) {
- + // indexDescriptor must be alive before ImageSearcher is initialized completely in the
- + // native layer.
- + try (ParcelFileDescriptor indexDescriptor =
- + ParcelFileDescriptor.open(options.getSearcherOptions().getIndexFile(),
- + ParcelFileDescriptor.MODE_READ_ONLY)) {
- + return createFromModelFdAndOptionsImpl(modelDescriptor, modelDescriptorLength,
- + modelDescriptorOffset, options, indexDescriptor.getFd());
- + }
- + } else {
- + // Index file is not configured. We'll check if the model contains one in the native
- + // layer.
- + return createFromModelFdAndOptionsImpl(modelDescriptor, modelDescriptorLength,
- + modelDescriptorOffset, options, /*indexFd=*/0);
- + }
- + }
- +
- + private static ImageSearcher createFromModelFdAndOptionsImpl(final int modelDescriptor,
- + final long modelDescriptorLength, final long modelDescriptorOffset,
- + final ImageSearcherOptions options, final int indexFd) {
- + long nativeHandle = TaskJniUtils.createHandleFromLibrary(new EmptyHandleProvider() {
- + @Override
- + public long createHandle() {
- + return initJniWithModelFdAndOptions(modelDescriptor, modelDescriptorLength,
- + modelDescriptorOffset,
- + TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions()),
- + options.getSearcherOptions().getL2Normalize(),
- + options.getSearcherOptions().getQuantize(), indexFd,
- + options.getSearcherOptions().getMaxResults());
- + }
- + }, IMAGE_SEARCHER_NATIVE_LIB);
- + return new ImageSearcher(nativeHandle);
- + }
- +
- + private static native long initJniWithModelFdAndOptions(int modelDescriptor,
- + long modelDescriptorLength, long modelDescriptorOffset, long baseOptionsHandle,
- + boolean l2Normalize, boolean quantize, int indexDescriptor, int maxResults);
- +
- + private static native long initJniWithByteBuffer(ByteBuffer modelBuffer, long baseOptionsHandle,
- + boolean l2Normalize, boolean quantize, int indexFileDescriptor, int maxResults);
- +
- + /**
- + * The native method to search an image based on the ROI specified.
- + *
- + * @param roi the ROI of the input image, an array representing the bounding box as {left, top,
- + * width, height}
- + */
- + private static native List<NearestNeighbor> searchNative(
- + long nativeHandle, long frameBufferHandle, int[] roi);
- +
- + @Override
- + protected void deinit(long nativeHandle) {
- + deinitJni(nativeHandle);
- + }
- +
- + /**
- + * Native implementation to release memory pointed by the pointer.
- + *
- + * @param nativeHandle pointer to memory allocated
- + */
- + private native void deinitJni(long nativeHandle);
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/ColoredLabel.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/ColoredLabel.java
- index a92e70ebc09b4..7a7a5b323f43b 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/ColoredLabel.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/ColoredLabel.java
- @@ -17,72 +17,74 @@ package org.tensorflow.lite.task.vision.segmenter;
-
- import android.graphics.Color;
- import android.os.Build;
- +
- import androidx.annotation.RequiresApi;
- +
- import com.google.auto.value.AutoValue;
- +
- import org.tensorflow.lite.task.core.annotations.UsedByReflection;
-
- /** Represents a label associated with a color for display purposes. */
- @AutoValue
- @UsedByReflection("image_segmentation_jni.cc")
- public abstract class ColoredLabel {
- + /**
- + * Creates a {@link ColoredLabel} object with an ARGB color int.
- + *
- + * @param label the label string, as provided in the label map packed in the TFLite Model
- + * Metadata.
- + * @param displayName the display name of label, as configured through {@link
- + * ImageSegmenter.ImageSegmenterOptions.Builder#setDisplayNamesLocale}
- + * @param argb the color components for the label in ARGB. See <a
- + * href="https://developer.android.com/reference/android/graphics/Color#color-ints">Android
- + * Color ints.</a> for more details.
- + */
- + @UsedByReflection("image_segmentation_jni.cc")
- + public static ColoredLabel create(String label, String displayName, int argb) {
- + return new AutoValue_ColoredLabel(label, displayName, argb);
- + }
-
- - /**
- - * Creates a {@link ColoredLabel} object with an ARGB color int.
- - *
- - * @param label the label string, as provided in the label map packed in the TFLite Model
- - * Metadata.
- - * @param displayName the display name of label, as configured through {@link
- - * ImageSegmenter.ImageSegmenterOptions.Builder#setDisplayNamesLocale}
- - * @param argb the color components for the label in ARGB. See <a
- - * href="https://developer.android.com/reference/android/graphics/Color#color-ints">Android
- - * Color ints.</a> for more details.
- - */
- - @UsedByReflection("image_segmentation_jni.cc")
- - public static ColoredLabel create(String label, String displayName, int argb) {
- - return new AutoValue_ColoredLabel(label, displayName, argb);
- - }
- -
- - /**
- - * Creates a {@link ColoredLabel} object with a {@link android.graphics.Color} instance.
- - *
- - * @param label the label string, as provided in the label map packed in the TFLite Model
- - * Metadata.
- - * @param displayName the display name of label, as configured through {@link
- - * ImageSegmenter.ImageSegmenterOptions.Builder#setDisplayNamesLocale}
- - * @param color the color components for the label. The Color instatnce is supported on Android
- - * API level 26 and above. For API level lower than 26, use {@link #create(String, String,
- - * int)}. See <a
- - * href="https://developer.android.com/reference/android/graphics/Color#color-instances">Android
- - * Color instances.</a> for more details.
- - */
- - @RequiresApi(Build.VERSION_CODES.O)
- - public static ColoredLabel create(String label, String displayName, Color color) {
- - return new AutoValue_ColoredLabel(label, displayName, color.toArgb());
- - }
- + /**
- + * Creates a {@link ColoredLabel} object with a {@link android.graphics.Color} instance.
- + *
- + * @param label the label string, as provided in the label map packed in the TFLite Model
- + * Metadata.
- + * @param displayName the display name of label, as configured through {@link
- + * ImageSegmenter.ImageSegmenterOptions.Builder#setDisplayNamesLocale}
- + * @param color the color components for the label. The Color instatnce is supported on Android
- + * API level 26 and above. For API level lower than 26, use {@link #create(String, String,
- + * int)}. See <a
- + * href="https://developer.android.com/reference/android/graphics/Color#color-instances">Android
- + * Color instances.</a> for more details.
- + */
- + @RequiresApi(Build.VERSION_CODES.O)
- + public static ColoredLabel create(String label, String displayName, Color color) {
- + return new AutoValue_ColoredLabel(label, displayName, color.toArgb());
- + }
-
- - public abstract String getlabel();
- + public abstract String getlabel();
-
- - public abstract String getDisplayName();
- + public abstract String getDisplayName();
-
- - /**
- - * Gets the ARGB int that represents the color.
- - *
- - * <p>See <a
- - * href="https://developer.android.com/reference/android/graphics/Color#color-ints">Android Color
- - * ints.</a> for more details.
- - */
- - public abstract int getArgb();
- + /**
- + * Gets the ARGB int that represents the color.
- + *
- + * <p>See <a
- + * href="https://developer.android.com/reference/android/graphics/Color#color-ints">Android
- + * Color ints.</a> for more details.
- + */
- + public abstract int getArgb();
-
- - /**
- - * Gets the {@link android.graphics.Color} instance of the underlying color.
- - *
- - * <p>The Color instatnce is supported on Android API level 26 and above. For API level lower than
- - * 26, use {@link #getArgb()}. See <a
- - * href="https://developer.android.com/reference/android/graphics/Color#color-instances">Android
- - * Color instances.</a> for more details.
- - */
- - @RequiresApi(Build.VERSION_CODES.O)
- - public Color getColor() {
- - return Color.valueOf(getArgb());
- - }
- + /**
- + * Gets the {@link android.graphics.Color} instance of the underlying color.
- + *
- + * <p>The Color instatnce is supported on Android API level 26 and above. For API level lower
- + * than 26, use {@link #getArgb()}. See <a
- + * href="https://developer.android.com/reference/android/graphics/Color#color-instances">Android
- + * Color instances.</a> for more details.
- + */
- + @RequiresApi(Build.VERSION_CODES.O)
- + public Color getColor() {
- + return Color.valueOf(getArgb());
- + }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/ImageSegmenter.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/ImageSegmenter.java
- index 0caa7a33e1729..4c3b36304a0e3 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/ImageSegmenter.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/ImageSegmenter.java
- @@ -18,16 +18,10 @@ package org.tensorflow.lite.task.vision.segmenter;
- import android.content.Context;
- import android.content.res.AssetFileDescriptor;
- import android.os.ParcelFileDescriptor;
- +
- import com.google.android.odml.image.MlImage;
- import com.google.auto.value.AutoValue;
- -import java.io.File;
- -import java.io.IOException;
- -import java.nio.ByteBuffer;
- -import java.nio.ByteOrder;
- -import java.nio.MappedByteBuffer;
- -import java.util.ArrayList;
- -import java.util.Arrays;
- -import java.util.List;
- +
- import org.tensorflow.lite.support.image.MlImageAdapter;
- import org.tensorflow.lite.support.image.TensorImage;
- import org.tensorflow.lite.task.core.BaseOptions;
- @@ -37,6 +31,15 @@ import org.tensorflow.lite.task.core.vision.ImageProcessingOptions;
- import org.tensorflow.lite.task.vision.core.BaseVisionTaskApi;
- import org.tensorflow.lite.task.vision.core.BaseVisionTaskApi.InferenceProvider;
-
- +import java.io.File;
- +import java.io.IOException;
- +import java.nio.ByteBuffer;
- +import java.nio.ByteOrder;
- +import java.nio.MappedByteBuffer;
- +import java.util.ArrayList;
- +import java.util.Arrays;
- +import java.util.List;
- +
- /**
- * Performs segmentation on images.
- *
- @@ -75,394 +78,365 @@ import org.tensorflow.lite.task.vision.core.BaseVisionTaskApi.InferenceProvider;
- * href="https://tfhub.dev/tensorflow/lite-model/deeplabv3/1/metadata/1">TensorFlow Hub.</a>.
- */
- public final class ImageSegmenter extends BaseVisionTaskApi {
- + private static final String IMAGE_SEGMENTER_NATIVE_LIB = "task_vision_jni";
- + private static final int OPTIONAL_FD_LENGTH = -1;
- + private static final int OPTIONAL_FD_OFFSET = -1;
- +
- + private final OutputType outputType;
- +
- + /**
- + * Creates an {@link ImageSegmenter} instance from the default {@link ImageSegmenterOptions}.
- + *
- + * @param modelPath path of the segmentation model with metadata in the assets
- + * @throws IOException if an I/O error occurs when loading the tflite model
- + * @throws IllegalArgumentException if an argument is invalid
- + * @throws IllegalStateException if there is an internal error
- + * @throws RuntimeException if there is an otherwise unspecified error
- + */
- + public static ImageSegmenter createFromFile(Context context, String modelPath)
- + throws IOException {
- + return createFromFileAndOptions(
- + context, modelPath, ImageSegmenterOptions.builder().build());
- + }
-
- - private static final String IMAGE_SEGMENTER_NATIVE_LIB = "task_vision_jni";
- - private static final int OPTIONAL_FD_LENGTH = -1;
- - private static final int OPTIONAL_FD_OFFSET = -1;
- -
- - private final OutputType outputType;
- -
- - /**
- - * Creates an {@link ImageSegmenter} instance from the default {@link ImageSegmenterOptions}.
- - *
- - * @param modelPath path of the segmentation model with metadata in the assets
- - * @throws IOException if an I/O error occurs when loading the tflite model
- - * @throws IllegalArgumentException if an argument is invalid
- - * @throws IllegalStateException if there is an internal error
- - * @throws RuntimeException if there is an otherwise unspecified error
- - */
- - public static ImageSegmenter createFromFile(Context context, String modelPath)
- - throws IOException {
- - return createFromFileAndOptions(context, modelPath, ImageSegmenterOptions.builder().build());
- - }
- -
- - /**
- - * Creates an {@link ImageSegmenter} instance from the default {@link ImageSegmenterOptions}.
- - *
- - * @param modelFile the segmentation model {@link File} instance
- - * @throws IOException if an I/O error occurs when loading the tflite model
- - * @throws IllegalArgumentException if an argument is invalid
- - * @throws IllegalStateException if there is an internal error
- - * @throws RuntimeException if there is an otherwise unspecified error
- - */
- - public static ImageSegmenter createFromFile(File modelFile) throws IOException {
- - return createFromFileAndOptions(modelFile, ImageSegmenterOptions.builder().build());
- - }
- -
- - /**
- - * Creates an {@link ImageSegmenter} instance with a model buffer and the default {@link
- - * ImageSegmenterOptions}.
- - *
- - * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the
- - * segmentation model
- - * @throws IllegalStateException if there is an internal error
- - * @throws RuntimeException if there is an otherwise unspecified error
- - * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
- - * {@link MappedByteBuffer}
- - */
- - public static ImageSegmenter createFromBuffer(final ByteBuffer modelBuffer) {
- - return createFromBufferAndOptions(modelBuffer, ImageSegmenterOptions.builder().build());
- - }
- -
- - /**
- - * Creates an {@link ImageSegmenter} instance from {@link ImageSegmenterOptions}.
- - *
- - * @param modelPath path of the segmentation model with metadata in the assets
- - * @throws IOException if an I/O error occurs when loading the tflite model
- - * @throws IllegalArgumentException if an argument is invalid
- - * @throws IllegalStateException if there is an internal error
- - * @throws RuntimeException if there is an otherwise unspecified error
- - */
- - public static ImageSegmenter createFromFileAndOptions(
- - Context context, String modelPath, final ImageSegmenterOptions options) throws IOException {
- - try (AssetFileDescriptor assetFileDescriptor = context.getAssets().openFd(modelPath)) {
- - return createFromModelFdAndOptions(
- - /*fileDescriptor=*/ assetFileDescriptor.getParcelFileDescriptor().getFd(),
- - /*fileDescriptorLength=*/ assetFileDescriptor.getLength(),
- - /*fileDescriptorOffset=*/ assetFileDescriptor.getStartOffset(),
- - options);
- + /**
- + * Creates an {@link ImageSegmenter} instance from the default {@link ImageSegmenterOptions}.
- + *
- + * @param modelFile the segmentation model {@link File} instance
- + * @throws IOException if an I/O error occurs when loading the tflite model
- + * @throws IllegalArgumentException if an argument is invalid
- + * @throws IllegalStateException if there is an internal error
- + * @throws RuntimeException if there is an otherwise unspecified error
- + */
- + public static ImageSegmenter createFromFile(File modelFile) throws IOException {
- + return createFromFileAndOptions(modelFile, ImageSegmenterOptions.builder().build());
- }
- - }
- -
- - /**
- - * Creates an {@link ImageSegmenter} instance from {@link ImageSegmenterOptions}.
- - *
- - * @param modelFile the segmentation model {@link File} instance
- - * @throws IOException if an I/O error occurs when loading the tflite model
- - * @throws IllegalArgumentException if an argument is invalid
- - * @throws IllegalStateException if there is an internal error
- - * @throws RuntimeException if there is an otherwise unspecified error
- - */
- - public static ImageSegmenter createFromFileAndOptions(
- - File modelFile, final ImageSegmenterOptions options) throws IOException {
- - try (ParcelFileDescriptor descriptor =
- - ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
- - return createFromModelFdAndOptions(
- - /*fileDescriptor=*/ descriptor.getFd(),
- - /*fileDescriptorLength=*/ OPTIONAL_FD_LENGTH,
- - /*fileDescriptorOffset=*/ OPTIONAL_FD_OFFSET,
- - options);
- +
- + /**
- + * Creates an {@link ImageSegmenter} instance with a model buffer and the default {@link
- + * ImageSegmenterOptions}.
- + *
- + * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the
- + * segmentation model
- + * @throws IllegalStateException if there is an internal error
- + * @throws RuntimeException if there is an otherwise unspecified error
- + * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
- + * {@link MappedByteBuffer}
- + */
- + public static ImageSegmenter createFromBuffer(final ByteBuffer modelBuffer) {
- + return createFromBufferAndOptions(modelBuffer, ImageSegmenterOptions.builder().build());
- + }
- +
- + /**
- + * Creates an {@link ImageSegmenter} instance from {@link ImageSegmenterOptions}.
- + *
- + * @param modelPath path of the segmentation model with metadata in the assets
- + * @throws IOException if an I/O error occurs when loading the tflite model
- + * @throws IllegalArgumentException if an argument is invalid
- + * @throws IllegalStateException if there is an internal error
- + * @throws RuntimeException if there is an otherwise unspecified error
- + */
- + public static ImageSegmenter createFromFileAndOptions(Context context, String modelPath,
- + final ImageSegmenterOptions options) throws IOException {
- + try (AssetFileDescriptor assetFileDescriptor = context.getAssets().openFd(modelPath)) {
- + return createFromModelFdAndOptions(
- + /*fileDescriptor=*/assetFileDescriptor.getParcelFileDescriptor().getFd(),
- + /*fileDescriptorLength=*/assetFileDescriptor.getLength(),
- + /*fileDescriptorOffset=*/assetFileDescriptor.getStartOffset(), options);
- + }
- + }
- +
- + /**
- + * Creates an {@link ImageSegmenter} instance from {@link ImageSegmenterOptions}.
- + *
- + * @param modelFile the segmentation model {@link File} instance
- + * @throws IOException if an I/O error occurs when loading the tflite model
- + * @throws IllegalArgumentException if an argument is invalid
- + * @throws IllegalStateException if there is an internal error
- + * @throws RuntimeException if there is an otherwise unspecified error
- + */
- + public static ImageSegmenter createFromFileAndOptions(
- + File modelFile, final ImageSegmenterOptions options) throws IOException {
- + try (ParcelFileDescriptor descriptor =
- + ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
- + return createFromModelFdAndOptions(
- + /*fileDescriptor=*/descriptor.getFd(),
- + /*fileDescriptorLength=*/OPTIONAL_FD_LENGTH,
- + /*fileDescriptorOffset=*/OPTIONAL_FD_OFFSET, options);
- + }
- + }
- +
- + /**
- + * Creates an {@link ImageSegmenter} instance with a model buffer and {@link
- + * ImageSegmenterOptions}.
- + *
- + * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the
- + * segmentation model
- + * @throws IllegalStateException if there is an internal error
- + * @throws RuntimeException if there is an otherwise unspecified error
- + * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
- + * {@link MappedByteBuffer}
- + */
- + public static ImageSegmenter createFromBufferAndOptions(
- + final ByteBuffer modelBuffer, final ImageSegmenterOptions options) {
- + if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) {
- + throw new IllegalArgumentException(
- + "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer.");
- + }
- + return new ImageSegmenter(TaskJniUtils.createHandleFromLibrary(new EmptyHandleProvider() {
- + @Override
- + public long createHandle() {
- + return initJniWithByteBuffer(modelBuffer, options.getDisplayNamesLocale(),
- + options.getOutputType().getValue(),
- + TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads(
- + options.getBaseOptions(), options.getNumThreads()));
- + }
- + }, IMAGE_SEGMENTER_NATIVE_LIB), options.getOutputType());
- + }
- +
- + /**
- + * Constructor to initialize the JNI with a pointer from C++.
- + *
- + * @param nativeHandle a pointer referencing memory allocated in C++
- + */
- + private ImageSegmenter(long nativeHandle, OutputType outputType) {
- + super(nativeHandle);
- + this.outputType = outputType;
- + }
- +
- + /** Options for setting up an {@link ImageSegmenter}. */
- + @AutoValue
- + public abstract static class ImageSegmenterOptions {
- + private static final String DEFAULT_DISPLAY_NAME_LOCALE = "en";
- + private static final OutputType DEFAULT_OUTPUT_TYPE = OutputType.CATEGORY_MASK;
- + private static final int NUM_THREADS = -1;
- +
- + public abstract BaseOptions getBaseOptions();
- +
- + public abstract String getDisplayNamesLocale();
- +
- + public abstract OutputType getOutputType();
- +
- + public abstract int getNumThreads();
- +
- + public static Builder builder() {
- + return new AutoValue_ImageSegmenter_ImageSegmenterOptions.Builder()
- + .setDisplayNamesLocale(DEFAULT_DISPLAY_NAME_LOCALE)
- + .setOutputType(DEFAULT_OUTPUT_TYPE)
- + .setNumThreads(NUM_THREADS)
- + .setBaseOptions(BaseOptions.builder().build());
- + }
- +
- + /** Builder for {@link ImageSegmenterOptions}. */
- + @AutoValue.Builder
- + public abstract static class Builder {
- + /** Sets the general options to configure Task APIs, such as accelerators. */
- + public abstract Builder setBaseOptions(BaseOptions baseOptions);
- +
- + /**
- + * Sets the locale to use for display names specified through the TFLite Model Metadata,
- + * if any.
- + *
- + * <p>Defaults to English({@code "en"}). See the <a
- + * href="https://github.com/tensorflow/tflite-support/blob/3ce83f0cfe2c68fecf83e019f2acc354aaba471f/tensorflow_lite_support/metadata/metadata_schema.fbs#L147">TFLite
- + * Metadata schema file.</a> for the accepted pattern of locale.
- + */
- + public abstract Builder setDisplayNamesLocale(String displayNamesLocale);
- +
- + public abstract Builder setOutputType(OutputType outputType);
- +
- + /**
- + * Sets the number of threads to be used for TFLite ops that support multi-threading
- + * when running inference with CPU. Defaults to -1.
- + *
- + * <p>numThreads should be greater than 0 or equal to -1. Setting numThreads to -1 has
- + * the effect to let TFLite runtime set the value.
- + *
- + * @deprecated use {@link BaseOptions} to configure number of threads instead. This
- + * method
- + * will override the number of threads configured from {@link BaseOptions}.
- + */
- + @Deprecated
- + public abstract Builder setNumThreads(int numThreads);
- +
- + public abstract ImageSegmenterOptions build();
- + }
- + }
- +
- + /**
- + * Performs actual segmentation on the provided image.
- + *
- + * <p>{@link ImageSegmenter} supports the following {@link TensorImage} color space types:
- + *
- + * <ul>
- + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#RGB}
- + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV12}
- + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV21}
- + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV12}
- + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV21}
- + * </ul>
- + *
- + * @param image a UINT8 {@link TensorImage} object that represents an RGB or YUV image
- + * @return results of performing image segmentation. Note that at the time, a single {@link
- + * Segmentation} element is expected to be returned. The result is stored in a {@link List}
- + * for later extension to e.g. instance segmentation models, which may return one
- + * segmentation per object.
- + * @throws IllegalStateException if there is an internal error
- + * @throws RuntimeException if there is an otherwise unspecified error
- + * @throws IllegalArgumentException if the color space type of image is unsupported
- + */
- + public List<Segmentation> segment(TensorImage image) {
- + return segment(image, ImageProcessingOptions.builder().build());
- + }
- +
- + /**
- + * Performs actual segmentation on the provided image with {@link ImageProcessingOptions}.
- + *
- + * <p>{@link ImageSegmenter} supports the following {@link TensorImage} color space types:
- + *
- + * <ul>
- + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#RGB}
- + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV12}
- + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV21}
- + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV12}
- + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV21}
- + * </ul>
- + *
- + * <p>{@link ImageSegmenter} supports the following options:
- + *
- + * <ul>
- + * <li>image rotation (through {@link ImageProcessingOptions.Builder#setOrientation}). It
- + * defaults to {@link ImageProcessingOptions.Orientation#TOP_LEFT}
- + * </ul>
- + *
- + * @param image a UINT8 {@link TensorImage} object that represents an RGB or YUV image
- + * @param options the options configure how to preprocess the image
- + * @return results of performing image segmentation. Note that at the time, a single {@link
- + * Segmentation} element is expected to be returned. The result is stored in a {@link List}
- + * for later extension to e.g. instance segmentation models, which may return one
- + * segmentation per object.
- + * @throws IllegalStateException if there is an internal error
- + * @throws RuntimeException if there is an otherwise unspecified error
- + * @throws IllegalArgumentException if the color space type of image is unsupported
- + */
- + public List<Segmentation> segment(TensorImage image, ImageProcessingOptions options) {
- + return run(new InferenceProvider<List<Segmentation>>() {
- + @Override
- + public List<Segmentation> run(
- + long frameBufferHandle, int width, int height, ImageProcessingOptions options) {
- + return segment(frameBufferHandle, options);
- + }
- + }, image, options);
- }
- - }
- -
- - /**
- - * Creates an {@link ImageSegmenter} instance with a model buffer and {@link
- - * ImageSegmenterOptions}.
- - *
- - * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the
- - * segmentation model
- - * @throws IllegalStateException if there is an internal error
- - * @throws RuntimeException if there is an otherwise unspecified error
- - * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
- - * {@link MappedByteBuffer}
- - */
- - public static ImageSegmenter createFromBufferAndOptions(
- - final ByteBuffer modelBuffer, final ImageSegmenterOptions options) {
- - if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) {
- - throw new IllegalArgumentException(
- - "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer.");
- +
- + /**
- + * Performs actual segmentation on the provided {@code MlImage}.
- + *
- + * @param image an {@code MlImage} to segment.
- + * @return results of performing image segmentation. Note that at the time, a single {@link
- + * Segmentation} element is expected to be returned. The result is stored in a {@link List}
- + * for later extension to e.g. instance segmentation models, which may return one
- + * segmentation per object.
- + * @throws IllegalStateException if there is an internal error
- + * @throws RuntimeException if there is an otherwise unspecified error
- + * @throws IllegalArgumentException if the storage type or format of the image is unsupported
- + */
- + public List<Segmentation> segment(MlImage image) {
- + return segment(image, ImageProcessingOptions.builder().build());
- }
- - return new ImageSegmenter(
- - TaskJniUtils.createHandleFromLibrary(
- - new EmptyHandleProvider() {
- - @Override
- - public long createHandle() {
- - return initJniWithByteBuffer(
- - modelBuffer,
- - options.getDisplayNamesLocale(),
- - options.getOutputType().getValue(),
- - TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads(
- - options.getBaseOptions(), options.getNumThreads()));
- - }
- - },
- - IMAGE_SEGMENTER_NATIVE_LIB),
- - options.getOutputType());
- - }
- -
- - /**
- - * Constructor to initialize the JNI with a pointer from C++.
- - *
- - * @param nativeHandle a pointer referencing memory allocated in C++
- - */
- - private ImageSegmenter(long nativeHandle, OutputType outputType) {
- - super(nativeHandle);
- - this.outputType = outputType;
- - }
- -
- - /** Options for setting up an {@link ImageSegmenter}. */
- - @AutoValue
- - public abstract static class ImageSegmenterOptions {
- - private static final String DEFAULT_DISPLAY_NAME_LOCALE = "en";
- - private static final OutputType DEFAULT_OUTPUT_TYPE = OutputType.CATEGORY_MASK;
- - private static final int NUM_THREADS = -1;
- -
- - public abstract BaseOptions getBaseOptions();
- -
- - public abstract String getDisplayNamesLocale();
- -
- - public abstract OutputType getOutputType();
- -
- - public abstract int getNumThreads();
- -
- - public static Builder builder() {
- - return new AutoValue_ImageSegmenter_ImageSegmenterOptions.Builder()
- - .setDisplayNamesLocale(DEFAULT_DISPLAY_NAME_LOCALE)
- - .setOutputType(DEFAULT_OUTPUT_TYPE)
- - .setNumThreads(NUM_THREADS)
- - .setBaseOptions(BaseOptions.builder().build());
- +
- + /**
- + * Performs actual segmentation on the provided {@code MlImage} with {@link
- + * ImageProcessingOptions}.
- + *
- + * <p>{@link ImageSegmenter} supports the following options:
- + *
- + * <ul>
- + * <li>image rotation (through {@link ImageProcessingOptions.Builder#setOrientation}). It
- + * defaults to {@link ImageProcessingOptions.Orientation#TOP_LEFT}. {@link
- + * MlImage#getRotation()} is not effective.
- + * </ul>
- + *
- + * @param image an {@code MlImage} to segment.
- + * @param options the options configure how to preprocess the image.
- + * @return results of performing image segmentation. Note that at the time, a single {@link
- + * Segmentation} element is expected to be returned. The result is stored in a {@link List}
- + * for later extension to e.g. instance segmentation models, which may return one
- + * segmentation per object.
- + * @throws IllegalStateException if there is an internal error
- + * @throws RuntimeException if there is an otherwise unspecified error
- + * @throws IllegalArgumentException if the color space type of image is unsupported
- + */
- + public List<Segmentation> segment(MlImage image, ImageProcessingOptions options) {
- + image.getInternal().acquire();
- + TensorImage tensorImage = MlImageAdapter.createTensorImageFrom(image);
- + List<Segmentation> result = segment(tensorImage, options);
- + image.close();
- + return result;
- }
-
- - /** Builder for {@link ImageSegmenterOptions}. */
- - @AutoValue.Builder
- - public abstract static class Builder {
- -
- - /** Sets the general options to configure Task APIs, such as accelerators. */
- - public abstract Builder setBaseOptions(BaseOptions baseOptions);
- -
- - /**
- - * Sets the locale to use for display names specified through the TFLite Model Metadata, if
- - * any.
- - *
- - * <p>Defaults to English({@code "en"}). See the <a
- - * href="https://github.com/tensorflow/tflite-support/blob/3ce83f0cfe2c68fecf83e019f2acc354aaba471f/tensorflow_lite_support/metadata/metadata_schema.fbs#L147">TFLite
- - * Metadata schema file.</a> for the accepted pattern of locale.
- - */
- - public abstract Builder setDisplayNamesLocale(String displayNamesLocale);
- -
- - public abstract Builder setOutputType(OutputType outputType);
- -
- - /**
- - * Sets the number of threads to be used for TFLite ops that support multi-threading when
- - * running inference with CPU. Defaults to -1.
- - *
- - * <p>numThreads should be greater than 0 or equal to -1. Setting numThreads to -1 has the
- - * effect to let TFLite runtime set the value.
- - *
- - * @deprecated use {@link BaseOptions} to configure number of threads instead. This method
- - * will override the number of threads configured from {@link BaseOptions}.
- - */
- - @Deprecated
- - public abstract Builder setNumThreads(int numThreads);
- -
- - public abstract ImageSegmenterOptions build();
- + public List<Segmentation> segment(long frameBufferHandle, ImageProcessingOptions options) {
- + checkNotClosed();
- +
- + List<byte[]> maskByteArrays = new ArrayList<>();
- + List<ColoredLabel> coloredLabels = new ArrayList<>();
- + int[] maskShape = new int[2];
- + segmentNative(
- + getNativeHandle(), frameBufferHandle, maskByteArrays, maskShape, coloredLabels);
- +
- + List<ByteBuffer> maskByteBuffers = new ArrayList<>();
- + for (byte[] bytes : maskByteArrays) {
- + ByteBuffer byteBuffer = ByteBuffer.wrap(bytes);
- + // Change the byte order to little_endian, since the buffers were generated in jni.
- + byteBuffer.order(ByteOrder.LITTLE_ENDIAN);
- + maskByteBuffers.add(byteBuffer);
- + }
- +
- + return Arrays.asList(Segmentation.create(outputType,
- + outputType.createMasksFromBuffer(maskByteBuffers, maskShape), coloredLabels));
- }
- - }
- -
- - /**
- - * Performs actual segmentation on the provided image.
- - *
- - * <p>{@link ImageSegmenter} supports the following {@link TensorImage} color space types:
- - *
- - * <ul>
- - * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#RGB}
- - * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV12}
- - * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV21}
- - * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV12}
- - * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV21}
- - * </ul>
- - *
- - * @param image a UINT8 {@link TensorImage} object that represents an RGB or YUV image
- - * @return results of performing image segmentation. Note that at the time, a single {@link
- - * Segmentation} element is expected to be returned. The result is stored in a {@link List}
- - * for later extension to e.g. instance segmentation models, which may return one segmentation
- - * per object.
- - * @throws IllegalStateException if there is an internal error
- - * @throws RuntimeException if there is an otherwise unspecified error
- - * @throws IllegalArgumentException if the color space type of image is unsupported
- - */
- - public List<Segmentation> segment(TensorImage image) {
- - return segment(image, ImageProcessingOptions.builder().build());
- - }
- -
- - /**
- - * Performs actual segmentation on the provided image with {@link ImageProcessingOptions}.
- - *
- - * <p>{@link ImageSegmenter} supports the following {@link TensorImage} color space types:
- - *
- - * <ul>
- - * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#RGB}
- - * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV12}
- - * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV21}
- - * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV12}
- - * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV21}
- - * </ul>
- - *
- - * <p>{@link ImageSegmenter} supports the following options:
- - *
- - * <ul>
- - * <li>image rotation (through {@link ImageProcessingOptions.Builder#setOrientation}). It
- - * defaults to {@link ImageProcessingOptions.Orientation#TOP_LEFT}
- - * </ul>
- - *
- - * @param image a UINT8 {@link TensorImage} object that represents an RGB or YUV image
- - * @param options the options configure how to preprocess the image
- - * @return results of performing image segmentation. Note that at the time, a single {@link
- - * Segmentation} element is expected to be returned. The result is stored in a {@link List}
- - * for later extension to e.g. instance segmentation models, which may return one segmentation
- - * per object.
- - * @throws IllegalStateException if there is an internal error
- - * @throws RuntimeException if there is an otherwise unspecified error
- - * @throws IllegalArgumentException if the color space type of image is unsupported
- - */
- - public List<Segmentation> segment(TensorImage image, ImageProcessingOptions options) {
- - return run(
- - new InferenceProvider<List<Segmentation>>() {
- - @Override
- - public List<Segmentation> run(
- - long frameBufferHandle, int width, int height, ImageProcessingOptions options) {
- - return segment(frameBufferHandle, options);
- - }
- - },
- - image,
- - options);
- - }
- -
- - /**
- - * Performs actual segmentation on the provided {@code MlImage}.
- - *
- - * @param image an {@code MlImage} to segment.
- - * @return results of performing image segmentation. Note that at the time, a single {@link
- - * Segmentation} element is expected to be returned. The result is stored in a {@link List}
- - * for later extension to e.g. instance segmentation models, which may return one segmentation
- - * per object.
- - * @throws IllegalStateException if there is an internal error
- - * @throws RuntimeException if there is an otherwise unspecified error
- - * @throws IllegalArgumentException if the storage type or format of the image is unsupported
- - */
- - public List<Segmentation> segment(MlImage image) {
- - return segment(image, ImageProcessingOptions.builder().build());
- - }
- -
- - /**
- - * Performs actual segmentation on the provided {@code MlImage} with {@link
- - * ImageProcessingOptions}.
- - *
- - * <p>{@link ImageSegmenter} supports the following options:
- - *
- - * <ul>
- - * <li>image rotation (through {@link ImageProcessingOptions.Builder#setOrientation}). It
- - * defaults to {@link ImageProcessingOptions.Orientation#TOP_LEFT}. {@link
- - * MlImage#getRotation()} is not effective.
- - * </ul>
- - *
- - * @param image an {@code MlImage} to segment.
- - * @param options the options configure how to preprocess the image.
- - * @return results of performing image segmentation. Note that at the time, a single {@link
- - * Segmentation} element is expected to be returned. The result is stored in a {@link List}
- - * for later extension to e.g. instance segmentation models, which may return one segmentation
- - * per object.
- - * @throws IllegalStateException if there is an internal error
- - * @throws RuntimeException if there is an otherwise unspecified error
- - * @throws IllegalArgumentException if the color space type of image is unsupported
- - */
- - public List<Segmentation> segment(MlImage image, ImageProcessingOptions options) {
- - image.getInternal().acquire();
- - TensorImage tensorImage = MlImageAdapter.createTensorImageFrom(image);
- - List<Segmentation> result = segment(tensorImage, options);
- - image.close();
- - return result;
- - }
- -
- - public List<Segmentation> segment(long frameBufferHandle, ImageProcessingOptions options) {
- - checkNotClosed();
- -
- - List<byte[]> maskByteArrays = new ArrayList<>();
- - List<ColoredLabel> coloredLabels = new ArrayList<>();
- - int[] maskShape = new int[2];
- - segmentNative(getNativeHandle(), frameBufferHandle, maskByteArrays, maskShape, coloredLabels);
- -
- - List<ByteBuffer> maskByteBuffers = new ArrayList<>();
- - for (byte[] bytes : maskByteArrays) {
- - ByteBuffer byteBuffer = ByteBuffer.wrap(bytes);
- - // Change the byte order to little_endian, since the buffers were generated in jni.
- - byteBuffer.order(ByteOrder.LITTLE_ENDIAN);
- - maskByteBuffers.add(byteBuffer);
- +
- + private static ImageSegmenter createFromModelFdAndOptions(final int fileDescriptor,
- + final long fileDescriptorLength, final long fileDescriptorOffset,
- + final ImageSegmenterOptions options) {
- + long nativeHandle = TaskJniUtils.createHandleFromLibrary(new EmptyHandleProvider() {
- + @Override
- + public long createHandle() {
- + return initJniWithModelFdAndOptions(fileDescriptor, fileDescriptorLength,
- + fileDescriptorOffset, options.getDisplayNamesLocale(),
- + options.getOutputType().getValue(),
- + TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads(
- + options.getBaseOptions(), options.getNumThreads()));
- + }
- + }, IMAGE_SEGMENTER_NATIVE_LIB);
- + return new ImageSegmenter(nativeHandle, options.getOutputType());
- + }
- +
- + private static native long initJniWithModelFdAndOptions(int fileDescriptor,
- + long fileDescriptorLength, long fileDescriptorOffset, String displayNamesLocale,
- + int outputType, long baseOptionsHandle);
- +
- + private static native long initJniWithByteBuffer(ByteBuffer modelBuffer,
- + String displayNamesLocale, int outputType, long baseOptionsHandle);
- +
- + /**
- + * The native method to segment the image.
- + *
- + * <p>{@code maskBuffers}, {@code maskShape}, {@code coloredLabels} will be updated in the
- + * native layer.
- + */
- + private static native void segmentNative(long nativeHandle, long frameBufferHandle,
- + List<byte[]> maskByteArrays, int[] maskShape, List<ColoredLabel> coloredLabels);
- +
- + @Override
- + protected void deinit(long nativeHandle) {
- + deinitJni(nativeHandle);
- }
-
- - return Arrays.asList(
- - Segmentation.create(
- - outputType,
- - outputType.createMasksFromBuffer(maskByteBuffers, maskShape),
- - coloredLabels));
- - }
- -
- - private static ImageSegmenter createFromModelFdAndOptions(
- - final int fileDescriptor,
- - final long fileDescriptorLength,
- - final long fileDescriptorOffset,
- - final ImageSegmenterOptions options) {
- - long nativeHandle =
- - TaskJniUtils.createHandleFromLibrary(
- - new EmptyHandleProvider() {
- - @Override
- - public long createHandle() {
- - return initJniWithModelFdAndOptions(
- - fileDescriptor,
- - fileDescriptorLength,
- - fileDescriptorOffset,
- - options.getDisplayNamesLocale(),
- - options.getOutputType().getValue(),
- - TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads(
- - options.getBaseOptions(), options.getNumThreads()));
- - }
- - },
- - IMAGE_SEGMENTER_NATIVE_LIB);
- - return new ImageSegmenter(nativeHandle, options.getOutputType());
- - }
- -
- - private static native long initJniWithModelFdAndOptions(
- - int fileDescriptor,
- - long fileDescriptorLength,
- - long fileDescriptorOffset,
- - String displayNamesLocale,
- - int outputType,
- - long baseOptionsHandle);
- -
- - private static native long initJniWithByteBuffer(
- - ByteBuffer modelBuffer, String displayNamesLocale, int outputType, long baseOptionsHandle);
- -
- - /**
- - * The native method to segment the image.
- - *
- - * <p>{@code maskBuffers}, {@code maskShape}, {@code coloredLabels} will be updated in the native
- - * layer.
- - */
- - private static native void segmentNative(
- - long nativeHandle,
- - long frameBufferHandle,
- - List<byte[]> maskByteArrays,
- - int[] maskShape,
- - List<ColoredLabel> coloredLabels);
- -
- - @Override
- - protected void deinit(long nativeHandle) {
- - deinitJni(nativeHandle);
- - }
- -
- - /**
- - * Native implementation to release memory pointed by the pointer.
- - *
- - * @param nativeHandle pointer to memory allocated
- - */
- - private native void deinitJni(long nativeHandle);
- + /**
- + * Native implementation to release memory pointed by the pointer.
- + *
- + * @param nativeHandle pointer to memory allocated
- + */
- + private native void deinitJni(long nativeHandle);
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/OutputType.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/OutputType.java
- index 26ace1eaa1783..8c69cf5d152a0 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/OutputType.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/OutputType.java
- @@ -20,126 +20,128 @@ import static org.tensorflow.lite.DataType.UINT8;
- import static org.tensorflow.lite.support.common.internal.SupportPreconditions.checkArgument;
- import static org.tensorflow.lite.support.image.ColorSpaceType.GRAYSCALE;
-
- +import org.tensorflow.lite.support.image.TensorImage;
- +import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
- +
- import java.nio.ByteBuffer;
- import java.util.ArrayList;
- import java.util.List;
- -import org.tensorflow.lite.support.image.TensorImage;
- -import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
-
- /**
- * Output mask type. This allows specifying the type of post-processing to perform on the raw model
- * results.
- */
- public enum OutputType {
- -
- - /**
- - * Gives a single output mask where each pixel represents the class which the pixel in the
- - * original image was predicted to belong to.
- - */
- - CATEGORY_MASK(0) {
- /**
- - * {@inheritDoc}
- - *
- - * @throws IllegalArgumentException if more than one {@link TensorImage} are provided, or if the
- - * color space of the {@link TensorImage} is not {@link ColorSpaceType#GRAYSCALE}
- + * Gives a single output mask where each pixel represents the class which the pixel in the
- + * original image was predicted to belong to.
- */
- - @Override
- - void assertMasksMatchColoredLabels(List<TensorImage> masks, List<ColoredLabel> coloredLabels) {
- - checkArgument(
- - masks.size() == 1,
- - "CATRGORY_MASK only allows one TensorImage in the list, providing " + masks.size());
- -
- - TensorImage mask = masks.get(0);
- - checkArgument(
- - mask.getColorSpaceType() == GRAYSCALE,
- - "CATRGORY_MASK only supports masks of ColorSpaceType, GRAYSCALE, providing "
- - + mask.getColorSpaceType());
- - }
- + CATEGORY_MASK(0) {
- + /**
- + * {@inheritDoc}
- + *
- + * @throws IllegalArgumentException if more than one {@link TensorImage} are provided, or if
- + * the
- + * color space of the {@link TensorImage} is not {@link ColorSpaceType#GRAYSCALE}
- + */
- + @Override
- + void assertMasksMatchColoredLabels(
- + List<TensorImage> masks, List<ColoredLabel> coloredLabels) {
- + checkArgument(masks.size() == 1,
- + "CATRGORY_MASK only allows one TensorImage in the list, providing "
- + + masks.size());
- +
- + TensorImage mask = masks.get(0);
- + checkArgument(mask.getColorSpaceType() == GRAYSCALE,
- + "CATRGORY_MASK only supports masks of ColorSpaceType, GRAYSCALE, providing "
- + + mask.getColorSpaceType());
- + }
- +
- + /**
- + * {@inheritDoc}
- + *
- + * @throws IllegalArgumentException if more than one {@link ByteBuffer} are provided in the
- + * list
- + */
- + @Override
- + List<TensorImage> createMasksFromBuffer(List<ByteBuffer> buffers, int[] maskShape) {
- + checkArgument(buffers.size() == 1,
- + "CATRGORY_MASK only allows one mask in the buffer list, providing "
- + + buffers.size());
- +
- + List<TensorImage> masks = new ArrayList<>();
- + TensorBuffer tensorBuffer = TensorBuffer.createDynamic(UINT8);
- + tensorBuffer.loadBuffer(buffers.get(0), maskShape);
- + TensorImage tensorImage = new TensorImage(UINT8);
- + tensorImage.load(tensorBuffer, GRAYSCALE);
- + masks.add(tensorImage);
- +
- + return masks;
- + }
- + },
-
- /**
- - * {@inheritDoc}
- - *
- - * @throws IllegalArgumentException if more than one {@link ByteBuffer} are provided in the list
- + * Gives a list of output masks where, for each mask, each pixel represents the prediction
- + * confidence, usually in the [0, 1] range.
- */
- - @Override
- - List<TensorImage> createMasksFromBuffer(List<ByteBuffer> buffers, int[] maskShape) {
- - checkArgument(
- - buffers.size() == 1,
- - "CATRGORY_MASK only allows one mask in the buffer list, providing " + buffers.size());
- -
- - List<TensorImage> masks = new ArrayList<>();
- - TensorBuffer tensorBuffer = TensorBuffer.createDynamic(UINT8);
- - tensorBuffer.loadBuffer(buffers.get(0), maskShape);
- - TensorImage tensorImage = new TensorImage(UINT8);
- - tensorImage.load(tensorBuffer, GRAYSCALE);
- - masks.add(tensorImage);
- -
- - return masks;
- + CONFIDENCE_MASK(1) {
- + /**
- + * {@inheritDoc}
- + *
- + * @throws IllegalArgumentException if more the size of the masks list does not match the
- + * size
- + * of the coloredlabels list, or if the color space type of the any mask is not {@link
- + * ColorSpaceType#GRAYSCALE}
- + */
- + @Override
- + void assertMasksMatchColoredLabels(
- + List<TensorImage> masks, List<ColoredLabel> coloredLabels) {
- + checkArgument(masks.size() == coloredLabels.size(),
- + String.format(
- + "When using CONFIDENCE_MASK, the number of masks (%d) should match the number of"
- + + " coloredLabels (%d).",
- + masks.size(), coloredLabels.size()));
- +
- + for (TensorImage mask : masks) {
- + checkArgument(mask.getColorSpaceType() == GRAYSCALE,
- + "CONFIDENCE_MASK only supports masks of ColorSpaceType, GRAYSCALE, providing "
- + + mask.getColorSpaceType());
- + }
- + }
- +
- + @Override
- + List<TensorImage> createMasksFromBuffer(List<ByteBuffer> buffers, int[] maskShape) {
- + List<TensorImage> masks = new ArrayList<>();
- + for (ByteBuffer buffer : buffers) {
- + TensorBuffer tensorBuffer = TensorBuffer.createDynamic(FLOAT32);
- + tensorBuffer.loadBuffer(buffer, maskShape);
- + TensorImage tensorImage = new TensorImage(FLOAT32);
- + tensorImage.load(tensorBuffer, GRAYSCALE);
- + masks.add(tensorImage);
- + }
- + return masks;
- + }
- + };
- +
- + public int getValue() {
- + return value;
- }
- - },
-
- - /**
- - * Gives a list of output masks where, for each mask, each pixel represents the prediction
- - * confidence, usually in the [0, 1] range.
- - */
- - CONFIDENCE_MASK(1) {
- /**
- - * {@inheritDoc}
- + * Verifies that the given list of masks matches the list of colored labels.
- *
- - * @throws IllegalArgumentException if more the size of the masks list does not match the size
- - * of the coloredlabels list, or if the color space type of the any mask is not {@link
- - * ColorSpaceType#GRAYSCALE}
- + * @throws IllegalArgumentException if {@code masks} and {@code coloredLabels} do not match the
- + * output type
- */
- - @Override
- - void assertMasksMatchColoredLabels(List<TensorImage> masks, List<ColoredLabel> coloredLabels) {
- - checkArgument(
- - masks.size() == coloredLabels.size(),
- - String.format(
- - "When using CONFIDENCE_MASK, the number of masks (%d) should match the number of"
- - + " coloredLabels (%d).",
- - masks.size(), coloredLabels.size()));
- -
- - for (TensorImage mask : masks) {
- - checkArgument(
- - mask.getColorSpaceType() == GRAYSCALE,
- - "CONFIDENCE_MASK only supports masks of ColorSpaceType, GRAYSCALE, providing "
- - + mask.getColorSpaceType());
- - }
- - }
- -
- - @Override
- - List<TensorImage> createMasksFromBuffer(List<ByteBuffer> buffers, int[] maskShape) {
- - List<TensorImage> masks = new ArrayList<>();
- - for (ByteBuffer buffer : buffers) {
- - TensorBuffer tensorBuffer = TensorBuffer.createDynamic(FLOAT32);
- - tensorBuffer.loadBuffer(buffer, maskShape);
- - TensorImage tensorImage = new TensorImage(FLOAT32);
- - tensorImage.load(tensorBuffer, GRAYSCALE);
- - masks.add(tensorImage);
- - }
- - return masks;
- - }
- - };
- + abstract void assertMasksMatchColoredLabels(
- + List<TensorImage> masks, List<ColoredLabel> coloredLabels);
-
- - public int getValue() {
- - return value;
- - }
- + /** Creates the masks in {@link TensorImage} based on the data in {@link ByteBuffer}. */
- + abstract List<TensorImage> createMasksFromBuffer(List<ByteBuffer> buffers, int[] maskShape);
-
- - /**
- - * Verifies that the given list of masks matches the list of colored labels.
- - *
- - * @throws IllegalArgumentException if {@code masks} and {@code coloredLabels} do not match the
- - * output type
- - */
- - abstract void assertMasksMatchColoredLabels(
- - List<TensorImage> masks, List<ColoredLabel> coloredLabels);
- + private final int value;
-
- - /** Creates the masks in {@link TensorImage} based on the data in {@link ByteBuffer}. */
- - abstract List<TensorImage> createMasksFromBuffer(List<ByteBuffer> buffers, int[] maskShape);
- -
- - private final int value;
- -
- - private OutputType(int value) {
- - this.value = value;
- - }
- + private OutputType(int value) {
- + this.value = value;
- + }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/Segmentation.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/Segmentation.java
- index 018482c7e82db..f5062bc8745f0 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/Segmentation.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/Segmentation.java
- @@ -16,67 +16,69 @@ limitations under the License.
- package org.tensorflow.lite.task.vision.segmenter;
-
- import com.google.auto.value.AutoValue;
- +
- +import org.tensorflow.lite.support.image.TensorImage;
- +
- import java.util.ArrayList;
- import java.util.Collections;
- import java.util.List;
- -import org.tensorflow.lite.support.image.TensorImage;
-
- /** Represents the segmentation result of an {@link ImageSegmenter}. */
- @AutoValue
- public abstract class Segmentation {
- + /**
- + * Creates a {@link Segmentation} object.
- + *
- + * <p>{@link Segmentation} provides two types of outputs as indicated through {@link
- + * OutputType}:
- + *
- + * <p>{@link OutputType#CATEGORY_MASK}: the result contains a single category mask, which is a
- + * grayscale {@link TensorImage} with shape (height, width), in row major order. The value of
- + * each pixel in this mask represents the class to which the pixel in the mask belongs. The
- + * pixel values are in 1:1 corresponding with the colored labels, i.e. a pixel with value {@code
- + * i} is associated with {@code coloredLabels.get(i)}.
- + *
- + * <p>{@link OutputType#CONFIDENCE_MASK}: the result contains a list of confidence masks, which
- + * are in 1:1 correspondance with the colored labels, i.e. {@link masks.get(i)} is associated
- + * with
- + * {@code coloredLabels.get(i)}. Each confidence mask is a grayscale {@link TensorImage} with
- + * shape (height, width), in row major order. The value of each pixel in these masks represents
- + * the confidence score for this particular class.
- + *
- + * <p>IMPORTANT: segmentation masks are not direcly suited for display, in particular:<br>
- + * \* they are relative to the unrotated input frame, i.e. *not* taking into account the {@code
- + * Orientation} flag of the input FrameBuffer, <br>
- + * \* their dimensions are intrinsic to the model, i.e. *not* dependent on the input FrameBuffer
- + * dimensions.
- + *
- + * <p>Example of such post-processing, assuming: <br>
- + * \* an input FrameBuffer with width=640, height=480, orientation=kLeftBottom (i.e. the image
- + * will be rotated 90° clockwise during preprocessing to make it "upright"), <br>
- + * \* a model outputting masks of size 224x224. <br>
- + * In order to be directly displayable on top of the input image assumed to be displayed *with*
- + * the {@code Orientation} flag taken into account (according to the <a
- + * href="http://jpegclub.org/exif_orientation.html">EXIF specification</a>), the masks need to
- + * be: re-scaled to 640 x 480, then rotated 90° clockwise.
- + *
- + * @throws IllegalArgumentException if {@code masks} and {@code coloredLabels} do not match the
- + * {@code outputType}
- + */
- + static Segmentation create(
- + OutputType outputType, List<TensorImage> masks, List<ColoredLabel> coloredLabels) {
- + outputType.assertMasksMatchColoredLabels(masks, coloredLabels);
-
- - /**
- - * Creates a {@link Segmentation} object.
- - *
- - * <p>{@link Segmentation} provides two types of outputs as indicated through {@link OutputType}:
- - *
- - * <p>{@link OutputType#CATEGORY_MASK}: the result contains a single category mask, which is a
- - * grayscale {@link TensorImage} with shape (height, width), in row major order. The value of each
- - * pixel in this mask represents the class to which the pixel in the mask belongs. The pixel
- - * values are in 1:1 corresponding with the colored labels, i.e. a pixel with value {@code i} is
- - * associated with {@code coloredLabels.get(i)}.
- - *
- - * <p>{@link OutputType#CONFIDENCE_MASK}: the result contains a list of confidence masks, which
- - * are in 1:1 correspondance with the colored labels, i.e. {@link masks.get(i)} is associated with
- - * {@code coloredLabels.get(i)}. Each confidence mask is a grayscale {@link TensorImage} with
- - * shape (height, width), in row major order. The value of each pixel in these masks represents
- - * the confidence score for this particular class.
- - *
- - * <p>IMPORTANT: segmentation masks are not direcly suited for display, in particular:<br>
- - * \* they are relative to the unrotated input frame, i.e. *not* taking into account the {@code
- - * Orientation} flag of the input FrameBuffer, <br>
- - * \* their dimensions are intrinsic to the model, i.e. *not* dependent on the input FrameBuffer
- - * dimensions.
- - *
- - * <p>Example of such post-processing, assuming: <br>
- - * \* an input FrameBuffer with width=640, height=480, orientation=kLeftBottom (i.e. the image
- - * will be rotated 90° clockwise during preprocessing to make it "upright"), <br>
- - * \* a model outputting masks of size 224x224. <br>
- - * In order to be directly displayable on top of the input image assumed to be displayed *with*
- - * the {@code Orientation} flag taken into account (according to the <a
- - * href="http://jpegclub.org/exif_orientation.html">EXIF specification</a>), the masks need to be:
- - * re-scaled to 640 x 480, then rotated 90° clockwise.
- - *
- - * @throws IllegalArgumentException if {@code masks} and {@code coloredLabels} do not match the
- - * {@code outputType}
- - */
- - static Segmentation create(
- - OutputType outputType, List<TensorImage> masks, List<ColoredLabel> coloredLabels) {
- - outputType.assertMasksMatchColoredLabels(masks, coloredLabels);
- -
- - return new AutoValue_Segmentation(
- - outputType,
- - Collections.unmodifiableList(new ArrayList<TensorImage>(masks)),
- - Collections.unmodifiableList(new ArrayList<ColoredLabel>(coloredLabels)));
- - }
- + return new AutoValue_Segmentation(outputType,
- + Collections.unmodifiableList(new ArrayList<TensorImage>(masks)),
- + Collections.unmodifiableList(new ArrayList<ColoredLabel>(coloredLabels)));
- + }
-
- - public abstract OutputType getOutputType();
- + public abstract OutputType getOutputType();
-
- - // As an open source project, we've been trying avoiding depending on common java libraries,
- - // such as Guava, because it may introduce conflicts with clients who also happen to use those
- - // libraries. Therefore, instead of using ImmutableList here, we convert the List into
- - // unmodifiableList in create() to make it less vulnerable.
- - public abstract List<TensorImage> getMasks();
- + // As an open source project, we've been trying avoiding depending on common java libraries,
- + // such as Guava, because it may introduce conflicts with clients who also happen to use those
- + // libraries. Therefore, instead of using ImmutableList here, we convert the List into
- + // unmodifiableList in create() to make it less vulnerable.
- + public abstract List<TensorImage> getMasks();
-
- - public abstract List<ColoredLabel> getColoredLabels();
- + public abstract List<ColoredLabel> getColoredLabels();
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/audio/TensorAudioTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/audio/TensorAudioTest.java
- index f53cfd7a9510a..02aa581c3559c 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/audio/TensorAudioTest.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/audio/TensorAudioTest.java
- @@ -16,6 +16,7 @@ limitations under the License.
- package org.tensorflow.lite.support.audio;
-
- import static com.google.common.truth.Truth.assertThat;
- +
- import static org.junit.Assert.assertThrows;
- import static org.mockito.ArgumentMatchers.any;
- import static org.mockito.ArgumentMatchers.anyInt;
- @@ -25,6 +26,7 @@ import static org.mockito.Mockito.when;
-
- import android.media.AudioFormat;
- import android.media.AudioRecord;
- +
- import org.junit.Test;
- import org.junit.runner.RunWith;
- import org.junit.runners.Suite;
- @@ -35,259 +37,258 @@ import org.tensorflow.lite.support.audio.TensorAudio.TensorAudioFormat;
- /** Test for {@link TensorAudio}. */
- @RunWith(Suite.class)
- @SuiteClasses({
- - TensorAudioTest.General.class,
- + TensorAudioTest.General.class,
- })
- public class TensorAudioTest {
- -
- - /** General tests of TensorAudio. */
- - @RunWith(RobolectricTestRunner.class)
- - public static final class General extends TensorAudioTest {
- - @Test
- - public void createSucceedsWithTensorAudioFormat() throws Exception {
- - TensorAudio tensor =
- - TensorAudio.create(
- - TensorAudioFormat.builder().setChannels(1).setSampleRate(2).build(), 100);
- - assertThat(tensor.getFormat().getChannels()).isEqualTo(1);
- - assertThat(tensor.getFormat().getSampleRate()).isEqualTo(2);
- - assertThat(tensor.getTensorBuffer().getFlatSize()).isEqualTo(100);
- - }
- -
- - @Test
- - public void createSucceedsWithTensorAudioFormatWithMultipleChannels() throws Exception {
- - TensorAudio tensor =
- - TensorAudio.create(
- - TensorAudioFormat.builder().setChannels(5).setSampleRate(2).build(), 100);
- - assertThat(tensor.getFormat().getChannels()).isEqualTo(5);
- - assertThat(tensor.getFormat().getSampleRate()).isEqualTo(2);
- - assertThat(tensor.getTensorBuffer().getFlatSize()).isEqualTo(500);
- - }
- -
- - @Test
- - public void createSucceededsWithDefaultArguments() throws Exception {
- - TensorAudio tensor =
- - TensorAudio.create(TensorAudioFormat.builder().setSampleRate(20).build(), 1000);
- - // Number of channels defaults to 1.
- - assertThat(tensor.getFormat().getChannels()).isEqualTo(1);
- - assertThat(tensor.getFormat().getSampleRate()).isEqualTo(20);
- - assertThat(tensor.getTensorBuffer().getFlatSize()).isEqualTo(1000);
- - }
- -
- - @Test
- - public void createSucceedsWithAudioFormat() throws Exception {
- - AudioFormat format =
- - new AudioFormat.Builder()
- - .setChannelMask(AudioFormat.CHANNEL_IN_STEREO)
- - .setEncoding(AudioFormat.ENCODING_PCM_16BIT)
- - .setSampleRate(16000)
- - .build();
- - TensorAudio tensor = TensorAudio.create(format, 100);
- - // STEREO has 2 channels
- - assertThat(tensor.getFormat().getChannels()).isEqualTo(2);
- - assertThat(tensor.getFormat().getSampleRate()).isEqualTo(16000);
- - // flatSize = channelCount * sampleCount
- - assertThat(tensor.getTensorBuffer().getFlatSize()).isEqualTo(200);
- - }
- -
- - @Test
- - public void createFailedWithInvalidSampleRate() throws Exception {
- - IllegalArgumentException exception =
- - assertThrows(
- - IllegalArgumentException.class,
- - () -> TensorAudio.create(TensorAudioFormat.builder().setSampleRate(0).build(), 100));
- - // Sample rate 0 is not allowed
- - assertThat(exception).hasMessageThat().ignoringCase().contains("sample rate");
- - }
- -
- - @Test
- - public void createFailedWithInvalidChannels() throws Exception {
- - IllegalArgumentException exception =
- - assertThrows(
- - IllegalArgumentException.class,
- - () ->
- - TensorAudio.create(
- - TensorAudioFormat.builder().setSampleRate(1).setChannels(-1).build(), 100));
- - // Negative channels is not allowed
- - assertThat(exception).hasMessageThat().ignoringCase().contains("channels");
- - }
- -
- - @Test
- - public void loadSucceedsFromArray() throws Exception {
- - TensorAudioFormat format =
- - TensorAudioFormat.builder().setChannels(2).setSampleRate(2).build();
- - TensorAudio tensor = TensorAudio.create(format, 2);
- - assertThat(tensor.getTensorBuffer().getFloatArray()).isEqualTo(new float[4]);
- -
- - tensor.load(new float[] {2.f, 0});
- - assertThat(tensor.getTensorBuffer().getFloatArray())
- - .usingTolerance(0.001f)
- - .containsExactly(new float[] {0, 0, 2.f, 0});
- -
- - tensor.load(new float[] {2.f, 3.f}, 0, 2);
- - assertThat(tensor.getTensorBuffer().getFloatArray())
- - .usingTolerance(0.001f)
- - .containsExactly(new float[] {2.f, 0, 2.f, 3.f});
- -
- - tensor.load(new float[] {2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f}, 1, 6);
- - // The sequence is longer than the ring buffer size so it's expected to keep only the last 4
- - // numbers (index 3 to 6) of the load target sub-sequence (index 1 to 6).
- - assertThat(tensor.getTensorBuffer().getFloatArray())
- - .usingTolerance(0.001f)
- - .containsExactly(new float[] {5.f, 6.f, 7.f, 8.f});
- -
- - tensor.load(new short[] {Short.MAX_VALUE, Short.MIN_VALUE});
- - assertThat(tensor.getTensorBuffer().getFloatArray())
- - .usingTolerance(0.001f)
- - .containsExactly(new float[] {7.f, 8.f, 1.f, -1.f});
- -
- - tensor.load(new short[] {1000, 2000, 3000, 0, 1000, Short.MIN_VALUE, 4000, 5000, 6000}, 3, 6);
- - // The sequence is longer than the ring buffer size so it's expected to keep only the last 4
- - // numbers.
- - assertThat(tensor.getTensorBuffer().getFloatArray())
- - .usingTolerance(0.001f)
- - .containsExactly(
- - new float[] {
- - -1.f, 4000.f / Short.MAX_VALUE, 5000.f / Short.MAX_VALUE, 6000.f / Short.MAX_VALUE
- - });
- - }
- -
- - @Test
- - public void loadFailsWithIndexOutOfRange() throws Exception {
- - TensorAudioFormat format = TensorAudioFormat.builder().setSampleRate(2).build();
- - TensorAudio tensor = TensorAudio.create(format, 5);
- -
- - assertThrows(IllegalArgumentException.class, () -> tensor.load(new short[100], 99, 2));
- -
- - assertThrows(IllegalArgumentException.class, () -> tensor.load(new float[100], 99, 2));
- - }
- -
- - @Test
- - public void loadFailsWithIncompatibleInputSize() throws Exception {
- - TensorAudioFormat format =
- - TensorAudioFormat.builder().setChannels(3).setSampleRate(2).build();
- - TensorAudio tensor = TensorAudio.create(format, 5);
- -
- - assertThrows(IllegalArgumentException.class, () -> tensor.load(new float[1]));
- -
- - assertThrows(IllegalArgumentException.class, () -> tensor.load(new short[2]));
- -
- - assertThrows(IllegalArgumentException.class, () -> tensor.load(new float[2], 1, 1));
- -
- - assertThrows(IllegalArgumentException.class, () -> tensor.load(new short[5], 2, 4));
- - }
- -
- - @Test
- - public void loadAudioRecordSucceeds() throws Exception {
- - TensorAudio tensor =
- - TensorAudio.create(TensorAudioFormat.builder().setSampleRate(16000).build(), 4);
- - tensor.load(new float[] {1, 2, 3, 4, 5});
- - assertThat(tensor.getTensorBuffer().getFloatArray())
- - .isEqualTo(new float[] {2.f, 3.f, 4.f, 5.f});
- -
- - AudioRecord record = mock(AudioRecord.class);
- - when(record.getBufferSizeInFrames()).thenReturn(5);
- - when(record.getChannelCount()).thenReturn(1);
- - when(record.getAudioFormat()).thenReturn(AudioFormat.ENCODING_PCM_FLOAT);
- - when(record.getFormat())
- - .thenReturn(
- - new AudioFormat.Builder()
- - .setChannelMask(AudioFormat.CHANNEL_IN_MONO)
- - .setEncoding(AudioFormat.ENCODING_PCM_FLOAT)
- - .setSampleRate(16000)
- - .build());
- - // Unused
- - when(record.read(any(short[].class), anyInt(), anyInt(), eq(AudioRecord.READ_NON_BLOCKING)))
- - .thenReturn(AudioRecord.ERROR_INVALID_OPERATION);
- - // Used
- - when(record.read(any(float[].class), anyInt(), anyInt(), eq(AudioRecord.READ_NON_BLOCKING)))
- - .thenReturn(1);
- - assertThat(tensor.load(record)).isEqualTo(1);
- - assertThat(tensor.getTensorBuffer().getFloatArray())
- - .isEqualTo(new float[] {3.f, 4.f, 5.f, 0});
- -
- - record = mock(AudioRecord.class);
- - when(record.getBufferSizeInFrames()).thenReturn(5);
- - when(record.getChannelCount()).thenReturn(1);
- - when(record.getAudioFormat()).thenReturn(AudioFormat.ENCODING_PCM_16BIT);
- - when(record.getFormat())
- - .thenReturn(
- - new AudioFormat.Builder()
- - .setChannelMask(AudioFormat.CHANNEL_IN_MONO)
- - .setEncoding(AudioFormat.ENCODING_PCM_16BIT)
- - .setSampleRate(16000)
- - .build());
- - // Used
- - when(record.read(any(short[].class), anyInt(), anyInt(), eq(AudioRecord.READ_NON_BLOCKING)))
- - .thenReturn(2);
- - // Unused
- - when(record.read(any(float[].class), anyInt(), anyInt(), eq(AudioRecord.READ_NON_BLOCKING)))
- - .thenReturn(AudioRecord.ERROR_INVALID_OPERATION);
- - assertThat(tensor.load(record)).isEqualTo(2);
- - assertThat(tensor.getTensorBuffer().getFloatArray()).isEqualTo(new float[] {5.f, 0, 0, 0});
- - }
- -
- - @Test
- - public void loadAudioRecordFailsWithErrorState() throws Exception {
- - TensorAudio tensor =
- - TensorAudio.create(TensorAudioFormat.builder().setSampleRate(16000).build(), 4);
- - tensor.load(new float[] {1, 2, 3, 4, 5});
- - assertThat(tensor.getTensorBuffer().getFloatArray())
- - .isEqualTo(new float[] {2.f, 3.f, 4.f, 5.f});
- -
- - AudioRecord record = mock(AudioRecord.class);
- - when(record.getAudioFormat()).thenReturn(AudioFormat.ENCODING_PCM_FLOAT);
- - when(record.getFormat())
- - .thenReturn(
- - new AudioFormat.Builder()
- - .setChannelMask(AudioFormat.CHANNEL_IN_MONO)
- - .setEncoding(AudioFormat.ENCODING_PCM_FLOAT)
- - .setSampleRate(16000)
- - .build());
- - // Unused
- - when(record.read(any(short[].class), anyInt(), anyInt(), eq(AudioRecord.READ_NON_BLOCKING)))
- - .thenReturn(AudioRecord.ERROR_INVALID_OPERATION);
- - // Used
- - when(record.read(any(float[].class), anyInt(), anyInt(), eq(AudioRecord.READ_NON_BLOCKING)))
- - .thenReturn(AudioRecord.ERROR_DEAD_OBJECT);
- - IllegalStateException exception =
- - assertThrows(IllegalStateException.class, () -> tensor.load(record));
- - assertThat(exception).hasMessageThat().contains("ERROR_DEAD_OBJECT");
- - }
- -
- - @Test
- - public void loadAudioRecordFailsWithUnsupportedAudioEncoding() throws Exception {
- - TensorAudio tensor =
- - TensorAudio.create(TensorAudioFormat.builder().setSampleRate(16000).build(), 4);
- - AudioRecord record = mock(AudioRecord.class);
- - when(record.getFormat())
- - .thenReturn(
- - new AudioFormat.Builder()
- - .setChannelMask(AudioFormat.CHANNEL_IN_MONO)
- - .setEncoding(AudioFormat.ENCODING_PCM_8BIT) // Not supported
- - .setSampleRate(16000)
- - .build());
- - when(record.getAudioFormat()).thenReturn(AudioFormat.ENCODING_PCM_8BIT);
- -
- - IllegalArgumentException exception =
- - assertThrows(IllegalArgumentException.class, () -> tensor.load(record));
- - assertThat(exception).hasMessageThat().ignoringCase().contains("unsupported encoding");
- - }
- -
- - @Test
- - public void loadAudioRecordFailsWithIncompatibleAudioFormat() throws Exception {
- - TensorAudio tensor =
- - TensorAudio.create(TensorAudioFormat.builder().setSampleRate(16000).build(), 4);
- - AudioRecord record = mock(AudioRecord.class);
- - when(record.getFormat())
- - .thenReturn(
- - new AudioFormat.Builder()
- - .setChannelMask(AudioFormat.CHANNEL_IN_MONO)
- - .setEncoding(AudioFormat.ENCODING_PCM_FLOAT)
- - .setSampleRate(44100) // Mismatch
- - .build());
- -
- - IllegalArgumentException exception =
- - assertThrows(IllegalArgumentException.class, () -> tensor.load(record));
- - assertThat(exception).hasMessageThat().ignoringCase().contains("Incompatible audio format");
- + /** General tests of TensorAudio. */
- + @RunWith(RobolectricTestRunner.class)
- + public static final class General extends TensorAudioTest {
- + @Test
- + public void createSucceedsWithTensorAudioFormat() throws Exception {
- + TensorAudio tensor = TensorAudio.create(
- + TensorAudioFormat.builder().setChannels(1).setSampleRate(2).build(), 100);
- + assertThat(tensor.getFormat().getChannels()).isEqualTo(1);
- + assertThat(tensor.getFormat().getSampleRate()).isEqualTo(2);
- + assertThat(tensor.getTensorBuffer().getFlatSize()).isEqualTo(100);
- + }
- +
- + @Test
- + public void createSucceedsWithTensorAudioFormatWithMultipleChannels() throws Exception {
- + TensorAudio tensor = TensorAudio.create(
- + TensorAudioFormat.builder().setChannels(5).setSampleRate(2).build(), 100);
- + assertThat(tensor.getFormat().getChannels()).isEqualTo(5);
- + assertThat(tensor.getFormat().getSampleRate()).isEqualTo(2);
- + assertThat(tensor.getTensorBuffer().getFlatSize()).isEqualTo(500);
- + }
- +
- + @Test
- + public void createSucceededsWithDefaultArguments() throws Exception {
- + TensorAudio tensor =
- + TensorAudio.create(TensorAudioFormat.builder().setSampleRate(20).build(), 1000);
- + // Number of channels defaults to 1.
- + assertThat(tensor.getFormat().getChannels()).isEqualTo(1);
- + assertThat(tensor.getFormat().getSampleRate()).isEqualTo(20);
- + assertThat(tensor.getTensorBuffer().getFlatSize()).isEqualTo(1000);
- + }
- +
- + @Test
- + public void createSucceedsWithAudioFormat() throws Exception {
- + AudioFormat format = new AudioFormat.Builder()
- + .setChannelMask(AudioFormat.CHANNEL_IN_STEREO)
- + .setEncoding(AudioFormat.ENCODING_PCM_16BIT)
- + .setSampleRate(16000)
- + .build();
- + TensorAudio tensor = TensorAudio.create(format, 100);
- + // STEREO has 2 channels
- + assertThat(tensor.getFormat().getChannels()).isEqualTo(2);
- + assertThat(tensor.getFormat().getSampleRate()).isEqualTo(16000);
- + // flatSize = channelCount * sampleCount
- + assertThat(tensor.getTensorBuffer().getFlatSize()).isEqualTo(200);
- + }
- +
- + @Test
- + public void createFailedWithInvalidSampleRate() throws Exception {
- + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
- + ()
- + -> TensorAudio.create(
- + TensorAudioFormat.builder().setSampleRate(0).build(), 100));
- + // Sample rate 0 is not allowed
- + assertThat(exception).hasMessageThat().ignoringCase().contains("sample rate");
- + }
- +
- + @Test
- + public void createFailedWithInvalidChannels() throws Exception {
- + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
- + ()
- + -> TensorAudio.create(TensorAudioFormat.builder()
- + .setSampleRate(1)
- + .setChannels(-1)
- + .build(),
- + 100));
- + // Negative channels is not allowed
- + assertThat(exception).hasMessageThat().ignoringCase().contains("channels");
- + }
- +
- + @Test
- + public void loadSucceedsFromArray() throws Exception {
- + TensorAudioFormat format =
- + TensorAudioFormat.builder().setChannels(2).setSampleRate(2).build();
- + TensorAudio tensor = TensorAudio.create(format, 2);
- + assertThat(tensor.getTensorBuffer().getFloatArray()).isEqualTo(new float[4]);
- +
- + tensor.load(new float[] {2.f, 0});
- + assertThat(tensor.getTensorBuffer().getFloatArray())
- + .usingTolerance(0.001f)
- + .containsExactly(new float[] {0, 0, 2.f, 0});
- +
- + tensor.load(new float[] {2.f, 3.f}, 0, 2);
- + assertThat(tensor.getTensorBuffer().getFloatArray())
- + .usingTolerance(0.001f)
- + .containsExactly(new float[] {2.f, 0, 2.f, 3.f});
- +
- + tensor.load(new float[] {2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f}, 1, 6);
- + // The sequence is longer than the ring buffer size so it's expected to keep only the
- + // last 4 numbers (index 3 to 6) of the load target sub-sequence (index 1 to 6).
- + assertThat(tensor.getTensorBuffer().getFloatArray())
- + .usingTolerance(0.001f)
- + .containsExactly(new float[] {5.f, 6.f, 7.f, 8.f});
- +
- + tensor.load(new short[] {Short.MAX_VALUE, Short.MIN_VALUE});
- + assertThat(tensor.getTensorBuffer().getFloatArray())
- + .usingTolerance(0.001f)
- + .containsExactly(new float[] {7.f, 8.f, 1.f, -1.f});
- +
- + tensor.load(new short[] {1000, 2000, 3000, 0, 1000, Short.MIN_VALUE, 4000, 5000, 6000},
- + 3, 6);
- + // The sequence is longer than the ring buffer size so it's expected to keep only the
- + // last 4 numbers.
- + assertThat(tensor.getTensorBuffer().getFloatArray())
- + .usingTolerance(0.001f)
- + .containsExactly(new float[] {-1.f, 4000.f / Short.MAX_VALUE,
- + 5000.f / Short.MAX_VALUE, 6000.f / Short.MAX_VALUE});
- + }
- +
- + @Test
- + public void loadFailsWithIndexOutOfRange() throws Exception {
- + TensorAudioFormat format = TensorAudioFormat.builder().setSampleRate(2).build();
- + TensorAudio tensor = TensorAudio.create(format, 5);
- +
- + assertThrows(IllegalArgumentException.class, () -> tensor.load(new short[100], 99, 2));
- +
- + assertThrows(IllegalArgumentException.class, () -> tensor.load(new float[100], 99, 2));
- + }
- +
- + @Test
- + public void loadFailsWithIncompatibleInputSize() throws Exception {
- + TensorAudioFormat format =
- + TensorAudioFormat.builder().setChannels(3).setSampleRate(2).build();
- + TensorAudio tensor = TensorAudio.create(format, 5);
- +
- + assertThrows(IllegalArgumentException.class, () -> tensor.load(new float[1]));
- +
- + assertThrows(IllegalArgumentException.class, () -> tensor.load(new short[2]));
- +
- + assertThrows(IllegalArgumentException.class, () -> tensor.load(new float[2], 1, 1));
- +
- + assertThrows(IllegalArgumentException.class, () -> tensor.load(new short[5], 2, 4));
- + }
- +
- + @Test
- + public void loadAudioRecordSucceeds() throws Exception {
- + TensorAudio tensor =
- + TensorAudio.create(TensorAudioFormat.builder().setSampleRate(16000).build(), 4);
- + tensor.load(new float[] {1, 2, 3, 4, 5});
- + assertThat(tensor.getTensorBuffer().getFloatArray())
- + .isEqualTo(new float[] {2.f, 3.f, 4.f, 5.f});
- +
- + AudioRecord record = mock(AudioRecord.class);
- + when(record.getBufferSizeInFrames()).thenReturn(5);
- + when(record.getChannelCount()).thenReturn(1);
- + when(record.getAudioFormat()).thenReturn(AudioFormat.ENCODING_PCM_FLOAT);
- + when(record.getFormat())
- + .thenReturn(new AudioFormat.Builder()
- + .setChannelMask(AudioFormat.CHANNEL_IN_MONO)
- + .setEncoding(AudioFormat.ENCODING_PCM_FLOAT)
- + .setSampleRate(16000)
- + .build());
- + // Unused
- + when(record.read(
- + any(short[].class), anyInt(), anyInt(), eq(AudioRecord.READ_NON_BLOCKING)))
- + .thenReturn(AudioRecord.ERROR_INVALID_OPERATION);
- + // Used
- + when(record.read(
- + any(float[].class), anyInt(), anyInt(), eq(AudioRecord.READ_NON_BLOCKING)))
- + .thenReturn(1);
- + assertThat(tensor.load(record)).isEqualTo(1);
- + assertThat(tensor.getTensorBuffer().getFloatArray())
- + .isEqualTo(new float[] {3.f, 4.f, 5.f, 0});
- +
- + record = mock(AudioRecord.class);
- + when(record.getBufferSizeInFrames()).thenReturn(5);
- + when(record.getChannelCount()).thenReturn(1);
- + when(record.getAudioFormat()).thenReturn(AudioFormat.ENCODING_PCM_16BIT);
- + when(record.getFormat())
- + .thenReturn(new AudioFormat.Builder()
- + .setChannelMask(AudioFormat.CHANNEL_IN_MONO)
- + .setEncoding(AudioFormat.ENCODING_PCM_16BIT)
- + .setSampleRate(16000)
- + .build());
- + // Used
- + when(record.read(
- + any(short[].class), anyInt(), anyInt(), eq(AudioRecord.READ_NON_BLOCKING)))
- + .thenReturn(2);
- + // Unused
- + when(record.read(
- + any(float[].class), anyInt(), anyInt(), eq(AudioRecord.READ_NON_BLOCKING)))
- + .thenReturn(AudioRecord.ERROR_INVALID_OPERATION);
- + assertThat(tensor.load(record)).isEqualTo(2);
- + assertThat(tensor.getTensorBuffer().getFloatArray())
- + .isEqualTo(new float[] {5.f, 0, 0, 0});
- + }
- +
- + @Test
- + public void loadAudioRecordFailsWithErrorState() throws Exception {
- + TensorAudio tensor =
- + TensorAudio.create(TensorAudioFormat.builder().setSampleRate(16000).build(), 4);
- + tensor.load(new float[] {1, 2, 3, 4, 5});
- + assertThat(tensor.getTensorBuffer().getFloatArray())
- + .isEqualTo(new float[] {2.f, 3.f, 4.f, 5.f});
- +
- + AudioRecord record = mock(AudioRecord.class);
- + when(record.getAudioFormat()).thenReturn(AudioFormat.ENCODING_PCM_FLOAT);
- + when(record.getFormat())
- + .thenReturn(new AudioFormat.Builder()
- + .setChannelMask(AudioFormat.CHANNEL_IN_MONO)
- + .setEncoding(AudioFormat.ENCODING_PCM_FLOAT)
- + .setSampleRate(16000)
- + .build());
- + // Unused
- + when(record.read(
- + any(short[].class), anyInt(), anyInt(), eq(AudioRecord.READ_NON_BLOCKING)))
- + .thenReturn(AudioRecord.ERROR_INVALID_OPERATION);
- + // Used
- + when(record.read(
- + any(float[].class), anyInt(), anyInt(), eq(AudioRecord.READ_NON_BLOCKING)))
- + .thenReturn(AudioRecord.ERROR_DEAD_OBJECT);
- + IllegalStateException exception =
- + assertThrows(IllegalStateException.class, () -> tensor.load(record));
- + assertThat(exception).hasMessageThat().contains("ERROR_DEAD_OBJECT");
- + }
- +
- + @Test
- + public void loadAudioRecordFailsWithUnsupportedAudioEncoding() throws Exception {
- + TensorAudio tensor =
- + TensorAudio.create(TensorAudioFormat.builder().setSampleRate(16000).build(), 4);
- + AudioRecord record = mock(AudioRecord.class);
- + when(record.getFormat())
- + .thenReturn(new AudioFormat.Builder()
- + .setChannelMask(AudioFormat.CHANNEL_IN_MONO)
- + .setEncoding(AudioFormat.ENCODING_PCM_8BIT) // Not supported
- + .setSampleRate(16000)
- + .build());
- + when(record.getAudioFormat()).thenReturn(AudioFormat.ENCODING_PCM_8BIT);
- +
- + IllegalArgumentException exception =
- + assertThrows(IllegalArgumentException.class, () -> tensor.load(record));
- + assertThat(exception).hasMessageThat().ignoringCase().contains("unsupported encoding");
- + }
- +
- + @Test
- + public void loadAudioRecordFailsWithIncompatibleAudioFormat() throws Exception {
- + TensorAudio tensor =
- + TensorAudio.create(TensorAudioFormat.builder().setSampleRate(16000).build(), 4);
- + AudioRecord record = mock(AudioRecord.class);
- + when(record.getFormat())
- + .thenReturn(new AudioFormat.Builder()
- + .setChannelMask(AudioFormat.CHANNEL_IN_MONO)
- + .setEncoding(AudioFormat.ENCODING_PCM_FLOAT)
- + .setSampleRate(44100) // Mismatch
- + .build());
- +
- + IllegalArgumentException exception =
- + assertThrows(IllegalArgumentException.class, () -> tensor.load(record));
- + assertThat(exception).hasMessageThat().ignoringCase().contains(
- + "Incompatible audio format");
- + }
- }
- - }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/FileUtilTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/FileUtilTest.java
- index d97665d1ed771..1d26476733c98 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/FileUtilTest.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/FileUtilTest.java
- @@ -18,78 +18,81 @@ package org.tensorflow.lite.support.common;
- import static com.google.common.truth.Truth.assertThat;
-
- import android.content.Context;
- +
- import androidx.test.core.app.ApplicationProvider;
- +
- +import org.junit.Assert;
- +import org.junit.Test;
- +import org.junit.runner.RunWith;
- +import org.robolectric.RobolectricTestRunner;
- +
- import java.io.ByteArrayInputStream;
- import java.io.IOException;
- import java.io.InputStream;
- import java.nio.MappedByteBuffer;
- import java.nio.charset.Charset;
- import java.util.List;
- -import org.junit.Assert;
- -import org.junit.Test;
- -import org.junit.runner.RunWith;
- -import org.robolectric.RobolectricTestRunner;
-
- /** Tests of {@link org.tensorflow.lite.support.common.FileUtil}. */
- @RunWith(RobolectricTestRunner.class)
- public final class FileUtilTest {
- - private final Context context = ApplicationProvider.getApplicationContext();
- - private static final String LABEL_PATH = "flower_labels.txt";
- -
- - @Test
- - public void testLoadLabels() throws IOException {
- - List<String> labels = FileUtil.loadLabels(context, LABEL_PATH);
- - assertThat(labels)
- - .containsExactly("daisy", "dandelion", "roses", "sunflowers", "tulips")
- - .inOrder();
- - }
- -
- - @Test
- - public void testLoadLabelsFromInputStream() throws IOException {
- - InputStream inputStream = context.getAssets().open(LABEL_PATH);
- - assertThat(FileUtil.loadLabels(inputStream))
- - .containsExactly("daisy", "dandelion", "roses", "sunflowers", "tulips")
- - .inOrder();
- - }
- -
- - @Test
- - public void whitespaceLabelsShouldNotCount() throws IOException {
- - String s = "a\nb\n \n\n\nc";
- - InputStream stream = new ByteArrayInputStream(s.getBytes(Charset.defaultCharset()));
- - assertThat(FileUtil.loadLabels(stream)).hasSize(3);
- - }
- -
- - @Test
- - public void testLoadLabelsNullContext() throws IOException {
- - Context nullContext = null;
- - Assert.assertThrows(
- - NullPointerException.class, () -> FileUtil.loadLabels(nullContext, LABEL_PATH));
- - }
- -
- - @Test
- - public void testLoadLabelsNullFilePath() throws IOException {
- - String nullFilePath = null;
- - Assert.assertThrows(
- - NullPointerException.class, () -> FileUtil.loadLabels(context, nullFilePath));
- - }
- -
- - @Test
- - public void testLoadMappedFile() throws IOException {
- - MappedByteBuffer byteModel = FileUtil.loadMappedFile(context, LABEL_PATH);
- - assertThat(byteModel).isNotNull();
- - }
- -
- - @Test
- - public void testLoadMappedFileWithNullContext() throws IOException {
- - Context nullContext = null;
- - Assert.assertThrows(
- - NullPointerException.class, () -> FileUtil.loadMappedFile(nullContext, LABEL_PATH));
- - }
- -
- - @Test
- - public void loadMappedFileWithNullFilePath() throws IOException {
- - String nullFilePath = null;
- - Assert.assertThrows(
- - NullPointerException.class, () -> FileUtil.loadMappedFile(context, nullFilePath));
- - }
- + private final Context context = ApplicationProvider.getApplicationContext();
- + private static final String LABEL_PATH = "flower_labels.txt";
- +
- + @Test
- + public void testLoadLabels() throws IOException {
- + List<String> labels = FileUtil.loadLabels(context, LABEL_PATH);
- + assertThat(labels)
- + .containsExactly("daisy", "dandelion", "roses", "sunflowers", "tulips")
- + .inOrder();
- + }
- +
- + @Test
- + public void testLoadLabelsFromInputStream() throws IOException {
- + InputStream inputStream = context.getAssets().open(LABEL_PATH);
- + assertThat(FileUtil.loadLabels(inputStream))
- + .containsExactly("daisy", "dandelion", "roses", "sunflowers", "tulips")
- + .inOrder();
- + }
- +
- + @Test
- + public void whitespaceLabelsShouldNotCount() throws IOException {
- + String s = "a\nb\n \n\n\nc";
- + InputStream stream = new ByteArrayInputStream(s.getBytes(Charset.defaultCharset()));
- + assertThat(FileUtil.loadLabels(stream)).hasSize(3);
- + }
- +
- + @Test
- + public void testLoadLabelsNullContext() throws IOException {
- + Context nullContext = null;
- + Assert.assertThrows(
- + NullPointerException.class, () -> FileUtil.loadLabels(nullContext, LABEL_PATH));
- + }
- +
- + @Test
- + public void testLoadLabelsNullFilePath() throws IOException {
- + String nullFilePath = null;
- + Assert.assertThrows(
- + NullPointerException.class, () -> FileUtil.loadLabels(context, nullFilePath));
- + }
- +
- + @Test
- + public void testLoadMappedFile() throws IOException {
- + MappedByteBuffer byteModel = FileUtil.loadMappedFile(context, LABEL_PATH);
- + assertThat(byteModel).isNotNull();
- + }
- +
- + @Test
- + public void testLoadMappedFileWithNullContext() throws IOException {
- + Context nullContext = null;
- + Assert.assertThrows(
- + NullPointerException.class, () -> FileUtil.loadMappedFile(nullContext, LABEL_PATH));
- + }
- +
- + @Test
- + public void loadMappedFileWithNullFilePath() throws IOException {
- + String nullFilePath = null;
- + Assert.assertThrows(
- + NullPointerException.class, () -> FileUtil.loadMappedFile(context, nullFilePath));
- + }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/TensorProcessorTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/TensorProcessorTest.java
- index 43a7f7cd1ce29..82f97f2534cf7 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/TensorProcessorTest.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/TensorProcessorTest.java
- @@ -27,59 +27,58 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
- /** Tests for {@link TensorProcessor}. */
- @RunWith(RobolectricTestRunner.class)
- public final class TensorProcessorTest {
- + private static final int EXAMPLE_NUM_FEATURES = 1000;
- + private static final float MEAN = 127.5f;
- + private static final float STDDEV = 127.5f;
-
- - private static final int EXAMPLE_NUM_FEATURES = 1000;
- - private static final float MEAN = 127.5f;
- - private static final float STDDEV = 127.5f;
- -
- - @Test
- - public void testBuild() {
- - TensorProcessor processor =
- - new TensorProcessor.Builder().add(new NormalizeOp(MEAN, STDDEV)).build();
- - assertThat(processor).isNotNull();
- - }
- + @Test
- + public void testBuild() {
- + TensorProcessor processor =
- + new TensorProcessor.Builder().add(new NormalizeOp(MEAN, STDDEV)).build();
- + assertThat(processor).isNotNull();
- + }
-
- - @Test
- - public void testNormalize() {
- - TensorBuffer input = createExampleTensorBuffer();
- - TensorProcessor processor =
- - new TensorProcessor.Builder().add(new NormalizeOp(MEAN, STDDEV)).build();
- - TensorBuffer output = processor.process(input);
- + @Test
- + public void testNormalize() {
- + TensorBuffer input = createExampleTensorBuffer();
- + TensorProcessor processor =
- + new TensorProcessor.Builder().add(new NormalizeOp(MEAN, STDDEV)).build();
- + TensorBuffer output = processor.process(input);
-
- - float[] pixels = output.getFloatArray();
- - assertThat(pixels.length).isEqualTo(EXAMPLE_NUM_FEATURES);
- - for (float p : pixels) {
- - assertThat(p).isAtLeast(-1);
- - assertThat(p).isAtMost(1);
- + float[] pixels = output.getFloatArray();
- + assertThat(pixels.length).isEqualTo(EXAMPLE_NUM_FEATURES);
- + for (float p : pixels) {
- + assertThat(p).isAtLeast(-1);
- + assertThat(p).isAtMost(1);
- + }
- }
- - }
-
- - @Test
- - public void testMultipleNormalize() {
- - TensorBuffer input = createExampleTensorBuffer();
- - TensorProcessor processor =
- - new TensorProcessor.Builder()
- - .add(new NormalizeOp(MEAN, STDDEV)) // [0, 255] -> [-1, 1]
- - .add(new NormalizeOp(-1, 2)) // [-1, 1] -> [0, 1]
- - .build();
- - TensorBuffer output = processor.process(input);
- + @Test
- + public void testMultipleNormalize() {
- + TensorBuffer input = createExampleTensorBuffer();
- + TensorProcessor processor =
- + new TensorProcessor.Builder()
- + .add(new NormalizeOp(MEAN, STDDEV)) // [0, 255] -> [-1, 1]
- + .add(new NormalizeOp(-1, 2)) // [-1, 1] -> [0, 1]
- + .build();
- + TensorBuffer output = processor.process(input);
-
- - float[] pixels = output.getFloatArray();
- - assertThat(pixels.length).isEqualTo(EXAMPLE_NUM_FEATURES);
- - for (float p : pixels) {
- - assertThat(p).isAtLeast(0);
- - assertThat(p).isAtMost(1);
- + float[] pixels = output.getFloatArray();
- + assertThat(pixels.length).isEqualTo(EXAMPLE_NUM_FEATURES);
- + for (float p : pixels) {
- + assertThat(p).isAtLeast(0);
- + assertThat(p).isAtMost(1);
- + }
- }
- - }
-
- - // Creates a TensorBuffer of size {1, 1000}, containing values in range [0, 255].
- - private static TensorBuffer createExampleTensorBuffer() {
- - TensorBuffer buffer = TensorBuffer.createDynamic(DataType.FLOAT32);
- - int[] features = new int[EXAMPLE_NUM_FEATURES];
- - for (int i = 0; i < EXAMPLE_NUM_FEATURES; i++) {
- - features[i] = i % 256;
- + // Creates a TensorBuffer of size {1, 1000}, containing values in range [0, 255].
- + private static TensorBuffer createExampleTensorBuffer() {
- + TensorBuffer buffer = TensorBuffer.createDynamic(DataType.FLOAT32);
- + int[] features = new int[EXAMPLE_NUM_FEATURES];
- + for (int i = 0; i < EXAMPLE_NUM_FEATURES; i++) {
- + features[i] = i % 256;
- + }
- + buffer.loadArray(features, new int[] {1, EXAMPLE_NUM_FEATURES});
- + return buffer;
- }
- - buffer.loadArray(features, new int[] {1, EXAMPLE_NUM_FEATURES});
- - return buffer;
- - }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/ops/CastOpTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/ops/CastOpTest.java
- index a159c71863322..e8ba24d27550b 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/ops/CastOpTest.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/ops/CastOpTest.java
- @@ -27,56 +27,55 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
- /** Tests of {@link CastOp}. */
- @RunWith(RobolectricTestRunner.class)
- public final class CastOpTest {
- + private static final float[] FLOAT_ARRAY = new float[] {1.1f, 3.3f, 5.5f, 7.7f, 9.9f};
- + private static final float[] CASTED_FLOAT_ARRAY = new float[] {1.0f, 3.0f, 5.0f, 7.0f, 9.0f};
- + private static final int[] INT_ARRAY = new int[] {1, 3, 5, 7, 9};
- + private static final int[] SHAPE = new int[] {5};
-
- - private static final float[] FLOAT_ARRAY = new float[] {1.1f, 3.3f, 5.5f, 7.7f, 9.9f};
- - private static final float[] CASTED_FLOAT_ARRAY = new float[] {1.0f, 3.0f, 5.0f, 7.0f, 9.0f};
- - private static final int[] INT_ARRAY = new int[] {1, 3, 5, 7, 9};
- - private static final int[] SHAPE = new int[] {5};
- -
- - @Test
- - public void castFloat32ToUint8ShouldSuccess() {
- - TensorBuffer floatBuffer = TensorBuffer.createDynamic(DataType.FLOAT32);
- - floatBuffer.loadArray(FLOAT_ARRAY, SHAPE);
- - CastOp op = new CastOp(DataType.UINT8);
- - TensorBuffer uint8Buffer = op.apply(floatBuffer);
- - assertThat(uint8Buffer.getDataType()).isEqualTo(DataType.UINT8);
- - assertThat(uint8Buffer.getIntArray()).isEqualTo(INT_ARRAY);
- - }
- + @Test
- + public void castFloat32ToUint8ShouldSuccess() {
- + TensorBuffer floatBuffer = TensorBuffer.createDynamic(DataType.FLOAT32);
- + floatBuffer.loadArray(FLOAT_ARRAY, SHAPE);
- + CastOp op = new CastOp(DataType.UINT8);
- + TensorBuffer uint8Buffer = op.apply(floatBuffer);
- + assertThat(uint8Buffer.getDataType()).isEqualTo(DataType.UINT8);
- + assertThat(uint8Buffer.getIntArray()).isEqualTo(INT_ARRAY);
- + }
-
- - @Test
- - public void castUint8ToFloat32ShouldSuccess() {
- - TensorBuffer uint8Buffer = TensorBuffer.createDynamic(DataType.UINT8);
- - uint8Buffer.loadArray(INT_ARRAY, SHAPE);
- - CastOp op = new CastOp(DataType.FLOAT32);
- - TensorBuffer floatBuffer = op.apply(uint8Buffer);
- - assertThat(floatBuffer.getDataType()).isEqualTo(DataType.FLOAT32);
- - assertThat(floatBuffer.getFloatArray()).isEqualTo(CASTED_FLOAT_ARRAY);
- - }
- + @Test
- + public void castUint8ToFloat32ShouldSuccess() {
- + TensorBuffer uint8Buffer = TensorBuffer.createDynamic(DataType.UINT8);
- + uint8Buffer.loadArray(INT_ARRAY, SHAPE);
- + CastOp op = new CastOp(DataType.FLOAT32);
- + TensorBuffer floatBuffer = op.apply(uint8Buffer);
- + assertThat(floatBuffer.getDataType()).isEqualTo(DataType.FLOAT32);
- + assertThat(floatBuffer.getFloatArray()).isEqualTo(CASTED_FLOAT_ARRAY);
- + }
-
- - @Test
- - public void castFloat32ToFloat32ShouldNotRecreate() {
- - TensorBuffer floatBuffer = TensorBuffer.createDynamic(DataType.FLOAT32);
- - floatBuffer.loadArray(FLOAT_ARRAY, SHAPE);
- - CastOp op = new CastOp(DataType.FLOAT32);
- - TensorBuffer newBuffer = op.apply(floatBuffer);
- - assertThat(newBuffer.getDataType()).isEqualTo(DataType.FLOAT32);
- - assertThat(newBuffer).isSameInstanceAs(floatBuffer);
- - }
- + @Test
- + public void castFloat32ToFloat32ShouldNotRecreate() {
- + TensorBuffer floatBuffer = TensorBuffer.createDynamic(DataType.FLOAT32);
- + floatBuffer.loadArray(FLOAT_ARRAY, SHAPE);
- + CastOp op = new CastOp(DataType.FLOAT32);
- + TensorBuffer newBuffer = op.apply(floatBuffer);
- + assertThat(newBuffer.getDataType()).isEqualTo(DataType.FLOAT32);
- + assertThat(newBuffer).isSameInstanceAs(floatBuffer);
- + }
-
- - @Test
- - public void castUint8ToUint8ShouldNotRecreate() {
- - TensorBuffer uint8Buffer = TensorBuffer.createDynamic(DataType.UINT8);
- - uint8Buffer.loadArray(INT_ARRAY, SHAPE);
- - CastOp op = new CastOp(DataType.UINT8);
- - TensorBuffer newBuffer = op.apply(uint8Buffer);
- - assertThat(newBuffer.getDataType()).isEqualTo(DataType.UINT8);
- - assertThat(newBuffer).isSameInstanceAs(uint8Buffer);
- - }
- + @Test
- + public void castUint8ToUint8ShouldNotRecreate() {
- + TensorBuffer uint8Buffer = TensorBuffer.createDynamic(DataType.UINT8);
- + uint8Buffer.loadArray(INT_ARRAY, SHAPE);
- + CastOp op = new CastOp(DataType.UINT8);
- + TensorBuffer newBuffer = op.apply(uint8Buffer);
- + assertThat(newBuffer.getDataType()).isEqualTo(DataType.UINT8);
- + assertThat(newBuffer).isSameInstanceAs(uint8Buffer);
- + }
-
- - @Test
- - public void castToUnsupportedDataTypeShouldThrow() {
- - for (DataType type : new DataType[] {DataType.INT32, DataType.INT64, DataType.STRING}) {
- - Assert.assertThrows(IllegalArgumentException.class, () -> new CastOp(type));
- + @Test
- + public void castToUnsupportedDataTypeShouldThrow() {
- + for (DataType type : new DataType[] {DataType.INT32, DataType.INT64, DataType.STRING}) {
- + Assert.assertThrows(IllegalArgumentException.class, () -> new CastOp(type));
- + }
- }
- - }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/ops/DequantizeOpTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/ops/DequantizeOpTest.java
- index 99ded56ce069a..a69bcd7ec0296 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/ops/DequantizeOpTest.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/ops/DequantizeOpTest.java
- @@ -26,16 +26,15 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
- /** Tests of {@link DequantizeOp}. */
- @RunWith(RobolectricTestRunner.class)
- public final class DequantizeOpTest {
- -
- - @Test
- - public void dequantizeShouldSucess() {
- - int[] originalData = new int[] {191, 159, 63, 127, 255, 0};
- - DequantizeOp op = new DequantizeOp(127.0f, 1.0f / 128);
- - TensorBuffer input = TensorBuffer.createFixedSize(new int[] {6}, DataType.UINT8);
- - input.loadArray(originalData);
- - TensorBuffer dequantized = op.apply(input);
- - assertThat(dequantized.getDataType()).isEqualTo(DataType.FLOAT32);
- - assertThat(dequantized.getFloatArray())
- - .isEqualTo(new float[] {0.5f, 0.25f, -0.5f, 0, 1, -0.9921875f});
- - }
- + @Test
- + public void dequantizeShouldSucess() {
- + int[] originalData = new int[] {191, 159, 63, 127, 255, 0};
- + DequantizeOp op = new DequantizeOp(127.0f, 1.0f / 128);
- + TensorBuffer input = TensorBuffer.createFixedSize(new int[] {6}, DataType.UINT8);
- + input.loadArray(originalData);
- + TensorBuffer dequantized = op.apply(input);
- + assertThat(dequantized.getDataType()).isEqualTo(DataType.FLOAT32);
- + assertThat(dequantized.getFloatArray())
- + .isEqualTo(new float[] {0.5f, 0.25f, -0.5f, 0, 1, -0.9921875f});
- + }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/ops/NormalizeOpTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/ops/NormalizeOpTest.java
- index 09ef275a826bc..aabc6be926106 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/ops/NormalizeOpTest.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/ops/NormalizeOpTest.java
- @@ -16,6 +16,7 @@ limitations under the License.
- package org.tensorflow.lite.support.common.ops;
-
- import static com.google.common.truth.Truth.assertThat;
- +
- import static org.tensorflow.lite.DataType.FLOAT32;
- import static org.tensorflow.lite.DataType.UINT8;
-
- @@ -31,122 +32,120 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
- */
- @RunWith(RobolectricTestRunner.class)
- public final class NormalizeOpTest {
- + private static final float MEAN = 50;
- + private static final float STDDEV = 50;
- + private static final int NUM_ELEMENTS = 100;
- +
- + @Test
- + public void testNormalizeIntBuffer() {
- + int[] inputArr = new int[NUM_ELEMENTS];
- + for (int i = 0; i < NUM_ELEMENTS; i++) {
- + inputArr[i] = i;
- + }
- + TensorBuffer input = TensorBuffer.createDynamic(DataType.UINT8);
- + input.loadArray(inputArr, new int[] {inputArr.length});
- + NormalizeOp op = new NormalizeOp(MEAN, STDDEV);
- + TensorBuffer output = op.apply(input);
- + assertThat(output.getDataType()).isEqualTo(FLOAT32);
- + float[] outputArr = output.getFloatArray();
- + for (int i = 0; i < NUM_ELEMENTS; i++) {
- + assertThat(outputArr[i]).isEqualTo((inputArr[i] - MEAN) / STDDEV);
- + }
- + }
-
- - private static final float MEAN = 50;
- - private static final float STDDEV = 50;
- - private static final int NUM_ELEMENTS = 100;
- + @Test
- + public void testNormalizeFloatBuffer() {
- + float[] inputArr = new float[NUM_ELEMENTS];
- + for (int i = 0; i < NUM_ELEMENTS; i++) {
- + inputArr[i] = i;
- + }
- + TensorBuffer input = TensorBuffer.createDynamic(FLOAT32);
- + input.loadArray(inputArr, new int[] {inputArr.length});
- + NormalizeOp op = new NormalizeOp(MEAN, STDDEV);
- + TensorBuffer output = op.apply(input);
- + assertThat(output.getDataType()).isEqualTo(FLOAT32);
- + float[] outputArr = output.getFloatArray();
- + for (int i = 0; i < NUM_ELEMENTS; i++) {
- + assertThat(outputArr[i]).isEqualTo((inputArr[i] - MEAN) / STDDEV);
- + }
- + }
-
- - @Test
- - public void testNormalizeIntBuffer() {
- - int[] inputArr = new int[NUM_ELEMENTS];
- - for (int i = 0; i < NUM_ELEMENTS; i++) {
- - inputArr[i] = i;
- + @Test
- + public void testZeroStddev() {
- + Assert.assertThrows(IllegalArgumentException.class, () -> new NormalizeOp(1, 0));
- }
- - TensorBuffer input = TensorBuffer.createDynamic(DataType.UINT8);
- - input.loadArray(inputArr, new int[] {inputArr.length});
- - NormalizeOp op = new NormalizeOp(MEAN, STDDEV);
- - TensorBuffer output = op.apply(input);
- - assertThat(output.getDataType()).isEqualTo(FLOAT32);
- - float[] outputArr = output.getFloatArray();
- - for (int i = 0; i < NUM_ELEMENTS; i++) {
- - assertThat(outputArr[i]).isEqualTo((inputArr[i] - MEAN) / STDDEV);
- +
- + @Test
- + public void testIdentityShortcut() {
- + TensorBuffer input = TensorBuffer.createFixedSize(new int[] {3, 3}, UINT8);
- + NormalizeOp op = new NormalizeOp(0, 1);
- + TensorBuffer output = op.apply(input);
- + assertThat(output.getDataType()).isEqualTo(UINT8);
- + assertThat(output).isSameInstanceAs(input);
- }
- - }
-
- - @Test
- - public void testNormalizeFloatBuffer() {
- - float[] inputArr = new float[NUM_ELEMENTS];
- - for (int i = 0; i < NUM_ELEMENTS; i++) {
- - inputArr[i] = i;
- + @Test
- + public void testNormalizeOp_zeroMeanAndZeroStddev() {
- + TensorBuffer input = TensorBuffer.createFixedSize(new int[] {3, 3}, UINT8);
- + NormalizeOp op = new NormalizeOp(0, 0);
- + TensorBuffer output = op.apply(input);
- + assertThat(output.getDataType()).isEqualTo(UINT8);
- + assertThat(output).isSameInstanceAs(input);
- }
- - TensorBuffer input = TensorBuffer.createDynamic(FLOAT32);
- - input.loadArray(inputArr, new int[] {inputArr.length});
- - NormalizeOp op = new NormalizeOp(MEAN, STDDEV);
- - TensorBuffer output = op.apply(input);
- - assertThat(output.getDataType()).isEqualTo(FLOAT32);
- - float[] outputArr = output.getFloatArray();
- - for (int i = 0; i < NUM_ELEMENTS; i++) {
- - assertThat(outputArr[i]).isEqualTo((inputArr[i] - MEAN) / STDDEV);
- +
- + @Test
- + public void testNormalizeOp_zeroMeanAndInifityStddev() {
- + TensorBuffer input = TensorBuffer.createFixedSize(new int[] {3, 3}, UINT8);
- + NormalizeOp op = new NormalizeOp(0, Float.POSITIVE_INFINITY);
- + TensorBuffer output = op.apply(input);
- + assertThat(output.getDataType()).isEqualTo(UINT8);
- + assertThat(output).isSameInstanceAs(input);
- }
- - }
- -
- - @Test
- - public void testZeroStddev() {
- - Assert.assertThrows(IllegalArgumentException.class, () -> new NormalizeOp(1, 0));
- - }
- -
- - @Test
- - public void testIdentityShortcut() {
- - TensorBuffer input = TensorBuffer.createFixedSize(new int[] {3, 3}, UINT8);
- - NormalizeOp op = new NormalizeOp(0, 1);
- - TensorBuffer output = op.apply(input);
- - assertThat(output.getDataType()).isEqualTo(UINT8);
- - assertThat(output).isSameInstanceAs(input);
- - }
- -
- - @Test
- - public void testNormalizeOp_zeroMeanAndZeroStddev() {
- - TensorBuffer input = TensorBuffer.createFixedSize(new int[] {3, 3}, UINT8);
- - NormalizeOp op = new NormalizeOp(0, 0);
- - TensorBuffer output = op.apply(input);
- - assertThat(output.getDataType()).isEqualTo(UINT8);
- - assertThat(output).isSameInstanceAs(input);
- - }
- -
- - @Test
- - public void testNormalizeOp_zeroMeanAndInifityStddev() {
- - TensorBuffer input = TensorBuffer.createFixedSize(new int[] {3, 3}, UINT8);
- - NormalizeOp op = new NormalizeOp(0, Float.POSITIVE_INFINITY);
- - TensorBuffer output = op.apply(input);
- - assertThat(output.getDataType()).isEqualTo(UINT8);
- - assertThat(output).isSameInstanceAs(input);
- - }
- -
- - @Test
- - public void testMultiChannelNormalize() {
- - float[] inputArr = new float[NUM_ELEMENTS];
- - for (int i = 0; i < NUM_ELEMENTS; i++) {
- - inputArr[i] = i;
- +
- + @Test
- + public void testMultiChannelNormalize() {
- + float[] inputArr = new float[NUM_ELEMENTS];
- + for (int i = 0; i < NUM_ELEMENTS; i++) {
- + inputArr[i] = i;
- + }
- + TensorBuffer input = TensorBuffer.createDynamic(FLOAT32);
- + input.loadArray(inputArr, new int[] {20, 5});
- + float[] means = new float[] {1, 2, 3, 4, 5};
- + float[] stddevs = new float[] {6, 7, 8, 9, 10};
- + NormalizeOp op = new NormalizeOp(means, stddevs);
- + TensorBuffer output = op.apply(input);
- + assertThat(output.getDataType()).isEqualTo(FLOAT32);
- + float[] outputArr = output.getFloatArray();
- + for (int i = 0; i < NUM_ELEMENTS; i++) {
- + assertThat(outputArr[i]).isEqualTo((i - means[i % 5]) / stddevs[i % 5]);
- + }
- }
- - TensorBuffer input = TensorBuffer.createDynamic(FLOAT32);
- - input.loadArray(inputArr, new int[] {20, 5});
- - float[] means = new float[] {1, 2, 3, 4, 5};
- - float[] stddevs = new float[] {6, 7, 8, 9, 10};
- - NormalizeOp op = new NormalizeOp(means, stddevs);
- - TensorBuffer output = op.apply(input);
- - assertThat(output.getDataType()).isEqualTo(FLOAT32);
- - float[] outputArr = output.getFloatArray();
- - for (int i = 0; i < NUM_ELEMENTS; i++) {
- - assertThat(outputArr[i]).isEqualTo((i - means[i % 5]) / stddevs[i % 5]);
- +
- + @Test
- + public void testMultiChannelShortcut() {
- + TensorBuffer input = TensorBuffer.createFixedSize(new int[] {3, 3}, UINT8);
- + NormalizeOp op = new NormalizeOp(new float[] {0, 0, 0}, new float[] {1, 1, 1});
- + TensorBuffer output = op.apply(input);
- + assertThat(output.getDataType()).isEqualTo(UINT8);
- + assertThat(output).isSameInstanceAs(input);
- + }
- +
- + @Test
- + public void testMismatchedNumbersOfMeansAndStddevs() {
- + Assert.assertThrows(IllegalArgumentException.class,
- + () -> new NormalizeOp(new float[] {2, 3}, new float[] {1}));
- + }
- +
- + @Test
- + public void testMismatchedInputTensorChannelNum() {
- + TensorBuffer input = TensorBuffer.createFixedSize(new int[] {3, 3}, UINT8);
- + NormalizeOp op = new NormalizeOp(new float[] {0, 0}, new float[] {1, 2});
- + Assert.assertThrows(IllegalArgumentException.class, () -> op.apply(input));
- + }
- +
- + @Test
- + public void testAnyChannelInvalidStddev() {
- + Assert.assertThrows(IllegalArgumentException.class,
- + () -> new NormalizeOp(new float[] {2, 3}, new float[] {1, 0}));
- }
- - }
- -
- - @Test
- - public void testMultiChannelShortcut() {
- - TensorBuffer input = TensorBuffer.createFixedSize(new int[] {3, 3}, UINT8);
- - NormalizeOp op = new NormalizeOp(new float[] {0, 0, 0}, new float[] {1, 1, 1});
- - TensorBuffer output = op.apply(input);
- - assertThat(output.getDataType()).isEqualTo(UINT8);
- - assertThat(output).isSameInstanceAs(input);
- - }
- -
- - @Test
- - public void testMismatchedNumbersOfMeansAndStddevs() {
- - Assert.assertThrows(
- - IllegalArgumentException.class, () -> new NormalizeOp(new float[] {2, 3}, new float[] {1}));
- - }
- -
- - @Test
- - public void testMismatchedInputTensorChannelNum() {
- - TensorBuffer input = TensorBuffer.createFixedSize(new int[] {3, 3}, UINT8);
- - NormalizeOp op = new NormalizeOp(new float[] {0, 0}, new float[] {1, 2});
- - Assert.assertThrows(IllegalArgumentException.class, () -> op.apply(input));
- - }
- -
- - @Test
- - public void testAnyChannelInvalidStddev() {
- - Assert.assertThrows(
- - IllegalArgumentException.class,
- - () -> new NormalizeOp(new float[] {2, 3}, new float[] {1, 0}));
- - }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/ops/QuantizeOpTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/ops/QuantizeOpTest.java
- index 8ef72f92e0696..519cd287e1575 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/ops/QuantizeOpTest.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/ops/QuantizeOpTest.java
- @@ -26,15 +26,14 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
- /** Tests of {@link QuantizeOp}. */
- @RunWith(RobolectricTestRunner.class)
- public final class QuantizeOpTest {
- -
- - @Test
- - public void quantizeShouldSuccess() {
- - float[] originalData = {0.5f, 0.25f, -0.5f, 0, 1, -0.9921875f}; // -0.9921875 == -127 / 128
- - QuantizeOp op = new QuantizeOp(127.0f, 1.0f / 128);
- - TensorBuffer input = TensorBuffer.createFixedSize(new int[] {6}, DataType.FLOAT32);
- - input.loadArray(originalData);
- - TensorBuffer quantized = op.apply(input);
- - assertThat(quantized.getDataType()).isEqualTo(DataType.FLOAT32);
- - assertThat(quantized.getIntArray()).isEqualTo(new int[] {191, 159, 63, 127, 255, 0});
- - }
- + @Test
- + public void quantizeShouldSuccess() {
- + float[] originalData = {0.5f, 0.25f, -0.5f, 0, 1, -0.9921875f}; // -0.9921875 == -127 / 128
- + QuantizeOp op = new QuantizeOp(127.0f, 1.0f / 128);
- + TensorBuffer input = TensorBuffer.createFixedSize(new int[] {6}, DataType.FLOAT32);
- + input.loadArray(originalData);
- + TensorBuffer quantized = op.apply(input);
- + assertThat(quantized.getDataType()).isEqualTo(DataType.FLOAT32);
- + assertThat(quantized.getIntArray()).isEqualTo(new int[] {191, 159, 63, 127, 255, 0});
- + }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/BoundingBoxUtilTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/BoundingBoxUtilTest.java
- index 7f16c8e95628d..e8edb588c61c6 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/BoundingBoxUtilTest.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/BoundingBoxUtilTest.java
- @@ -18,7 +18,7 @@ package org.tensorflow.lite.support.image;
- import static com.google.common.truth.Truth.assertThat;
-
- import android.graphics.RectF;
- -import java.util.List;
- +
- import org.junit.Assert;
- import org.junit.Before;
- import org.junit.Test;
- @@ -28,213 +28,142 @@ import org.tensorflow.lite.DataType;
- import org.tensorflow.lite.support.image.BoundingBoxUtil.CoordinateType;
- import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
-
- +import java.util.List;
- +
- /** Tests of {@link BoundingBoxUtil}. */
- @RunWith(RobolectricTestRunner.class)
- public class BoundingBoxUtilTest {
- -
- - private TensorBuffer tensorBuffer;
- -
- - @Before
- - public void setUp() {
- - // 2 bounding boxes with additional batch dimension.
- - tensorBuffer = TensorBuffer.createFixedSize(new int[] {1, 2, 4}, DataType.FLOAT32);
- - }
- -
- - @Test
- - public void convertDefaultRatioBoundaries() {
- - tensorBuffer.loadArray(new float[] {0.25f, 0.2f, 0.75f, 0.8f, 0.5f, 0.0f, 1.0f, 1.0f});
- -
- - List<RectF> boxList =
- - BoundingBoxUtil.convert(
- - tensorBuffer,
- - new int[] {0, 1, 2, 3},
- - -1,
- - BoundingBoxUtil.Type.BOUNDARIES,
- - CoordinateType.RATIO,
- - 500,
- - 400);
- -
- - assertThat(boxList).hasSize(2);
- - assertThat(boxList.get(0)).isEqualTo(new RectF(100, 100, 300, 400));
- - assertThat(boxList.get(1)).isEqualTo(new RectF(200, 0, 400, 500));
- - }
- -
- - @Test
- - public void convertComplexTensor() {
- - tensorBuffer = TensorBuffer.createFixedSize(new int[] {3, 4, 2}, DataType.FLOAT32);
- - tensorBuffer.loadArray(
- - new float[] {
- - // sub tensor 0
- - 0, 1, 10, 11, 20, 21, 30, 31,
- - // sub tensor 1
- - 100, 101, 110, 111, 120, 121, 130, 131,
- - // sub tensor 2
- - 200, 201, 210, 211, 220, 221, 230, 231
- - });
- -
- - List<RectF> boxList =
- - BoundingBoxUtil.convert(
- - tensorBuffer,
- - new int[] {0, 1, 2, 3},
- - 1,
- - BoundingBoxUtil.Type.BOUNDARIES,
- - CoordinateType.PIXEL,
- - 0,
- - 0);
- -
- - assertThat(boxList).hasSize(6);
- - assertThat(boxList.get(0)).isEqualTo(new RectF(0, 10, 20, 30));
- - assertThat(boxList.get(1)).isEqualTo(new RectF(1, 11, 21, 31));
- - assertThat(boxList.get(2)).isEqualTo(new RectF(100, 110, 120, 130));
- - assertThat(boxList.get(3)).isEqualTo(new RectF(101, 111, 121, 131));
- - }
- -
- - @Test
- - public void convertIndexedRatioBoundaries() {
- - tensorBuffer.loadArray(new float[] {0.25f, 0.2f, 0.75f, 0.8f, 0.5f, 0.0f, 1.0f, 1.0f});
- -
- - List<RectF> boxList =
- - BoundingBoxUtil.convert(
- - tensorBuffer,
- - new int[] {1, 0, 3, 2},
- - -1,
- - BoundingBoxUtil.Type.BOUNDARIES,
- - CoordinateType.RATIO,
- - 500,
- - 400);
- -
- - assertThat(boxList).hasSize(2);
- - assertThat(boxList.get(0)).isEqualTo(new RectF(80, 125, 320, 375));
- - assertThat(boxList.get(1)).isEqualTo(new RectF(0, 250, 400, 500));
- - }
- -
- - @Test
- - public void convertPixelBoundaries() {
- - tensorBuffer.loadArray(new float[] {100, 100, 300, 400, 200, 0, 400, 500});
- -
- - List<RectF> boxList =
- - BoundingBoxUtil.convert(
- - tensorBuffer,
- - new int[] {0, 1, 2, 3},
- - -1,
- - BoundingBoxUtil.Type.BOUNDARIES,
- - CoordinateType.PIXEL,
- - 500,
- - 400);
- -
- - assertThat(boxList)
- - .containsExactly(new RectF(100, 100, 300, 400), new RectF(200, 0, 400, 500))
- - .inOrder();
- - }
- -
- - @Test
- - public void convertRatioUpperLeft() {
- - tensorBuffer.loadArray(new float[] {0.25f, 0.2f, 0.5f, 0.6f, 0.5f, 0.0f, 0.5f, 1.0f});
- -
- - List<RectF> boxList =
- - BoundingBoxUtil.convert(
- - tensorBuffer,
- - new int[] {0, 1, 2, 3},
- - -1,
- - BoundingBoxUtil.Type.UPPER_LEFT,
- - CoordinateType.RATIO,
- - 500,
- - 400);
- -
- - assertThat(boxList).hasSize(2);
- - assertThat(boxList)
- - .containsExactly(new RectF(100, 100, 300, 400), new RectF(200, 0, 400, 500))
- - .inOrder();
- - }
- -
- - @Test
- - public void convertPixelUpperLeft() {
- - tensorBuffer.loadArray(new float[] {100, 100, 200, 300, 200, 0, 200, 500});
- -
- - List<RectF> boxList =
- - BoundingBoxUtil.convert(
- - tensorBuffer,
- - new int[] {0, 1, 2, 3},
- - -1,
- - BoundingBoxUtil.Type.UPPER_LEFT,
- - CoordinateType.PIXEL,
- - 500,
- - 400);
- -
- - assertThat(boxList)
- - .containsExactly(new RectF(100, 100, 300, 400), new RectF(200, 0, 400, 500))
- - .inOrder();
- - }
- -
- - @Test
- - public void convertRatioCenter() {
- - tensorBuffer.loadArray(new float[] {0.5f, 0.5f, 0.5f, 0.6f, 0.75f, 0.5f, 0.5f, 1.0f});
- -
- - List<RectF> boxList =
- - BoundingBoxUtil.convert(
- - tensorBuffer,
- - new int[] {0, 1, 2, 3},
- - -1,
- - BoundingBoxUtil.Type.CENTER,
- - CoordinateType.RATIO,
- - 500,
- - 400);
- -
- - assertThat(boxList)
- - .containsExactly(new RectF(100, 99.99999f, 300, 400), new RectF(200, 0, 400, 500))
- - .inOrder();
- - }
- -
- - @Test
- - public void convertPixelCenter() {
- - tensorBuffer.loadArray(new float[] {200, 250, 200, 300, 300, 250, 200, 500});
- -
- - List<RectF> boxList =
- - BoundingBoxUtil.convert(
- - tensorBuffer,
- - new int[] {0, 1, 2, 3},
- - -1,
- - BoundingBoxUtil.Type.CENTER,
- - CoordinateType.PIXEL,
- - 500,
- - 400);
- -
- - assertThat(boxList)
- - .containsExactly(new RectF(100, 100, 300, 400), new RectF(200, 0, 400, 500))
- - .inOrder();
- - }
- -
- - @Test
- - public void convertTensorWithUnexpectedShapeShouldThrow() {
- - TensorBuffer badShapeTensor = TensorBuffer.createFixedSize(new int[] {1, 5}, DataType.FLOAT32);
- -
- - Assert.assertThrows(
- - IllegalArgumentException.class,
- - () ->
- - BoundingBoxUtil.convert(
- - badShapeTensor,
- - new int[] {0, 1, 2, 3},
- - -1,
- - BoundingBoxUtil.Type.BOUNDARIES,
- - CoordinateType.RATIO,
- - 300,
- - 400));
- - }
- -
- - @Test
- - public void convertIntTensorShouldThrow() {
- - TensorBuffer badTypeTensor = TensorBuffer.createFixedSize(new int[] {1, 4}, DataType.UINT8);
- -
- - Assert.assertThrows(
- - IllegalArgumentException.class,
- - () ->
- - BoundingBoxUtil.convert(
- - badTypeTensor,
- - new int[] {0, 1, 2, 3},
- - -1,
- - BoundingBoxUtil.Type.BOUNDARIES,
- - CoordinateType.RATIO,
- - 300,
- - 400));
- - }
- + private TensorBuffer tensorBuffer;
- +
- + @Before
- + public void setUp() {
- + // 2 bounding boxes with additional batch dimension.
- + tensorBuffer = TensorBuffer.createFixedSize(new int[] {1, 2, 4}, DataType.FLOAT32);
- + }
- +
- + @Test
- + public void convertDefaultRatioBoundaries() {
- + tensorBuffer.loadArray(new float[] {0.25f, 0.2f, 0.75f, 0.8f, 0.5f, 0.0f, 1.0f, 1.0f});
- +
- + List<RectF> boxList = BoundingBoxUtil.convert(tensorBuffer, new int[] {0, 1, 2, 3}, -1,
- + BoundingBoxUtil.Type.BOUNDARIES, CoordinateType.RATIO, 500, 400);
- +
- + assertThat(boxList).hasSize(2);
- + assertThat(boxList.get(0)).isEqualTo(new RectF(100, 100, 300, 400));
- + assertThat(boxList.get(1)).isEqualTo(new RectF(200, 0, 400, 500));
- + }
- +
- + @Test
- + public void convertComplexTensor() {
- + tensorBuffer = TensorBuffer.createFixedSize(new int[] {3, 4, 2}, DataType.FLOAT32);
- + tensorBuffer.loadArray(new float[] {// sub tensor 0
- + 0, 1, 10, 11, 20, 21, 30, 31,
- + // sub tensor 1
- + 100, 101, 110, 111, 120, 121, 130, 131,
- + // sub tensor 2
- + 200, 201, 210, 211, 220, 221, 230, 231});
- +
- + List<RectF> boxList = BoundingBoxUtil.convert(tensorBuffer, new int[] {0, 1, 2, 3}, 1,
- + BoundingBoxUtil.Type.BOUNDARIES, CoordinateType.PIXEL, 0, 0);
- +
- + assertThat(boxList).hasSize(6);
- + assertThat(boxList.get(0)).isEqualTo(new RectF(0, 10, 20, 30));
- + assertThat(boxList.get(1)).isEqualTo(new RectF(1, 11, 21, 31));
- + assertThat(boxList.get(2)).isEqualTo(new RectF(100, 110, 120, 130));
- + assertThat(boxList.get(3)).isEqualTo(new RectF(101, 111, 121, 131));
- + }
- +
- + @Test
- + public void convertIndexedRatioBoundaries() {
- + tensorBuffer.loadArray(new float[] {0.25f, 0.2f, 0.75f, 0.8f, 0.5f, 0.0f, 1.0f, 1.0f});
- +
- + List<RectF> boxList = BoundingBoxUtil.convert(tensorBuffer, new int[] {1, 0, 3, 2}, -1,
- + BoundingBoxUtil.Type.BOUNDARIES, CoordinateType.RATIO, 500, 400);
- +
- + assertThat(boxList).hasSize(2);
- + assertThat(boxList.get(0)).isEqualTo(new RectF(80, 125, 320, 375));
- + assertThat(boxList.get(1)).isEqualTo(new RectF(0, 250, 400, 500));
- + }
- +
- + @Test
- + public void convertPixelBoundaries() {
- + tensorBuffer.loadArray(new float[] {100, 100, 300, 400, 200, 0, 400, 500});
- +
- + List<RectF> boxList = BoundingBoxUtil.convert(tensorBuffer, new int[] {0, 1, 2, 3}, -1,
- + BoundingBoxUtil.Type.BOUNDARIES, CoordinateType.PIXEL, 500, 400);
- +
- + assertThat(boxList)
- + .containsExactly(new RectF(100, 100, 300, 400), new RectF(200, 0, 400, 500))
- + .inOrder();
- + }
- +
- + @Test
- + public void convertRatioUpperLeft() {
- + tensorBuffer.loadArray(new float[] {0.25f, 0.2f, 0.5f, 0.6f, 0.5f, 0.0f, 0.5f, 1.0f});
- +
- + List<RectF> boxList = BoundingBoxUtil.convert(tensorBuffer, new int[] {0, 1, 2, 3}, -1,
- + BoundingBoxUtil.Type.UPPER_LEFT, CoordinateType.RATIO, 500, 400);
- +
- + assertThat(boxList).hasSize(2);
- + assertThat(boxList)
- + .containsExactly(new RectF(100, 100, 300, 400), new RectF(200, 0, 400, 500))
- + .inOrder();
- + }
- +
- + @Test
- + public void convertPixelUpperLeft() {
- + tensorBuffer.loadArray(new float[] {100, 100, 200, 300, 200, 0, 200, 500});
- +
- + List<RectF> boxList = BoundingBoxUtil.convert(tensorBuffer, new int[] {0, 1, 2, 3}, -1,
- + BoundingBoxUtil.Type.UPPER_LEFT, CoordinateType.PIXEL, 500, 400);
- +
- + assertThat(boxList)
- + .containsExactly(new RectF(100, 100, 300, 400), new RectF(200, 0, 400, 500))
- + .inOrder();
- + }
- +
- + @Test
- + public void convertRatioCenter() {
- + tensorBuffer.loadArray(new float[] {0.5f, 0.5f, 0.5f, 0.6f, 0.75f, 0.5f, 0.5f, 1.0f});
- +
- + List<RectF> boxList = BoundingBoxUtil.convert(tensorBuffer, new int[] {0, 1, 2, 3}, -1,
- + BoundingBoxUtil.Type.CENTER, CoordinateType.RATIO, 500, 400);
- +
- + assertThat(boxList)
- + .containsExactly(new RectF(100, 99.99999f, 300, 400), new RectF(200, 0, 400, 500))
- + .inOrder();
- + }
- +
- + @Test
- + public void convertPixelCenter() {
- + tensorBuffer.loadArray(new float[] {200, 250, 200, 300, 300, 250, 200, 500});
- +
- + List<RectF> boxList = BoundingBoxUtil.convert(tensorBuffer, new int[] {0, 1, 2, 3}, -1,
- + BoundingBoxUtil.Type.CENTER, CoordinateType.PIXEL, 500, 400);
- +
- + assertThat(boxList)
- + .containsExactly(new RectF(100, 100, 300, 400), new RectF(200, 0, 400, 500))
- + .inOrder();
- + }
- +
- + @Test
- + public void convertTensorWithUnexpectedShapeShouldThrow() {
- + TensorBuffer badShapeTensor =
- + TensorBuffer.createFixedSize(new int[] {1, 5}, DataType.FLOAT32);
- +
- + Assert.assertThrows(IllegalArgumentException.class,
- + ()
- + -> BoundingBoxUtil.convert(badShapeTensor, new int[] {0, 1, 2, 3}, -1,
- + BoundingBoxUtil.Type.BOUNDARIES, CoordinateType.RATIO, 300, 400));
- + }
- +
- + @Test
- + public void convertIntTensorShouldThrow() {
- + TensorBuffer badTypeTensor = TensorBuffer.createFixedSize(new int[] {1, 4}, DataType.UINT8);
- +
- + Assert.assertThrows(IllegalArgumentException.class,
- + ()
- + -> BoundingBoxUtil.convert(badTypeTensor, new int[] {0, 1, 2, 3}, -1,
- + BoundingBoxUtil.Type.BOUNDARIES, CoordinateType.RATIO, 300, 400));
- + }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ColorSpaceTypeInstrumentedTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ColorSpaceTypeInstrumentedTest.java
- index c41508308291a..329b5aa370744 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ColorSpaceTypeInstrumentedTest.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ColorSpaceTypeInstrumentedTest.java
- @@ -15,10 +15,12 @@ limitations under the License.
- package org.tensorflow.lite.support.image;
-
- import static com.google.common.truth.Truth.assertThat;
- +
- import static org.tensorflow.lite.support.image.TestImageCreator.createGrayscaleBitmap;
- import static org.tensorflow.lite.support.image.TestImageCreator.createGrayscaleTensorBuffer;
-
- import android.graphics.Bitmap;
- +
- import org.junit.Test;
- import org.junit.runner.RunWith;
- import org.junit.runners.JUnit4;
- @@ -27,22 +29,21 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
-
- @RunWith(JUnit4.class)
- public final class ColorSpaceTypeInstrumentedTest {
- -
- - @Test
- - public void convertTensorBufferToBitmapShouldSuccessWithGrayscaleWithUint8() {
- - TensorBuffer buffer = createGrayscaleTensorBuffer(DataType.UINT8, false);
- - Bitmap bitmap = ColorSpaceType.GRAYSCALE.convertTensorBufferToBitmap(buffer);
- -
- - Bitmap expectedBitmap = createGrayscaleBitmap();
- - assertThat(bitmap.sameAs(expectedBitmap)).isTrue();
- - }
- -
- - @Test
- - public void convertTensorBufferToBitmapShouldSuccessWithGrayscaleWithFloat() {
- - TensorBuffer buffer = createGrayscaleTensorBuffer(DataType.FLOAT32, false);
- - Bitmap bitmap = ColorSpaceType.GRAYSCALE.convertTensorBufferToBitmap(buffer);
- -
- - Bitmap expectedBitmap = createGrayscaleBitmap();
- - assertThat(bitmap.sameAs(expectedBitmap)).isTrue();
- - }
- + @Test
- + public void convertTensorBufferToBitmapShouldSuccessWithGrayscaleWithUint8() {
- + TensorBuffer buffer = createGrayscaleTensorBuffer(DataType.UINT8, false);
- + Bitmap bitmap = ColorSpaceType.GRAYSCALE.convertTensorBufferToBitmap(buffer);
- +
- + Bitmap expectedBitmap = createGrayscaleBitmap();
- + assertThat(bitmap.sameAs(expectedBitmap)).isTrue();
- + }
- +
- + @Test
- + public void convertTensorBufferToBitmapShouldSuccessWithGrayscaleWithFloat() {
- + TensorBuffer buffer = createGrayscaleTensorBuffer(DataType.FLOAT32, false);
- + Bitmap bitmap = ColorSpaceType.GRAYSCALE.convertTensorBufferToBitmap(buffer);
- +
- + Bitmap expectedBitmap = createGrayscaleBitmap();
- + assertThat(bitmap.sameAs(expectedBitmap)).isTrue();
- + }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ColorSpaceTypeTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ColorSpaceTypeTest.java
- index 46977fdb2bdfa..92612255269f6 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ColorSpaceTypeTest.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ColorSpaceTypeTest.java
- @@ -16,6 +16,7 @@ limitations under the License.
- package org.tensorflow.lite.support.image;
-
- import static com.google.common.truth.Truth.assertThat;
- +
- import static org.junit.Assert.assertThrows;
- import static org.tensorflow.lite.support.image.TestImageCreator.createRgbBitmap;
- import static org.tensorflow.lite.support.image.TestImageCreator.createRgbTensorBuffer;
- @@ -23,8 +24,7 @@ import static org.tensorflow.lite.support.image.TestImageCreator.createRgbTensor
- import android.graphics.Bitmap;
- import android.graphics.Bitmap.Config;
- import android.graphics.ImageFormat;
- -import java.util.Arrays;
- -import java.util.Collection;
- +
- import org.junit.Rule;
- import org.junit.Test;
- import org.junit.rules.ErrorCollector;
- @@ -38,386 +38,353 @@ import org.robolectric.RobolectricTestRunner;
- import org.tensorflow.lite.DataType;
- import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
-
- +import java.util.Arrays;
- +import java.util.Collection;
- +
- /** Tests of {@link ImageConversions}. */
- @RunWith(Suite.class)
- -@SuiteClasses({
- - ColorSpaceTypeTest.ValidShapeTest.class,
- - ColorSpaceTypeTest.InvalidShapeTest.class,
- - ColorSpaceTypeTest.BitmapConfigTest.class,
- - ColorSpaceTypeTest.ImageFormatTest.class,
- - ColorSpaceTypeTest.YuvImageTest.class,
- - ColorSpaceTypeTest.AssertNumElementsTest.class,
- - ColorSpaceTypeTest.General.class
- -})
- +@SuiteClasses({ColorSpaceTypeTest.ValidShapeTest.class, ColorSpaceTypeTest.InvalidShapeTest.class,
- + ColorSpaceTypeTest.BitmapConfigTest.class, ColorSpaceTypeTest.ImageFormatTest.class,
- + ColorSpaceTypeTest.YuvImageTest.class, ColorSpaceTypeTest.AssertNumElementsTest.class,
- + ColorSpaceTypeTest.General.class})
- public class ColorSpaceTypeTest {
- -
- - /** Parameterized tests for valid shapes. */
- - @RunWith(ParameterizedRobolectricTestRunner.class)
- - public static final class ValidShapeTest extends ColorSpaceTypeTest {
- -
- - @Parameter(0)
- - public ColorSpaceType colorSpaceType;
- -
- - /** The shape that matches the colorSpaceType. */
- - @Parameter(1)
- - public int[] validShape;
- -
- - /** The height of validShape. */
- - @Parameter(2)
- - public int expectedHeight;
- -
- - /** The width of validShape. */
- - @Parameter(3)
- - public int expectedWidth;
- -
- - @Parameters(name = "colorSpaceType={0}; validShape={1}; height={2}; width={3}")
- - public static Collection<Object[]> data() {
- - return Arrays.asList(
- - new Object[][] {
- - {ColorSpaceType.RGB, new int[] {1, 10, 20, 3}, 10, 20},
- - {ColorSpaceType.RGB, new int[] {10, 20, 3}, 10, 20},
- - {ColorSpaceType.GRAYSCALE, new int[] {10, 20}, 10, 20},
- - {ColorSpaceType.GRAYSCALE, new int[] {1, 10, 20, 1}, 10, 20},
- - });
- - }
- -
- - @Test
- - public void getHeightSucceedsWithValidShape() {
- - assertThat(colorSpaceType.getHeight(validShape)).isEqualTo(expectedHeight);
- + /** Parameterized tests for valid shapes. */
- + @RunWith(ParameterizedRobolectricTestRunner.class)
- + public static final class ValidShapeTest extends ColorSpaceTypeTest {
- + @Parameter(0)
- + public ColorSpaceType colorSpaceType;
- +
- + /** The shape that matches the colorSpaceType. */
- + @Parameter(1)
- + public int[] validShape;
- +
- + /** The height of validShape. */
- + @Parameter(2)
- + public int expectedHeight;
- +
- + /** The width of validShape. */
- + @Parameter(3)
- + public int expectedWidth;
- +
- + @Parameters(name = "colorSpaceType={0}; validShape={1}; height={2}; width={3}")
- + public static Collection<Object[]> data() {
- + return Arrays.asList(new Object[][] {
- + {ColorSpaceType.RGB, new int[] {1, 10, 20, 3}, 10, 20},
- + {ColorSpaceType.RGB, new int[] {10, 20, 3}, 10, 20},
- + {ColorSpaceType.GRAYSCALE, new int[] {10, 20}, 10, 20},
- + {ColorSpaceType.GRAYSCALE, new int[] {1, 10, 20, 1}, 10, 20},
- + });
- + }
- +
- + @Test
- + public void getHeightSucceedsWithValidShape() {
- + assertThat(colorSpaceType.getHeight(validShape)).isEqualTo(expectedHeight);
- + }
- +
- + @Test
- + public void getWidthSucceedsWithValidShape() {
- + assertThat(colorSpaceType.getWidth(validShape)).isEqualTo(expectedWidth);
- + }
- }
-
- - @Test
- - public void getWidthSucceedsWithValidShape() {
- - assertThat(colorSpaceType.getWidth(validShape)).isEqualTo(expectedWidth);
- - }
- - }
- -
- - /** Parameterized tests for invalid shapes. */
- - @RunWith(ParameterizedRobolectricTestRunner.class)
- - public static final class InvalidShapeTest extends ColorSpaceTypeTest {
- -
- - private static final String RGB_ASSERT_SHAPE_MESSAGE =
- - "The shape of a RGB image should be (h, w, c) or (1, h, w, c), and channels"
- - + " representing R, G, B in order. The provided image shape is ";
- - private static final String GRAYSCALE_ASSERT_SHAPE_MESSAGE =
- - "The shape of a grayscale image should be (h, w) or (1, h, w, 1). The provided image"
- - + " shape is ";
- -
- - @Parameter(0)
- - public ColorSpaceType colorSpaceType;
- -
- - /** The shape that does not match the colorSpaceType. */
- - @Parameter(1)
- - public int[] invalidShape;
- -
- - @Parameter(2)
- - public String errorMessage;
- -
- - @Parameters(name = "colorSpaceType={0}; invalidShape={1}")
- - public static Collection<Object[]> data() {
- - return Arrays.asList(
- - new Object[][] {
- - {ColorSpaceType.RGB, new int[] {2, 10, 20, 3}, RGB_ASSERT_SHAPE_MESSAGE},
- - {ColorSpaceType.RGB, new int[] {1, 10, 20, 3, 4}, RGB_ASSERT_SHAPE_MESSAGE},
- - {ColorSpaceType.RGB, new int[] {1, 10, 20, 5}, RGB_ASSERT_SHAPE_MESSAGE},
- - {ColorSpaceType.RGB, new int[] {1, 10, 20}, RGB_ASSERT_SHAPE_MESSAGE},
- - {ColorSpaceType.RGB, new int[] {1, -10, 20, 3}, RGB_ASSERT_SHAPE_MESSAGE},
- - {ColorSpaceType.RGB, new int[] {1, 10, -20, 3}, RGB_ASSERT_SHAPE_MESSAGE},
- - {ColorSpaceType.RGB, new int[] {10, 20, 3, 4}, RGB_ASSERT_SHAPE_MESSAGE},
- - {ColorSpaceType.RGB, new int[] {10, 20, 5}, RGB_ASSERT_SHAPE_MESSAGE},
- - {ColorSpaceType.RGB, new int[] {10, 20}, RGB_ASSERT_SHAPE_MESSAGE},
- - {ColorSpaceType.RGB, new int[] {-10, 20, 3}, RGB_ASSERT_SHAPE_MESSAGE},
- - {ColorSpaceType.RGB, new int[] {10, -20, 3}, RGB_ASSERT_SHAPE_MESSAGE},
- - {ColorSpaceType.GRAYSCALE, new int[] {2, 10, 20}, GRAYSCALE_ASSERT_SHAPE_MESSAGE},
- - {ColorSpaceType.GRAYSCALE, new int[] {1, 10, 20, 3}, GRAYSCALE_ASSERT_SHAPE_MESSAGE},
- - {ColorSpaceType.GRAYSCALE, new int[] {1, -10, 20}, GRAYSCALE_ASSERT_SHAPE_MESSAGE},
- - {ColorSpaceType.GRAYSCALE, new int[] {1, 10, -20}, GRAYSCALE_ASSERT_SHAPE_MESSAGE},
- - {ColorSpaceType.GRAYSCALE, new int[] {10, 20, 4}, GRAYSCALE_ASSERT_SHAPE_MESSAGE},
- - {ColorSpaceType.GRAYSCALE, new int[] {10}, GRAYSCALE_ASSERT_SHAPE_MESSAGE},
- - {ColorSpaceType.GRAYSCALE, new int[] {-10, 20}, GRAYSCALE_ASSERT_SHAPE_MESSAGE},
- - {ColorSpaceType.GRAYSCALE, new int[] {10, -20}, GRAYSCALE_ASSERT_SHAPE_MESSAGE},
- - });
- + /** Parameterized tests for invalid shapes. */
- + @RunWith(ParameterizedRobolectricTestRunner.class)
- + public static final class InvalidShapeTest extends ColorSpaceTypeTest {
- + private static final String RGB_ASSERT_SHAPE_MESSAGE =
- + "The shape of a RGB image should be (h, w, c) or (1, h, w, c), and channels"
- + + " representing R, G, B in order. The provided image shape is ";
- + private static final String GRAYSCALE_ASSERT_SHAPE_MESSAGE =
- + "The shape of a grayscale image should be (h, w) or (1, h, w, 1). The provided image"
- + + " shape is ";
- +
- + @Parameter(0)
- + public ColorSpaceType colorSpaceType;
- +
- + /** The shape that does not match the colorSpaceType. */
- + @Parameter(1)
- + public int[] invalidShape;
- +
- + @Parameter(2)
- + public String errorMessage;
- +
- + @Parameters(name = "colorSpaceType={0}; invalidShape={1}")
- + public static Collection<Object[]> data() {
- + return Arrays.asList(new Object[][] {
- + {ColorSpaceType.RGB, new int[] {2, 10, 20, 3}, RGB_ASSERT_SHAPE_MESSAGE},
- + {ColorSpaceType.RGB, new int[] {1, 10, 20, 3, 4}, RGB_ASSERT_SHAPE_MESSAGE},
- + {ColorSpaceType.RGB, new int[] {1, 10, 20, 5}, RGB_ASSERT_SHAPE_MESSAGE},
- + {ColorSpaceType.RGB, new int[] {1, 10, 20}, RGB_ASSERT_SHAPE_MESSAGE},
- + {ColorSpaceType.RGB, new int[] {1, -10, 20, 3}, RGB_ASSERT_SHAPE_MESSAGE},
- + {ColorSpaceType.RGB, new int[] {1, 10, -20, 3}, RGB_ASSERT_SHAPE_MESSAGE},
- + {ColorSpaceType.RGB, new int[] {10, 20, 3, 4}, RGB_ASSERT_SHAPE_MESSAGE},
- + {ColorSpaceType.RGB, new int[] {10, 20, 5}, RGB_ASSERT_SHAPE_MESSAGE},
- + {ColorSpaceType.RGB, new int[] {10, 20}, RGB_ASSERT_SHAPE_MESSAGE},
- + {ColorSpaceType.RGB, new int[] {-10, 20, 3}, RGB_ASSERT_SHAPE_MESSAGE},
- + {ColorSpaceType.RGB, new int[] {10, -20, 3}, RGB_ASSERT_SHAPE_MESSAGE},
- + {ColorSpaceType.GRAYSCALE, new int[] {2, 10, 20},
- + GRAYSCALE_ASSERT_SHAPE_MESSAGE},
- + {ColorSpaceType.GRAYSCALE, new int[] {1, 10, 20, 3},
- + GRAYSCALE_ASSERT_SHAPE_MESSAGE},
- + {ColorSpaceType.GRAYSCALE, new int[] {1, -10, 20},
- + GRAYSCALE_ASSERT_SHAPE_MESSAGE},
- + {ColorSpaceType.GRAYSCALE, new int[] {1, 10, -20},
- + GRAYSCALE_ASSERT_SHAPE_MESSAGE},
- + {ColorSpaceType.GRAYSCALE, new int[] {10, 20, 4},
- + GRAYSCALE_ASSERT_SHAPE_MESSAGE},
- + {ColorSpaceType.GRAYSCALE, new int[] {10}, GRAYSCALE_ASSERT_SHAPE_MESSAGE},
- + {ColorSpaceType.GRAYSCALE, new int[] {-10, 20}, GRAYSCALE_ASSERT_SHAPE_MESSAGE},
- + {ColorSpaceType.GRAYSCALE, new int[] {10, -20}, GRAYSCALE_ASSERT_SHAPE_MESSAGE},
- + });
- + }
- +
- + @Test
- + public void assertShapeFaislsWithInvalidShape() {
- + IllegalArgumentException exception = assertThrows(
- + IllegalArgumentException.class, () -> colorSpaceType.assertShape(invalidShape));
- + assertThat(exception).hasMessageThat().contains(
- + errorMessage + Arrays.toString(invalidShape));
- + }
- +
- + @Test
- + public void getHeightFaislsWithInvalidShape() {
- + IllegalArgumentException exception = assertThrows(
- + IllegalArgumentException.class, () -> colorSpaceType.getHeight(invalidShape));
- + assertThat(exception).hasMessageThat().contains(
- + errorMessage + Arrays.toString(invalidShape));
- + }
- +
- + @Test
- + public void getWidthFaislsWithInvalidShape() {
- + IllegalArgumentException exception = assertThrows(
- + IllegalArgumentException.class, () -> colorSpaceType.getWidth(invalidShape));
- + assertThat(exception).hasMessageThat().contains(
- + errorMessage + Arrays.toString(invalidShape));
- + }
- }
-
- - @Test
- - public void assertShapeFaislsWithInvalidShape() {
- - IllegalArgumentException exception =
- - assertThrows(
- - IllegalArgumentException.class, () -> colorSpaceType.assertShape(invalidShape));
- - assertThat(exception).hasMessageThat().contains(errorMessage + Arrays.toString(invalidShape));
- + /** Parameterized tests for Bitmap Config. */
- + @RunWith(ParameterizedRobolectricTestRunner.class)
- + public static final class BitmapConfigTest extends ColorSpaceTypeTest {
- + @Parameter(0)
- + public ColorSpaceType colorSpaceType;
- +
- + /** The Bitmap configuration match the colorSpaceType. */
- + @Parameter(1)
- + public Config config;
- +
- + @Parameters(name = "colorSpaceType={0}; config={1}")
- + public static Collection<Object[]> data() {
- + return Arrays.asList(new Object[][] {
- + {ColorSpaceType.RGB, Config.ARGB_8888},
- + {ColorSpaceType.GRAYSCALE, Config.ALPHA_8},
- + });
- + }
- +
- + @Test
- + public void fromBitmapConfigSucceedsWithSupportedConfig() {
- + assertThat(ColorSpaceType.fromBitmapConfig(config)).isEqualTo(colorSpaceType);
- + }
- +
- + @Test
- + public void toBitmapConfigSucceedsWithSupportedConfig() {
- + assertThat(colorSpaceType.toBitmapConfig()).isEqualTo(config);
- + }
- }
-
- - @Test
- - public void getHeightFaislsWithInvalidShape() {
- - IllegalArgumentException exception =
- - assertThrows(
- - IllegalArgumentException.class, () -> colorSpaceType.getHeight(invalidShape));
- - assertThat(exception).hasMessageThat().contains(errorMessage + Arrays.toString(invalidShape));
- + /** Parameterized tests for ImageFormat. */
- + @RunWith(ParameterizedRobolectricTestRunner.class)
- + public static final class ImageFormatTest extends ColorSpaceTypeTest {
- + @Parameter(0)
- + public ColorSpaceType colorSpaceType;
- +
- + /** The ImageFormat that matches the colorSpaceType. */
- + @Parameter(1)
- + public int imageFormat;
- +
- + @Parameters(name = "colorSpaceType={0}; imageFormat={1}")
- + public static Collection<Object[]> data() {
- + return Arrays.asList(new Object[][] {
- + {ColorSpaceType.NV21, ImageFormat.NV21},
- + {ColorSpaceType.YV12, ImageFormat.YV12},
- + {ColorSpaceType.YUV_420_888, ImageFormat.YUV_420_888},
- + });
- + }
- +
- + @Test
- + public void fromImageFormatSucceedsWithSupportedImageFormat() {
- + assertThat(ColorSpaceType.fromImageFormat(imageFormat)).isEqualTo(colorSpaceType);
- + }
- }
-
- - @Test
- - public void getWidthFaislsWithInvalidShape() {
- - IllegalArgumentException exception =
- - assertThrows(IllegalArgumentException.class, () -> colorSpaceType.getWidth(invalidShape));
- - assertThat(exception).hasMessageThat().contains(errorMessage + Arrays.toString(invalidShape));
- - }
- - }
- -
- - /** Parameterized tests for Bitmap Config. */
- - @RunWith(ParameterizedRobolectricTestRunner.class)
- - public static final class BitmapConfigTest extends ColorSpaceTypeTest {
- -
- - @Parameter(0)
- - public ColorSpaceType colorSpaceType;
- -
- - /** The Bitmap configuration match the colorSpaceType. */
- - @Parameter(1)
- - public Config config;
- -
- - @Parameters(name = "colorSpaceType={0}; config={1}")
- - public static Collection<Object[]> data() {
- - return Arrays.asList(
- - new Object[][] {
- - {ColorSpaceType.RGB, Config.ARGB_8888},
- - {ColorSpaceType.GRAYSCALE, Config.ALPHA_8},
- - });
- + /** Parameterized tests for YUV image formats: NV12, NV21, YV12, YV21, YUV_420_888. */
- + @RunWith(ParameterizedRobolectricTestRunner.class)
- + public static final class YuvImageTest extends ColorSpaceTypeTest {
- + @Parameter(0)
- + public ColorSpaceType colorSpaceType;
- +
- + @Parameters(name = "colorSpaceType={0}")
- + public static Collection<Object[]> data() {
- + return Arrays.asList(new Object[][] {
- + {ColorSpaceType.NV12},
- + {ColorSpaceType.NV21},
- + {ColorSpaceType.YV12},
- + {ColorSpaceType.YV21},
- + {ColorSpaceType.YUV_420_888},
- + });
- + }
- +
- + @Test
- + public void convertTensorBufferToBitmapShouldFail() {
- + UnsupportedOperationException exception =
- + assertThrows(UnsupportedOperationException.class,
- + ()
- + -> colorSpaceType.convertTensorBufferToBitmap(
- + TensorBuffer.createDynamic(DataType.FLOAT32)));
- + assertThat(exception).hasMessageThat().contains(
- + "convertTensorBufferToBitmap() is unsupported for the color space type "
- + + colorSpaceType.name());
- + }
- +
- + @Test
- + public void getWidthShouldFail() {
- + UnsupportedOperationException exception =
- + assertThrows(UnsupportedOperationException.class,
- + () -> colorSpaceType.getWidth(new int[] {}));
- + assertThat(exception).hasMessageThat().contains(
- + "getWidth() only supports RGB and GRAYSCALE formats, but not "
- + + colorSpaceType.name());
- + }
- +
- + @Test
- + public void getHeightShouldFail() {
- + UnsupportedOperationException exception =
- + assertThrows(UnsupportedOperationException.class,
- + () -> colorSpaceType.getHeight(new int[] {}));
- + assertThat(exception).hasMessageThat().contains(
- + "getHeight() only supports RGB and GRAYSCALE formats, but not "
- + + colorSpaceType.name());
- + }
- +
- + @Test
- + public void assertShapeShouldFail() {
- + UnsupportedOperationException exception =
- + assertThrows(UnsupportedOperationException.class,
- + () -> colorSpaceType.assertShape(new int[] {}));
- + assertThat(exception).hasMessageThat().contains(
- + "assertShape() only supports RGB and GRAYSCALE formats, but not "
- + + colorSpaceType.name());
- + }
- +
- + @Test
- + public void getChannelValueShouldFail() {
- + UnsupportedOperationException exception = assertThrows(
- + UnsupportedOperationException.class, () -> colorSpaceType.getChannelValue());
- + assertThat(exception).hasMessageThat().contains(
- + "getChannelValue() is unsupported for the color space type "
- + + colorSpaceType.name());
- + }
- +
- + @Test
- + public void getNormalizedShapeShouldFail() {
- + UnsupportedOperationException exception =
- + assertThrows(UnsupportedOperationException.class,
- + () -> colorSpaceType.getNormalizedShape(new int[] {}));
- + assertThat(exception).hasMessageThat().contains(
- + "getNormalizedShape() is unsupported for the color space type "
- + + colorSpaceType.name());
- + }
- +
- + @Test
- + public void getShapeInfoMessageShouldFail() {
- + UnsupportedOperationException exception =
- + assertThrows(UnsupportedOperationException.class,
- + () -> colorSpaceType.getShapeInfoMessage());
- + assertThat(exception).hasMessageThat().contains(
- + "getShapeInfoMessage() is unsupported for the color space type "
- + + colorSpaceType.name());
- + }
- +
- + @Test
- + public void toBitmapConfigShouldFail() {
- + UnsupportedOperationException exception = assertThrows(
- + UnsupportedOperationException.class, () -> colorSpaceType.toBitmapConfig());
- + assertThat(exception).hasMessageThat().contains(
- + "toBitmapConfig() is unsupported for the color space type "
- + + colorSpaceType.name());
- + }
- }
-
- - @Test
- - public void fromBitmapConfigSucceedsWithSupportedConfig() {
- - assertThat(ColorSpaceType.fromBitmapConfig(config)).isEqualTo(colorSpaceType);
- - }
- -
- - @Test
- - public void toBitmapConfigSucceedsWithSupportedConfig() {
- - assertThat(colorSpaceType.toBitmapConfig()).isEqualTo(config);
- - }
- - }
- -
- - /** Parameterized tests for ImageFormat. */
- - @RunWith(ParameterizedRobolectricTestRunner.class)
- - public static final class ImageFormatTest extends ColorSpaceTypeTest {
- -
- - @Parameter(0)
- - public ColorSpaceType colorSpaceType;
- -
- - /** The ImageFormat that matches the colorSpaceType. */
- - @Parameter(1)
- - public int imageFormat;
- -
- - @Parameters(name = "colorSpaceType={0}; imageFormat={1}")
- - public static Collection<Object[]> data() {
- - return Arrays.asList(
- - new Object[][] {
- - {ColorSpaceType.NV21, ImageFormat.NV21},
- - {ColorSpaceType.YV12, ImageFormat.YV12},
- - {ColorSpaceType.YUV_420_888, ImageFormat.YUV_420_888},
- - });
- - }
- -
- - @Test
- - public void fromImageFormatSucceedsWithSupportedImageFormat() {
- - assertThat(ColorSpaceType.fromImageFormat(imageFormat)).isEqualTo(colorSpaceType);
- - }
- - }
- -
- - /** Parameterized tests for YUV image formats: NV12, NV21, YV12, YV21, YUV_420_888. */
- - @RunWith(ParameterizedRobolectricTestRunner.class)
- - public static final class YuvImageTest extends ColorSpaceTypeTest {
- -
- - @Parameter(0)
- - public ColorSpaceType colorSpaceType;
- -
- - @Parameters(name = "colorSpaceType={0}")
- - public static Collection<Object[]> data() {
- - return Arrays.asList(
- - new Object[][] {
- - {ColorSpaceType.NV12},
- - {ColorSpaceType.NV21},
- - {ColorSpaceType.YV12},
- - {ColorSpaceType.YV21},
- - {ColorSpaceType.YUV_420_888},
- - });
- - }
- -
- - @Test
- - public void convertTensorBufferToBitmapShouldFail() {
- - UnsupportedOperationException exception =
- - assertThrows(
- - UnsupportedOperationException.class,
- - () ->
- - colorSpaceType.convertTensorBufferToBitmap(
- - TensorBuffer.createDynamic(DataType.FLOAT32)));
- - assertThat(exception)
- - .hasMessageThat()
- - .contains(
- - "convertTensorBufferToBitmap() is unsupported for the color space type "
- - + colorSpaceType.name());
- - }
- -
- - @Test
- - public void getWidthShouldFail() {
- - UnsupportedOperationException exception =
- - assertThrows(
- - UnsupportedOperationException.class, () -> colorSpaceType.getWidth(new int[] {}));
- - assertThat(exception)
- - .hasMessageThat()
- - .contains(
- - "getWidth() only supports RGB and GRAYSCALE formats, but not "
- - + colorSpaceType.name());
- - }
- -
- - @Test
- - public void getHeightShouldFail() {
- - UnsupportedOperationException exception =
- - assertThrows(
- - UnsupportedOperationException.class, () -> colorSpaceType.getHeight(new int[] {}));
- - assertThat(exception)
- - .hasMessageThat()
- - .contains(
- - "getHeight() only supports RGB and GRAYSCALE formats, but not "
- - + colorSpaceType.name());
- - }
- -
- - @Test
- - public void assertShapeShouldFail() {
- - UnsupportedOperationException exception =
- - assertThrows(
- - UnsupportedOperationException.class, () -> colorSpaceType.assertShape(new int[] {}));
- - assertThat(exception)
- - .hasMessageThat()
- - .contains(
- - "assertShape() only supports RGB and GRAYSCALE formats, but not "
- - + colorSpaceType.name());
- - }
- -
- - @Test
- - public void getChannelValueShouldFail() {
- - UnsupportedOperationException exception =
- - assertThrows(UnsupportedOperationException.class, () -> colorSpaceType.getChannelValue());
- - assertThat(exception)
- - .hasMessageThat()
- - .contains(
- - "getChannelValue() is unsupported for the color space type " + colorSpaceType.name());
- - }
- -
- - @Test
- - public void getNormalizedShapeShouldFail() {
- - UnsupportedOperationException exception =
- - assertThrows(
- - UnsupportedOperationException.class,
- - () -> colorSpaceType.getNormalizedShape(new int[] {}));
- - assertThat(exception)
- - .hasMessageThat()
- - .contains(
- - "getNormalizedShape() is unsupported for the color space type "
- - + colorSpaceType.name());
- - }
- -
- - @Test
- - public void getShapeInfoMessageShouldFail() {
- - UnsupportedOperationException exception =
- - assertThrows(
- - UnsupportedOperationException.class, () -> colorSpaceType.getShapeInfoMessage());
- - assertThat(exception)
- - .hasMessageThat()
- - .contains(
- - "getShapeInfoMessage() is unsupported for the color space type "
- - + colorSpaceType.name());
- - }
- -
- - @Test
- - public void toBitmapConfigShouldFail() {
- - UnsupportedOperationException exception =
- - assertThrows(UnsupportedOperationException.class, () -> colorSpaceType.toBitmapConfig());
- - assertThat(exception)
- - .hasMessageThat()
- - .contains(
- - "toBitmapConfig() is unsupported for the color space type " + colorSpaceType.name());
- - }
- - }
- -
- - /** Parameterized tests for assertNumElements/getNumElements with all image formats. */
- - @RunWith(ParameterizedRobolectricTestRunner.class)
- - public static final class AssertNumElementsTest extends ColorSpaceTypeTest {
- - private static final int HEIGHT = 2;
- - private static final int WIDTH = 3;
- - private static final int LESS_NUM_ELEMENTS = 5; // less than expected
- - private static final int MORE_NUM_ELEMENTS = 20; // more than expected. OK.
- - @Rule public ErrorCollector errorCollector = new ErrorCollector();
- -
- - @Parameter(0)
- - public ColorSpaceType colorSpaceType;
- -
- - @Parameter(1)
- - public int expectedNumElements;
- -
- - @Parameters(name = "colorSpaceType={0};expectedNumElements={1}")
- - public static Collection<Object[]> data() {
- - return Arrays.asList(
- - new Object[][] {
- - {ColorSpaceType.RGB, 18},
- - {ColorSpaceType.GRAYSCALE, 6},
- - {ColorSpaceType.NV12, 10},
- - {ColorSpaceType.NV21, 10},
- - {ColorSpaceType.YV12, 10},
- - {ColorSpaceType.YV21, 10},
- - });
- - }
- -
- - @Test
- - public void getNumElementsShouldSucceedWithExpectedNumElements() {
- - assertThat(colorSpaceType.getNumElements(HEIGHT, WIDTH)).isEqualTo(expectedNumElements);
- - }
- -
- - @Test
- - public void assertNumElementsShouldSucceedWithMoreNumElements() {
- - errorCollector.checkSucceeds(
- - () -> {
- - colorSpaceType.assertNumElements(MORE_NUM_ELEMENTS, HEIGHT, WIDTH);
- - return null;
- - });
- - }
- -
- - @Test
- - public void assertNumElementsShouldFailWithLessNumElements() {
- - IllegalArgumentException exception =
- - assertThrows(
- - IllegalArgumentException.class,
- - () -> colorSpaceType.assertNumElements(LESS_NUM_ELEMENTS, HEIGHT, WIDTH));
- - assertThat(exception)
- - .hasMessageThat()
- - .contains(
- - String.format(
- - "The given number of elements (%d) does not match the image (%s) in %d x %d. The"
- - + " expected number of elements should be at least %d.",
- - LESS_NUM_ELEMENTS, colorSpaceType.name(), HEIGHT, WIDTH, expectedNumElements));
- - }
- - }
- -
- - /** General tests of ColorSpaceTypeTest. */
- - @RunWith(RobolectricTestRunner.class)
- - public static final class General extends ColorSpaceTypeTest {
- -
- - @Test
- - public void convertTensorBufferToBitmapShouldSuccessWithRGB() {
- - TensorBuffer buffer = createRgbTensorBuffer(DataType.UINT8, false);
- - Bitmap bitmap = ColorSpaceType.RGB.convertTensorBufferToBitmap(buffer);
- -
- - Bitmap expectedBitmap = createRgbBitmap();
- - assertThat(bitmap.sameAs(expectedBitmap)).isTrue();
- + /** Parameterized tests for assertNumElements/getNumElements with all image formats. */
- + @RunWith(ParameterizedRobolectricTestRunner.class)
- + public static final class AssertNumElementsTest extends ColorSpaceTypeTest {
- + private static final int HEIGHT = 2;
- + private static final int WIDTH = 3;
- + private static final int LESS_NUM_ELEMENTS = 5; // less than expected
- + private static final int MORE_NUM_ELEMENTS = 20; // more than expected. OK.
- + @Rule
- + public ErrorCollector errorCollector = new ErrorCollector();
- +
- + @Parameter(0)
- + public ColorSpaceType colorSpaceType;
- +
- + @Parameter(1)
- + public int expectedNumElements;
- +
- + @Parameters(name = "colorSpaceType={0};expectedNumElements={1}")
- + public static Collection<Object[]> data() {
- + return Arrays.asList(new Object[][] {
- + {ColorSpaceType.RGB, 18},
- + {ColorSpaceType.GRAYSCALE, 6},
- + {ColorSpaceType.NV12, 10},
- + {ColorSpaceType.NV21, 10},
- + {ColorSpaceType.YV12, 10},
- + {ColorSpaceType.YV21, 10},
- + });
- + }
- +
- + @Test
- + public void getNumElementsShouldSucceedWithExpectedNumElements() {
- + assertThat(colorSpaceType.getNumElements(HEIGHT, WIDTH)).isEqualTo(expectedNumElements);
- + }
- +
- + @Test
- + public void assertNumElementsShouldSucceedWithMoreNumElements() {
- + errorCollector.checkSucceeds(() -> {
- + colorSpaceType.assertNumElements(MORE_NUM_ELEMENTS, HEIGHT, WIDTH);
- + return null;
- + });
- + }
- +
- + @Test
- + public void assertNumElementsShouldFailWithLessNumElements() {
- + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
- + () -> colorSpaceType.assertNumElements(LESS_NUM_ELEMENTS, HEIGHT, WIDTH));
- + assertThat(exception).hasMessageThat().contains(String.format(
- + "The given number of elements (%d) does not match the image (%s) in %d x %d. The"
- + + " expected number of elements should be at least %d.",
- + LESS_NUM_ELEMENTS, colorSpaceType.name(), HEIGHT, WIDTH, expectedNumElements));
- + }
- }
-
- - @Test
- - public void fromBitmapConfigFailsWithUnsupportedConfig() {
- - Config unsupportedConfig = Config.ARGB_4444;
- - IllegalArgumentException exception =
- - assertThrows(
- - IllegalArgumentException.class,
- - () -> ColorSpaceType.fromBitmapConfig(unsupportedConfig));
- - assertThat(exception)
- - .hasMessageThat()
- - .contains("Bitmap configuration: " + unsupportedConfig + ", is not supported yet.");
- + /** General tests of ColorSpaceTypeTest. */
- + @RunWith(RobolectricTestRunner.class)
- + public static final class General extends ColorSpaceTypeTest {
- + @Test
- + public void convertTensorBufferToBitmapShouldSuccessWithRGB() {
- + TensorBuffer buffer = createRgbTensorBuffer(DataType.UINT8, false);
- + Bitmap bitmap = ColorSpaceType.RGB.convertTensorBufferToBitmap(buffer);
- +
- + Bitmap expectedBitmap = createRgbBitmap();
- + assertThat(bitmap.sameAs(expectedBitmap)).isTrue();
- + }
- +
- + @Test
- + public void fromBitmapConfigFailsWithUnsupportedConfig() {
- + Config unsupportedConfig = Config.ARGB_4444;
- + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
- + () -> ColorSpaceType.fromBitmapConfig(unsupportedConfig));
- + assertThat(exception).hasMessageThat().contains(
- + "Bitmap configuration: " + unsupportedConfig + ", is not supported yet.");
- + }
- }
- - }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ImageConversionsInstrumentedTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ImageConversionsInstrumentedTest.java
- index 1a4d367bf0fe1..49efc4273911c 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ImageConversionsInstrumentedTest.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ImageConversionsInstrumentedTest.java
- @@ -21,7 +21,9 @@ import static android.graphics.Color.BLUE;
- import static android.graphics.Color.GREEN;
- import static android.graphics.Color.RED;
- import static android.graphics.Color.WHITE;
- +
- import static com.google.common.truth.Truth.assertThat;
- +
- import static org.junit.Assert.assertThrows;
- import static org.tensorflow.lite.support.image.ImageConversions.convertGrayscaleTensorBufferToBitmap;
-
- @@ -30,10 +32,10 @@ import android.content.res.AssetManager;
- import android.graphics.Bitmap;
- import android.graphics.BitmapFactory;
- import android.util.Log;
- +
- import androidx.test.core.app.ApplicationProvider;
- import androidx.test.ext.junit.runners.AndroidJUnit4;
- -import java.io.IOException;
- -import java.util.Arrays;
- +
- import org.junit.Assert;
- import org.junit.Before;
- import org.junit.Test;
- @@ -43,192 +45,190 @@ import org.junit.runners.Suite.SuiteClasses;
- import org.tensorflow.lite.DataType;
- import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
-
- +import java.io.IOException;
- +import java.util.Arrays;
- +
- /** Instrumented unit test for {@link ImageConversions}. */
- @RunWith(Suite.class)
- -@SuiteClasses({
- - ImageConversionsInstrumentedTest.TensorBufferToBitmap.class,
- - ImageConversionsInstrumentedTest.BitmapToTensorBuffer.class
- -})
- +@SuiteClasses({ImageConversionsInstrumentedTest.TensorBufferToBitmap.class,
- + ImageConversionsInstrumentedTest.BitmapToTensorBuffer.class})
- public class ImageConversionsInstrumentedTest {
- + /** Tests for the TensorBuffer data type and normalized form. */
- + // Note that parameterized test with android_library_instrumentation_tests is currently not
- + // supported internally.
- + @RunWith(AndroidJUnit4.class)
- + public static final class TensorBufferToBitmap extends ImageConversionsInstrumentedTest {
- + @Test
- + public void convertGrayscaleTensorBufferToBitmapShouldSuccessWithFloatNormalized() {
- + DataType dataType = DataType.FLOAT32;
- + boolean isNormalized = true;
- +
- + TensorBuffer buffer =
- + TestImageCreator.createGrayscaleTensorBuffer(dataType, isNormalized);
- + Bitmap bitmap = convertGrayscaleTensorBufferToBitmap(buffer);
- +
- + Bitmap expectedBitmap = TestImageCreator.createGrayscaleBitmap();
- + assertThat(bitmap.sameAs(expectedBitmap)).isTrue();
- + }
-
- - /** Tests for the TensorBuffer data type and normalized form. */
- - // Note that parameterized test with android_library_instrumentation_tests is currently not
- - // supported internally.
- - @RunWith(AndroidJUnit4.class)
- - public static final class TensorBufferToBitmap extends ImageConversionsInstrumentedTest {
- -
- - @Test
- - public void convertGrayscaleTensorBufferToBitmapShouldSuccessWithFloatNormalized() {
- - DataType dataType = DataType.FLOAT32;
- - boolean isNormalized = true;
- + @Test
- + public void convertGrayscaleTensorBufferToBitmapShouldSuccessWithFloatUnnormalized() {
- + DataType dataType = DataType.FLOAT32;
- + boolean isNormalized = false;
-
- - TensorBuffer buffer = TestImageCreator.createGrayscaleTensorBuffer(dataType, isNormalized);
- - Bitmap bitmap = convertGrayscaleTensorBufferToBitmap(buffer);
- + TensorBuffer buffer =
- + TestImageCreator.createGrayscaleTensorBuffer(dataType, isNormalized);
- + Bitmap bitmap = convertGrayscaleTensorBufferToBitmap(buffer);
-
- - Bitmap expectedBitmap = TestImageCreator.createGrayscaleBitmap();
- - assertThat(bitmap.sameAs(expectedBitmap)).isTrue();
- - }
- -
- - @Test
- - public void convertGrayscaleTensorBufferToBitmapShouldSuccessWithFloatUnnormalized() {
- - DataType dataType = DataType.FLOAT32;
- - boolean isNormalized = false;
- + Bitmap expectedBitmap = TestImageCreator.createGrayscaleBitmap();
- + assertThat(bitmap.sameAs(expectedBitmap)).isTrue();
- + }
-
- - TensorBuffer buffer = TestImageCreator.createGrayscaleTensorBuffer(dataType, isNormalized);
- - Bitmap bitmap = convertGrayscaleTensorBufferToBitmap(buffer);
- + @Test
- + public void convertGrayscaleTensorBufferToBitmapShouldSuccessWithUint8Normalized() {
- + DataType dataType = DataType.UINT8;
- + boolean isNormalized = true;
-
- - Bitmap expectedBitmap = TestImageCreator.createGrayscaleBitmap();
- - assertThat(bitmap.sameAs(expectedBitmap)).isTrue();
- - }
- + TensorBuffer buffer =
- + TestImageCreator.createGrayscaleTensorBuffer(dataType, isNormalized);
- + Bitmap bitmap = convertGrayscaleTensorBufferToBitmap(buffer);
-
- - @Test
- - public void convertGrayscaleTensorBufferToBitmapShouldSuccessWithUint8Normalized() {
- - DataType dataType = DataType.UINT8;
- - boolean isNormalized = true;
- -
- - TensorBuffer buffer = TestImageCreator.createGrayscaleTensorBuffer(dataType, isNormalized);
- - Bitmap bitmap = convertGrayscaleTensorBufferToBitmap(buffer);
- + Bitmap expectedBitmap = TestImageCreator.createGrayscaleBitmap();
- + assertThat(bitmap.sameAs(expectedBitmap)).isTrue();
- + }
-
- - Bitmap expectedBitmap = TestImageCreator.createGrayscaleBitmap();
- - assertThat(bitmap.sameAs(expectedBitmap)).isTrue();
- - }
- + @Test
- + public void convertGrayscaleTensorBufferToBitmapShouldSuccessWithUint8Unnormalized() {
- + DataType dataType = DataType.UINT8;
- + boolean isNormalized = false;
-
- - @Test
- - public void convertGrayscaleTensorBufferToBitmapShouldSuccessWithUint8Unnormalized() {
- - DataType dataType = DataType.UINT8;
- - boolean isNormalized = false;
- + TensorBuffer buffer =
- + TestImageCreator.createGrayscaleTensorBuffer(dataType, isNormalized);
- + Bitmap bitmap = convertGrayscaleTensorBufferToBitmap(buffer);
-
- - TensorBuffer buffer = TestImageCreator.createGrayscaleTensorBuffer(dataType, isNormalized);
- - Bitmap bitmap = convertGrayscaleTensorBufferToBitmap(buffer);
- + Bitmap expectedBitmap = TestImageCreator.createGrayscaleBitmap();
- + assertThat(bitmap.sameAs(expectedBitmap)).isTrue();
- + }
-
- - Bitmap expectedBitmap = TestImageCreator.createGrayscaleBitmap();
- - assertThat(bitmap.sameAs(expectedBitmap)).isTrue();
- - }
- + @Test
- + public void
- + convertGrayscaleTensorBufferToBitmapShouldRejectBufferWithInvalidShapeWithFloat() {
- + DataType dataType = DataType.FLOAT32;
- + TensorBuffer buffer = TensorBuffer.createFixedSize(new int[] {2, 5, 10}, dataType);
- +
- + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
- + () -> convertGrayscaleTensorBufferToBitmap(buffer));
- + assertThat(exception).hasMessageThat().contains(
- + "The shape of a grayscale image should be (h, w) or (1, h, w, 1). The provided image"
- + + " shape is " + Arrays.toString(buffer.getShape()));
- + }
-
- - @Test
- - public void convertGrayscaleTensorBufferToBitmapShouldRejectBufferWithInvalidShapeWithFloat() {
- - DataType dataType = DataType.FLOAT32;
- - TensorBuffer buffer = TensorBuffer.createFixedSize(new int[] {2, 5, 10}, dataType);
- -
- - IllegalArgumentException exception =
- - assertThrows(
- - IllegalArgumentException.class, () -> convertGrayscaleTensorBufferToBitmap(buffer));
- - assertThat(exception)
- - .hasMessageThat()
- - .contains(
- - "The shape of a grayscale image should be (h, w) or (1, h, w, 1). The provided image"
- - + " shape is "
- - + Arrays.toString(buffer.getShape()));
- + @Test
- + public void
- + convertGrayscaleTensorBufferToBitmapShouldRejectBufferWithInvalidShapeWithUint8() {
- + DataType dataType = DataType.UINT8;
- + TensorBuffer buffer = TensorBuffer.createFixedSize(new int[] {2, 5, 10}, dataType);
- +
- + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
- + () -> convertGrayscaleTensorBufferToBitmap(buffer));
- + assertThat(exception).hasMessageThat().contains(
- + "The shape of a grayscale image should be (h, w) or (1, h, w, 1). The provided image"
- + + " shape is " + Arrays.toString(buffer.getShape()));
- + }
- }
-
- - @Test
- - public void convertGrayscaleTensorBufferToBitmapShouldRejectBufferWithInvalidShapeWithUint8() {
- - DataType dataType = DataType.UINT8;
- - TensorBuffer buffer = TensorBuffer.createFixedSize(new int[] {2, 5, 10}, dataType);
- -
- - IllegalArgumentException exception =
- - assertThrows(
- - IllegalArgumentException.class, () -> convertGrayscaleTensorBufferToBitmap(buffer));
- - assertThat(exception)
- - .hasMessageThat()
- - .contains(
- - "The shape of a grayscale image should be (h, w) or (1, h, w, 1). The provided image"
- - + " shape is "
- - + Arrays.toString(buffer.getShape()));
- - }
- - }
- -
- - /** BitmapToTensorBuffer tests of ImageConversionsInstrumentedTest. */
- - @RunWith(AndroidJUnit4.class)
- - public static final class BitmapToTensorBuffer extends ImageConversionsInstrumentedTest {
- -
- - private Bitmap greyGrid;
- - private Bitmap colorGrid;
- - private TensorBuffer buffer;
- -
- - static final String GREY_GRID_PATH = "grey_grid.png";
- - static final String COLOR_GRID_PATH = "color_grid.png";
- -
- - @Before
- - public void loadAssets() {
- - Context context = ApplicationProvider.getApplicationContext();
- - AssetManager assetManager = context.getAssets();
- - try {
- - greyGrid = BitmapFactory.decodeStream(assetManager.open(GREY_GRID_PATH));
- - colorGrid = BitmapFactory.decodeStream(assetManager.open(COLOR_GRID_PATH));
- - } catch (IOException e) {
- - Log.e("Test", "Cannot load asset files");
- - }
- - Assert.assertEquals(ARGB_8888, greyGrid.getConfig());
- - Assert.assertEquals(ARGB_8888, colorGrid.getConfig());
- - buffer = TensorBuffer.createDynamic(DataType.UINT8);
- - }
- + /** BitmapToTensorBuffer tests of ImageConversionsInstrumentedTest. */
- + @RunWith(AndroidJUnit4.class)
- + public static final class BitmapToTensorBuffer extends ImageConversionsInstrumentedTest {
- + private Bitmap greyGrid;
- + private Bitmap colorGrid;
- + private TensorBuffer buffer;
- +
- + static final String GREY_GRID_PATH = "grey_grid.png";
- + static final String COLOR_GRID_PATH = "color_grid.png";
- +
- + @Before
- + public void loadAssets() {
- + Context context = ApplicationProvider.getApplicationContext();
- + AssetManager assetManager = context.getAssets();
- + try {
- + greyGrid = BitmapFactory.decodeStream(assetManager.open(GREY_GRID_PATH));
- + colorGrid = BitmapFactory.decodeStream(assetManager.open(COLOR_GRID_PATH));
- + } catch (IOException e) {
- + Log.e("Test", "Cannot load asset files");
- + }
- + Assert.assertEquals(ARGB_8888, greyGrid.getConfig());
- + Assert.assertEquals(ARGB_8888, colorGrid.getConfig());
- + buffer = TensorBuffer.createDynamic(DataType.UINT8);
- + }
-
- - @Test
- - public void testBitmapDimensionLayout() {
- - // This test is not only for proving the correctness of bitmap -> TensorBuffer conversion, but
- - // also for us to better understand how Android Bitmap is storing pixels - height first or
- - // width first.
- - // We use a black image which has a white corner to understand what happens. By setting up the
- - // correct loop to pass the test, we can reveal the order of pixels returned from `getPixels`.
- - // The result shows that Android stores bitmap in an h-first manner. The returned array of
- - // `getPixels` is like [ 1st row, 2nd row, ... ] which is the same with TFLite.
- - Assert.assertEquals(100, greyGrid.getWidth());
- - Assert.assertEquals(100, greyGrid.getHeight());
- - Assert.assertEquals(BLACK, greyGrid.getPixel(25, 25)); // left top
- - Assert.assertEquals(BLACK, greyGrid.getPixel(75, 25)); // right top
- - Assert.assertEquals(WHITE, greyGrid.getPixel(25, 75)); // left bottom
- - Assert.assertEquals(BLACK, greyGrid.getPixel(75, 75)); // right bottom
- -
- - ImageConversions.convertBitmapToTensorBuffer(greyGrid, buffer);
- - Assert.assertArrayEquals(new int[] {100, 100, 3}, buffer.getShape());
- - Assert.assertEquals(DataType.UINT8, buffer.getDataType());
- -
- - int[] pixels = buffer.getIntArray();
- - int index = 0;
- - for (int h = 0; h < 100; h++) {
- - for (int w = 0; w < 100; w++) {
- - int expected = (w < 50 && h >= 50) ? 255 : 0;
- - Assert.assertEquals(expected, pixels[index++]);
- - Assert.assertEquals(expected, pixels[index++]);
- - Assert.assertEquals(expected, pixels[index++]);
- + @Test
- + public void testBitmapDimensionLayout() {
- + // This test is not only for proving the correctness of bitmap -> TensorBuffer
- + // conversion, but also for us to better understand how Android Bitmap is storing pixels
- + // - height first or width first. We use a black image which has a white corner to
- + // understand what happens. By setting up the correct loop to pass the test, we can
- + // reveal the order of pixels returned from `getPixels`. The result shows that Android
- + // stores bitmap in an h-first manner. The returned array of `getPixels` is like [ 1st
- + // row, 2nd row, ... ] which is the same with TFLite.
- + Assert.assertEquals(100, greyGrid.getWidth());
- + Assert.assertEquals(100, greyGrid.getHeight());
- + Assert.assertEquals(BLACK, greyGrid.getPixel(25, 25)); // left top
- + Assert.assertEquals(BLACK, greyGrid.getPixel(75, 25)); // right top
- + Assert.assertEquals(WHITE, greyGrid.getPixel(25, 75)); // left bottom
- + Assert.assertEquals(BLACK, greyGrid.getPixel(75, 75)); // right bottom
- +
- + ImageConversions.convertBitmapToTensorBuffer(greyGrid, buffer);
- + Assert.assertArrayEquals(new int[] {100, 100, 3}, buffer.getShape());
- + Assert.assertEquals(DataType.UINT8, buffer.getDataType());
- +
- + int[] pixels = buffer.getIntArray();
- + int index = 0;
- + for (int h = 0; h < 100; h++) {
- + for (int w = 0; w < 100; w++) {
- + int expected = (w < 50 && h >= 50) ? 255 : 0;
- + Assert.assertEquals(expected, pixels[index++]);
- + Assert.assertEquals(expected, pixels[index++]);
- + Assert.assertEquals(expected, pixels[index++]);
- + }
- + }
- }
- - }
- - }
-
- - @Test
- - public void testBitmapARGB8888ChannelLayout() {
- - // This test is not only for proving the correctness of bitmap -> TensorBuffer conversion, but
- - // also for us to better understand how Android Bitmap is storing pixels - RGB channel or
- - // other possible ordering.
- - // We use an colored grid image to understand what happens. It's a simple grid image with 4
- - // grid in different colors. Passed through our Bitmap -> TensorBuffer conversion which simply
- - // unpack channels from an integer returned from `getPixel`, its channel sequence could be
- - // revealed directly.
- - // The result shows that Android Bitmap has no magic when loading channels. If loading from
- - // PNG images, channel order still remains R-G-B.
- - Assert.assertEquals(100, colorGrid.getWidth());
- - Assert.assertEquals(100, colorGrid.getHeight());
- - Assert.assertEquals(BLUE, colorGrid.getPixel(25, 25)); // left top
- - Assert.assertEquals(BLACK, colorGrid.getPixel(75, 25)); // right top
- - Assert.assertEquals(GREEN, colorGrid.getPixel(25, 75)); // left bottom
- - Assert.assertEquals(RED, colorGrid.getPixel(75, 75)); // right bottom
- -
- - ImageConversions.convertBitmapToTensorBuffer(colorGrid, buffer);
- - Assert.assertArrayEquals(new int[] {100, 100, 3}, buffer.getShape());
- - Assert.assertEquals(DataType.UINT8, buffer.getDataType());
- -
- - int[] pixels = buffer.getIntArray();
- - Assert.assertArrayEquals(new int[] {0, 0, 255}, getChannels(pixels, 25, 25)); // left top
- - Assert.assertArrayEquals(new int[] {0, 0, 0}, getChannels(pixels, 25, 75)); // right top
- - Assert.assertArrayEquals(new int[] {0, 255, 0}, getChannels(pixels, 75, 25)); // left bottom
- - Assert.assertArrayEquals(new int[] {255, 0, 0}, getChannels(pixels, 75, 75)); // right bottom
- - }
- + @Test
- + public void testBitmapARGB8888ChannelLayout() {
- + // This test is not only for proving the correctness of bitmap -> TensorBuffer
- + // conversion, but also for us to better understand how Android Bitmap is storing pixels
- + // - RGB channel or other possible ordering. We use an colored grid image to understand
- + // what happens. It's a simple grid image with 4 grid in different colors. Passed
- + // through our Bitmap -> TensorBuffer conversion which simply unpack channels from an
- + // integer returned from `getPixel`, its channel sequence could be revealed directly.
- + // The result shows that Android Bitmap has no magic when loading channels. If loading
- + // from PNG images, channel order still remains R-G-B.
- + Assert.assertEquals(100, colorGrid.getWidth());
- + Assert.assertEquals(100, colorGrid.getHeight());
- + Assert.assertEquals(BLUE, colorGrid.getPixel(25, 25)); // left top
- + Assert.assertEquals(BLACK, colorGrid.getPixel(75, 25)); // right top
- + Assert.assertEquals(GREEN, colorGrid.getPixel(25, 75)); // left bottom
- + Assert.assertEquals(RED, colorGrid.getPixel(75, 75)); // right bottom
- +
- + ImageConversions.convertBitmapToTensorBuffer(colorGrid, buffer);
- + Assert.assertArrayEquals(new int[] {100, 100, 3}, buffer.getShape());
- + Assert.assertEquals(DataType.UINT8, buffer.getDataType());
- +
- + int[] pixels = buffer.getIntArray();
- + Assert.assertArrayEquals(
- + new int[] {0, 0, 255}, getChannels(pixels, 25, 25)); // left top
- + Assert.assertArrayEquals(new int[] {0, 0, 0}, getChannels(pixels, 25, 75)); // right top
- + Assert.assertArrayEquals(
- + new int[] {0, 255, 0}, getChannels(pixels, 75, 25)); // left bottom
- + Assert.assertArrayEquals(
- + new int[] {255, 0, 0}, getChannels(pixels, 75, 75)); // right bottom
- + }
-
- - /** Helper function only for {@link #testBitmapARGB8888ChannelLayout()}. */
- - private static int[] getChannels(int[] pixels, int h, int w) {
- - int id = (h * 100 + w) * 3;
- - return new int[] {pixels[id++], pixels[id++], pixels[id]};
- + /** Helper function only for {@link #testBitmapARGB8888ChannelLayout()}. */
- + private static int[] getChannels(int[] pixels, int h, int w) {
- + int id = (h * 100 + w) * 3;
- + return new int[] {pixels[id++], pixels[id++], pixels[id]};
- + }
- }
- - }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ImageConversionsTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ImageConversionsTest.java
- index b3300872c2357..c91db9d184f63 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ImageConversionsTest.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ImageConversionsTest.java
- @@ -16,13 +16,13 @@ limitations under the License.
- package org.tensorflow.lite.support.image;
-
- import static com.google.common.truth.Truth.assertThat;
- +
- import static org.junit.Assert.assertThrows;
- import static org.tensorflow.lite.support.image.ImageConversions.convertBitmapToTensorBuffer;
- import static org.tensorflow.lite.support.image.ImageConversions.convertRgbTensorBufferToBitmap;
-
- import android.graphics.Bitmap;
- -import java.util.Arrays;
- -import java.util.Collection;
- +
- import org.junit.Assert;
- import org.junit.Test;
- import org.junit.runner.RunWith;
- @@ -35,93 +35,93 @@ import org.robolectric.RobolectricTestRunner;
- import org.tensorflow.lite.DataType;
- import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
-
- +import java.util.Arrays;
- +import java.util.Collection;
- +
- /** Tests of {@link ImageConversions}. */
- @RunWith(Suite.class)
- @SuiteClasses({ImageConversionsTest.TensorBufferToBitmap.class, ImageConversionsTest.General.class})
- public class ImageConversionsTest {
- -
- - /** Parameterized tests for the TensorBuffer data type and normalized form. */
- - @RunWith(ParameterizedRobolectricTestRunner.class)
- - public static final class TensorBufferToBitmap extends ImageConversionsTest {
- -
- - /** The data type that used to create the TensorBuffer. */
- - @Parameter(0)
- - public DataType dataType;
- -
- - /** Indicates whether the shape is in the normalized form of (1, h, w, 3). */
- - @Parameter(1)
- - public boolean isNormalized;
- -
- - @Parameters(name = "dataType={0}; isNormalized={1}")
- - public static Collection<Object[]> data() {
- - return Arrays.asList(
- - new Object[][] {
- - {DataType.FLOAT32, true}, {DataType.UINT8, true},
- - {DataType.FLOAT32, false}, {DataType.UINT8, false},
- - });
- - }
- -
- - @Test
- - public void convertRgbTensorBufferToBitmapShouldSuccess() {
- - TensorBuffer buffer = TestImageCreator.createRgbTensorBuffer(dataType, isNormalized);
- - Bitmap bitmap = convertRgbTensorBufferToBitmap(buffer);
- -
- - Bitmap expectedBitmap = TestImageCreator.createRgbBitmap();
- - assertThat(bitmap.sameAs(expectedBitmap)).isTrue();
- - }
- -
- - @Test
- - public void convertRgbTensorBufferToBitmapShouldRejectBufferWithInvalidShape() {
- - TensorBuffer buffer = TensorBuffer.createFixedSize(new int[] {2, 5, 10, 3}, dataType);
- -
- - IllegalArgumentException exception =
- - assertThrows(
- - IllegalArgumentException.class, () -> convertRgbTensorBufferToBitmap(buffer));
- - assertThat(exception)
- - .hasMessageThat()
- - .contains(
- - "The shape of a RGB image should be (h, w, c) or (1, h, w, c), and channels"
- - + " representing R, G, B in order. The provided image shape is "
- - + Arrays.toString(buffer.getShape()));
- - }
- - }
- -
- - /** General tests of ImageConversionsTest. */
- - @RunWith(RobolectricTestRunner.class)
- - public static final class General extends ImageConversionsTest {
- -
- - private static final Bitmap rgbBitmap = TestImageCreator.createRgbBitmap();
- - private static final TensorBuffer rgbTensorBuffer =
- - TestImageCreator.createRgbTensorBuffer(DataType.UINT8, false);
- -
- - @Test
- - public void convertBitmapToTensorBufferShouldSuccess() {
- - TensorBuffer intBuffer = TensorBuffer.createFixedSize(new int[] {10, 10, 3}, DataType.UINT8);
- - convertBitmapToTensorBuffer(rgbBitmap, intBuffer);
- - assertThat(areEqualIntTensorBuffer(intBuffer, rgbTensorBuffer)).isTrue();
- - }
- -
- - @Test
- - public void convertBitmapToTensorBufferShouldThrowShapeNotExactlySame() {
- - TensorBuffer intBuffer = TensorBuffer.createFixedSize(new int[] {5, 20, 3}, DataType.UINT8);
- - Assert.assertThrows(
- - IllegalArgumentException.class, () -> convertBitmapToTensorBuffer(rgbBitmap, intBuffer));
- + /** Parameterized tests for the TensorBuffer data type and normalized form. */
- + @RunWith(ParameterizedRobolectricTestRunner.class)
- + public static final class TensorBufferToBitmap extends ImageConversionsTest {
- + /** The data type that used to create the TensorBuffer. */
- + @Parameter(0)
- + public DataType dataType;
- +
- + /** Indicates whether the shape is in the normalized form of (1, h, w, 3). */
- + @Parameter(1)
- + public boolean isNormalized;
- +
- + @Parameters(name = "dataType={0}; isNormalized={1}")
- + public static Collection<Object[]> data() {
- + return Arrays.asList(new Object[][] {
- + {DataType.FLOAT32, true},
- + {DataType.UINT8, true},
- + {DataType.FLOAT32, false},
- + {DataType.UINT8, false},
- + });
- + }
- +
- + @Test
- + public void convertRgbTensorBufferToBitmapShouldSuccess() {
- + TensorBuffer buffer = TestImageCreator.createRgbTensorBuffer(dataType, isNormalized);
- + Bitmap bitmap = convertRgbTensorBufferToBitmap(buffer);
- +
- + Bitmap expectedBitmap = TestImageCreator.createRgbBitmap();
- + assertThat(bitmap.sameAs(expectedBitmap)).isTrue();
- + }
- +
- + @Test
- + public void convertRgbTensorBufferToBitmapShouldRejectBufferWithInvalidShape() {
- + TensorBuffer buffer = TensorBuffer.createFixedSize(new int[] {2, 5, 10, 3}, dataType);
- +
- + IllegalArgumentException exception = assertThrows(
- + IllegalArgumentException.class, () -> convertRgbTensorBufferToBitmap(buffer));
- + assertThat(exception).hasMessageThat().contains(
- + "The shape of a RGB image should be (h, w, c) or (1, h, w, c), and channels"
- + + " representing R, G, B in order. The provided image shape is "
- + + Arrays.toString(buffer.getShape()));
- + }
- }
-
- - @Test
- - public void convertBitmapToTensorBufferShouldCastIntToFloatIfNeeded() {
- - TensorBuffer floatBuffer = TensorBuffer.createDynamic(DataType.FLOAT32);
- - convertBitmapToTensorBuffer(rgbBitmap, floatBuffer);
- - assertThat(areEqualIntTensorBuffer(floatBuffer, rgbTensorBuffer)).isTrue();
- + /** General tests of ImageConversionsTest. */
- + @RunWith(RobolectricTestRunner.class)
- + public static final class General extends ImageConversionsTest {
- + private static final Bitmap rgbBitmap = TestImageCreator.createRgbBitmap();
- + private static final TensorBuffer rgbTensorBuffer =
- + TestImageCreator.createRgbTensorBuffer(DataType.UINT8, false);
- +
- + @Test
- + public void convertBitmapToTensorBufferShouldSuccess() {
- + TensorBuffer intBuffer =
- + TensorBuffer.createFixedSize(new int[] {10, 10, 3}, DataType.UINT8);
- + convertBitmapToTensorBuffer(rgbBitmap, intBuffer);
- + assertThat(areEqualIntTensorBuffer(intBuffer, rgbTensorBuffer)).isTrue();
- + }
- +
- + @Test
- + public void convertBitmapToTensorBufferShouldThrowShapeNotExactlySame() {
- + TensorBuffer intBuffer =
- + TensorBuffer.createFixedSize(new int[] {5, 20, 3}, DataType.UINT8);
- + Assert.assertThrows(IllegalArgumentException.class,
- + () -> convertBitmapToTensorBuffer(rgbBitmap, intBuffer));
- + }
- +
- + @Test
- + public void convertBitmapToTensorBufferShouldCastIntToFloatIfNeeded() {
- + TensorBuffer floatBuffer = TensorBuffer.createDynamic(DataType.FLOAT32);
- + convertBitmapToTensorBuffer(rgbBitmap, floatBuffer);
- + assertThat(areEqualIntTensorBuffer(floatBuffer, rgbTensorBuffer)).isTrue();
- + }
- }
- - }
-
- - private static boolean areEqualIntTensorBuffer(TensorBuffer tb1, TensorBuffer tb2) {
- - if (!Arrays.equals(tb1.getShape(), tb2.getShape())) {
- - return false;
- + private static boolean areEqualIntTensorBuffer(TensorBuffer tb1, TensorBuffer tb2) {
- + if (!Arrays.equals(tb1.getShape(), tb2.getShape())) {
- + return false;
- + }
- + int[] arr1 = tb1.getIntArray();
- + int[] arr2 = tb2.getIntArray();
- + return Arrays.equals(arr1, arr2);
- }
- - int[] arr1 = tb1.getIntArray();
- - int[] arr2 = tb2.getIntArray();
- - return Arrays.equals(arr1, arr2);
- - }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ImageProcessorInstrumentedTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ImageProcessorInstrumentedTest.java
- index 8ac27fdb07ad1..e9cbfc1dc50bd 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ImageProcessorInstrumentedTest.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ImageProcessorInstrumentedTest.java
- @@ -16,10 +16,13 @@ limitations under the License.
- package org.tensorflow.lite.support.image;
-
- import static com.google.common.truth.Truth.assertThat;
- +
- import static org.junit.Assert.assertThrows;
-
- import android.graphics.Bitmap;
- +
- import androidx.test.ext.junit.runners.AndroidJUnit4;
- +
- import org.junit.Before;
- import org.junit.Test;
- import org.junit.runner.RunWith;
- @@ -30,120 +33,114 @@ import org.tensorflow.lite.support.image.ops.Rot90Op;
- /** Instrumented unit test for {@link ImageProcessor}. */
- @RunWith(AndroidJUnit4.class)
- public final class ImageProcessorInstrumentedTest {
- + private Bitmap exampleBitmap;
- + private TensorImage input;
- + private ImageProcessor processor;
- +
- + private static final int EXAMPLE_WIDTH = 10;
- + private static final int EXAMPLE_HEIGHT = 15;
- +
- + @Before
- + public void setUp() {
- + // The default number of rotation is once.
- + processor = new ImageProcessor.Builder().add(new Rot90Op()).build();
- + exampleBitmap = createExampleBitmap();
- + input = new TensorImage(DataType.UINT8);
- + input.load(exampleBitmap);
- + }
- +
- + @Test
- + public void updateNumberOfRotations_rotateTwice() {
- + int numberOfRotations = 2;
- +
- + processor.updateNumberOfRotations(numberOfRotations);
- + TensorImage output = processor.process(input);
- +
- + Bitmap outputBitmap = output.getBitmap();
- + assertExampleBitmapWithTwoRotations(outputBitmap);
- + }
- +
- + @Test
- + public void updateNumberOfRotationsWithOpIndex_rotateTwiceAndOpIndex0() {
- + int numberOfRotations = 2;
- + int occurrence = 0;
- +
- + processor.updateNumberOfRotations(numberOfRotations, occurrence);
- + TensorImage output = processor.process(input);
- +
- + Bitmap outputBitmap = output.getBitmap();
- + assertExampleBitmapWithTwoRotations(outputBitmap);
- + }
- +
- + @Test
- + public void updateNumberOfRotationsWithOpIndex_negativeOpIndex() {
- + int numberOfRotations = 2;
- + int negativeOpIndex = -1;
- +
- + IndexOutOfBoundsException exception = assertThrows(IndexOutOfBoundsException.class,
- + () -> processor.updateNumberOfRotations(numberOfRotations, negativeOpIndex));
- + assertThat(exception).hasMessageThat().isEqualTo("occurrence (-1) must not be negative");
- + }
- +
- + @Test
- + public void updateNumberOfRotationsWithOpIndex_occurrenceEqualToTheNumberOfRot90Op() {
- + int numberOfRotations = 2;
- + int occurrence = 1;
- +
- + IndexOutOfBoundsException exception = assertThrows(IndexOutOfBoundsException.class,
- + () -> processor.updateNumberOfRotations(numberOfRotations, occurrence));
- + assertThat(exception).hasMessageThat().isEqualTo(
- + "occurrence (1) must be less than size (1)");
- + }
- +
- + @Test
- + public void updateNumberOfRotationsWithOpIndex_noRot90OpIsAddedToImageProcessor() {
- + int numberOfRotations = 2;
- + int occurrence = 1;
- + // Add an op other than Rot90Op into ImageProcessor.
- + ImageProcessor processor =
- + new ImageProcessor.Builder().add(new ResizeWithCropOrPadOp(5, 5)).build();
- +
- + IllegalStateException exception = assertThrows(IllegalStateException.class,
- + () -> processor.updateNumberOfRotations(numberOfRotations, occurrence));
- + assertThat(exception).hasMessageThat().isEqualTo(
- + "The Rot90Op has not been added to the ImageProcessor.");
- + }
- +
- + @Test
- + public void updateNumberOfRotationsWithOpIndex_twoRot90Ops() {
- + // The overall effect of the two rotations is equivalent to rotating for twice.
- + int numberOfRotations0 = 5;
- + int numberOfRotations1 = 1;
- +
- + // Add two Rot90Ops into ImageProcessor.
- + ImageProcessor processor =
- + new ImageProcessor.Builder().add(new Rot90Op()).add(new Rot90Op()).build();
- + processor.updateNumberOfRotations(numberOfRotations0, /*occurrence=*/0);
- + processor.updateNumberOfRotations(numberOfRotations1, /*occurrence=*/1);
- +
- + TensorImage output = processor.process(input);
- + Bitmap outputBitmap = output.getBitmap();
- + assertExampleBitmapWithTwoRotations(outputBitmap);
- + }
-
- - private Bitmap exampleBitmap;
- - private TensorImage input;
- - private ImageProcessor processor;
- -
- - private static final int EXAMPLE_WIDTH = 10;
- - private static final int EXAMPLE_HEIGHT = 15;
- -
- - @Before
- - public void setUp() {
- - // The default number of rotation is once.
- - processor = new ImageProcessor.Builder().add(new Rot90Op()).build();
- - exampleBitmap = createExampleBitmap();
- - input = new TensorImage(DataType.UINT8);
- - input.load(exampleBitmap);
- - }
- -
- - @Test
- - public void updateNumberOfRotations_rotateTwice() {
- - int numberOfRotations = 2;
- -
- - processor.updateNumberOfRotations(numberOfRotations);
- - TensorImage output = processor.process(input);
- -
- - Bitmap outputBitmap = output.getBitmap();
- - assertExampleBitmapWithTwoRotations(outputBitmap);
- - }
- -
- - @Test
- - public void updateNumberOfRotationsWithOpIndex_rotateTwiceAndOpIndex0() {
- - int numberOfRotations = 2;
- - int occurrence = 0;
- -
- - processor.updateNumberOfRotations(numberOfRotations, occurrence);
- - TensorImage output = processor.process(input);
- -
- - Bitmap outputBitmap = output.getBitmap();
- - assertExampleBitmapWithTwoRotations(outputBitmap);
- - }
- -
- - @Test
- - public void updateNumberOfRotationsWithOpIndex_negativeOpIndex() {
- - int numberOfRotations = 2;
- - int negativeOpIndex = -1;
- -
- - IndexOutOfBoundsException exception =
- - assertThrows(
- - IndexOutOfBoundsException.class,
- - () -> processor.updateNumberOfRotations(numberOfRotations, negativeOpIndex));
- - assertThat(exception).hasMessageThat().isEqualTo("occurrence (-1) must not be negative");
- - }
- -
- - @Test
- - public void updateNumberOfRotationsWithOpIndex_occurrenceEqualToTheNumberOfRot90Op() {
- - int numberOfRotations = 2;
- - int occurrence = 1;
- -
- - IndexOutOfBoundsException exception =
- - assertThrows(
- - IndexOutOfBoundsException.class,
- - () -> processor.updateNumberOfRotations(numberOfRotations, occurrence));
- - assertThat(exception).hasMessageThat().isEqualTo("occurrence (1) must be less than size (1)");
- - }
- -
- - @Test
- - public void updateNumberOfRotationsWithOpIndex_noRot90OpIsAddedToImageProcessor() {
- - int numberOfRotations = 2;
- - int occurrence = 1;
- - // Add an op other than Rot90Op into ImageProcessor.
- - ImageProcessor processor =
- - new ImageProcessor.Builder().add(new ResizeWithCropOrPadOp(5, 5)).build();
- -
- - IllegalStateException exception =
- - assertThrows(
- - IllegalStateException.class,
- - () -> processor.updateNumberOfRotations(numberOfRotations, occurrence));
- - assertThat(exception)
- - .hasMessageThat()
- - .isEqualTo("The Rot90Op has not been added to the ImageProcessor.");
- - }
- -
- - @Test
- - public void updateNumberOfRotationsWithOpIndex_twoRot90Ops() {
- - // The overall effect of the two rotations is equivalent to rotating for twice.
- - int numberOfRotations0 = 5;
- - int numberOfRotations1 = 1;
- -
- - // Add two Rot90Ops into ImageProcessor.
- - ImageProcessor processor =
- - new ImageProcessor.Builder().add(new Rot90Op()).add(new Rot90Op()).build();
- - processor.updateNumberOfRotations(numberOfRotations0, /*occurrence=*/ 0);
- - processor.updateNumberOfRotations(numberOfRotations1, /*occurrence=*/ 1);
- -
- - TensorImage output = processor.process(input);
- - Bitmap outputBitmap = output.getBitmap();
- - assertExampleBitmapWithTwoRotations(outputBitmap);
- - }
- -
- - private void assertExampleBitmapWithTwoRotations(Bitmap bitmapRotated) {
- - assertThat(bitmapRotated.getWidth()).isEqualTo(EXAMPLE_WIDTH);
- - assertThat(bitmapRotated.getHeight()).isEqualTo(EXAMPLE_HEIGHT);
- - for (int i = 0; i < exampleBitmap.getWidth(); i++) {
- - for (int j = 0; j < exampleBitmap.getHeight(); j++) {
- - assertThat(exampleBitmap.getPixel(i, j))
- - .isEqualTo(bitmapRotated.getPixel(EXAMPLE_WIDTH - 1 - i, EXAMPLE_HEIGHT - 1 - j));
- - }
- + private void assertExampleBitmapWithTwoRotations(Bitmap bitmapRotated) {
- + assertThat(bitmapRotated.getWidth()).isEqualTo(EXAMPLE_WIDTH);
- + assertThat(bitmapRotated.getHeight()).isEqualTo(EXAMPLE_HEIGHT);
- + for (int i = 0; i < exampleBitmap.getWidth(); i++) {
- + for (int j = 0; j < exampleBitmap.getHeight(); j++) {
- + assertThat(exampleBitmap.getPixel(i, j))
- + .isEqualTo(bitmapRotated.getPixel(
- + EXAMPLE_WIDTH - 1 - i, EXAMPLE_HEIGHT - 1 - j));
- + }
- + }
- }
- - }
-
- - private static Bitmap createExampleBitmap() {
- - int[] colors = new int[EXAMPLE_WIDTH * EXAMPLE_HEIGHT];
- - for (int i = 0; i < EXAMPLE_WIDTH * EXAMPLE_HEIGHT; i++) {
- - colors[i] = (i << 16) | ((i + 1) << 8) | (i + 2);
- + private static Bitmap createExampleBitmap() {
- + int[] colors = new int[EXAMPLE_WIDTH * EXAMPLE_HEIGHT];
- + for (int i = 0; i < EXAMPLE_WIDTH * EXAMPLE_HEIGHT; i++) {
- + colors[i] = (i << 16) | ((i + 1) << 8) | (i + 2);
- + }
- + return Bitmap.createBitmap(colors, EXAMPLE_WIDTH, EXAMPLE_HEIGHT, Bitmap.Config.ARGB_8888);
- }
- - return Bitmap.createBitmap(colors, EXAMPLE_WIDTH, EXAMPLE_HEIGHT, Bitmap.Config.ARGB_8888);
- - }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ImageProcessorTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ImageProcessorTest.java
- index a655f4a506900..a93ba5465125c 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ImageProcessorTest.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ImageProcessorTest.java
- @@ -16,10 +16,12 @@ limitations under the License.
- package org.tensorflow.lite.support.image;
-
- import static com.google.common.truth.Truth.assertThat;
- +
- import static org.junit.Assert.assertThrows;
-
- import android.graphics.Bitmap;
- import android.graphics.RectF;
- +
- import org.junit.Test;
- import org.junit.runner.RunWith;
- import org.robolectric.RobolectricTestRunner;
- @@ -34,115 +36,112 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
- /** Tests for {@link ImageProcessor}. */
- @RunWith(RobolectricTestRunner.class)
- public final class ImageProcessorTest {
- + private static final int EXAMPLE_WIDTH = 10;
- + private static final int EXAMPLE_HEIGHT = 15;
- + private static final int EXAMPLE_NUM_PIXELS = EXAMPLE_HEIGHT * EXAMPLE_WIDTH;
- + private static final int EXAMPLE_NUM_CHANNELS = 3;
- + private static final float MEAN = 127.5f;
- + private static final float STDDEV = 127.5f;
- +
- + @Test
- + public void testBuild() {
- + ImageProcessor processor =
- + new ImageProcessor.Builder().add(new NormalizeOp(MEAN, STDDEV)).build();
- + assertThat(processor).isNotNull();
- + }
-
- - private static final int EXAMPLE_WIDTH = 10;
- - private static final int EXAMPLE_HEIGHT = 15;
- - private static final int EXAMPLE_NUM_PIXELS = EXAMPLE_HEIGHT * EXAMPLE_WIDTH;
- - private static final int EXAMPLE_NUM_CHANNELS = 3;
- - private static final float MEAN = 127.5f;
- - private static final float STDDEV = 127.5f;
- -
- - @Test
- - public void testBuild() {
- - ImageProcessor processor =
- - new ImageProcessor.Builder().add(new NormalizeOp(MEAN, STDDEV)).build();
- - assertThat(processor).isNotNull();
- - }
- -
- - @Test
- - public void testNormalize() {
- - TensorImage input = new TensorImage(DataType.FLOAT32);
- - input.load(createExampleBitmap());
- - ImageProcessor processor =
- - new ImageProcessor.Builder().add(new NormalizeOp(MEAN, STDDEV)).build();
- - TensorImage output = processor.process(input);
- -
- - float[] pixels = output.getTensorBuffer().getFloatArray();
- - assertThat(pixels.length).isEqualTo(EXAMPLE_NUM_CHANNELS * EXAMPLE_NUM_PIXELS);
- - for (float p : pixels) {
- - assertThat(p).isAtLeast(-1);
- - assertThat(p).isAtMost(1);
- + @Test
- + public void testNormalize() {
- + TensorImage input = new TensorImage(DataType.FLOAT32);
- + input.load(createExampleBitmap());
- + ImageProcessor processor =
- + new ImageProcessor.Builder().add(new NormalizeOp(MEAN, STDDEV)).build();
- + TensorImage output = processor.process(input);
- +
- + float[] pixels = output.getTensorBuffer().getFloatArray();
- + assertThat(pixels.length).isEqualTo(EXAMPLE_NUM_CHANNELS * EXAMPLE_NUM_PIXELS);
- + for (float p : pixels) {
- + assertThat(p).isAtLeast(-1);
- + assertThat(p).isAtMost(1);
- + }
- }
- - }
- -
- - @Test
- - public void testMultipleNormalize() {
- - TensorImage input = new TensorImage(DataType.FLOAT32);
- - input.load(createExampleBitmap());
- - ImageProcessor processor =
- - new ImageProcessor.Builder()
- - .add(new NormalizeOp(MEAN, STDDEV)) // [0, 255] -> [-1, 1]
- - .add(new NormalizeOp(-1, 2)) // [-1, 1] -> [0, 1]
- - .build();
- - TensorImage output = processor.process(input);
- -
- - float[] pixels = output.getTensorBuffer().getFloatArray();
- - assertThat(pixels.length).isEqualTo(EXAMPLE_NUM_CHANNELS * EXAMPLE_NUM_PIXELS);
- - for (float p : pixels) {
- - assertThat(p).isAtLeast(0);
- - assertThat(p).isAtMost(1);
- +
- + @Test
- + public void testMultipleNormalize() {
- + TensorImage input = new TensorImage(DataType.FLOAT32);
- + input.load(createExampleBitmap());
- + ImageProcessor processor =
- + new ImageProcessor.Builder()
- + .add(new NormalizeOp(MEAN, STDDEV)) // [0, 255] -> [-1, 1]
- + .add(new NormalizeOp(-1, 2)) // [-1, 1] -> [0, 1]
- + .build();
- + TensorImage output = processor.process(input);
- +
- + float[] pixels = output.getTensorBuffer().getFloatArray();
- + assertThat(pixels.length).isEqualTo(EXAMPLE_NUM_CHANNELS * EXAMPLE_NUM_PIXELS);
- + for (float p : pixels) {
- + assertThat(p).isAtLeast(0);
- + assertThat(p).isAtMost(1);
- + }
- }
- - }
- -
- - @Test
- - public void inverseTransformRectCorrectly() {
- - ImageProcessor processor =
- - new ImageProcessor.Builder()
- - .add(new ResizeOp(200, 300, ResizeMethod.BILINEAR))
- - .add(new ResizeWithCropOrPadOp(100, 200))
- - .add(new Rot90Op(1))
- - .add(new NormalizeOp(127, 128))
- - .build();
- - RectF transformed = new RectF(0, 50, 100, 150);
- - RectF original = processor.inverseTransform(transformed, 400, 600);
- - assertThat(original.top).isEqualTo(100);
- - assertThat(original.left).isEqualTo(200);
- - assertThat(original.right).isEqualTo(400);
- - assertThat(original.bottom).isEqualTo(300);
- - }
- -
- - @Test
- - public void resizeShouldFailWithNonRgbImages() {
- - int[] data = new int[] {1, 2, 3};
- - TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.UINT8);
- - tensorBuffer.loadArray(data, new int[] {1, 3});
- - TensorImage image = new TensorImage();
- - image.load(tensorBuffer, ColorSpaceType.GRAYSCALE);
- -
- - ImageProcessor processor =
- - new ImageProcessor.Builder().add(new ResizeOp(200, 300, ResizeMethod.BILINEAR)).build();
- -
- - IllegalArgumentException exception =
- - assertThrows(IllegalArgumentException.class, () -> processor.process(image));
- - assertThat(exception)
- - .hasMessageThat()
- - .contains(
- - "Only RGB images are supported in ResizeOp, but not "
- +
- + @Test
- + public void inverseTransformRectCorrectly() {
- + ImageProcessor processor = new ImageProcessor.Builder()
- + .add(new ResizeOp(200, 300, ResizeMethod.BILINEAR))
- + .add(new ResizeWithCropOrPadOp(100, 200))
- + .add(new Rot90Op(1))
- + .add(new NormalizeOp(127, 128))
- + .build();
- + RectF transformed = new RectF(0, 50, 100, 150);
- + RectF original = processor.inverseTransform(transformed, 400, 600);
- + assertThat(original.top).isEqualTo(100);
- + assertThat(original.left).isEqualTo(200);
- + assertThat(original.right).isEqualTo(400);
- + assertThat(original.bottom).isEqualTo(300);
- + }
- +
- + @Test
- + public void resizeShouldFailWithNonRgbImages() {
- + int[] data = new int[] {1, 2, 3};
- + TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.UINT8);
- + tensorBuffer.loadArray(data, new int[] {1, 3});
- + TensorImage image = new TensorImage();
- + image.load(tensorBuffer, ColorSpaceType.GRAYSCALE);
- +
- + ImageProcessor processor = new ImageProcessor.Builder()
- + .add(new ResizeOp(200, 300, ResizeMethod.BILINEAR))
- + .build();
- +
- + IllegalArgumentException exception =
- + assertThrows(IllegalArgumentException.class, () -> processor.process(image));
- + assertThat(exception).hasMessageThat().contains(
- + "Only RGB images are supported in ResizeOp, but not "
- + image.getColorSpaceType().name());
- - }
- -
- - @Test
- - public void normalizeShouldSuccessWithNonRgbImages() {
- - int[] data = new int[] {1, 2, 3};
- - TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.UINT8);
- - tensorBuffer.loadArray(data, new int[] {1, 3});
- - TensorImage image = new TensorImage();
- - image.load(tensorBuffer, ColorSpaceType.GRAYSCALE);
- -
- - ImageProcessor processor =
- - new ImageProcessor.Builder().add(new NormalizeOp(0.5f, 1f)).build();
- - TensorImage output = processor.process(image);
- -
- - float[] pixels = output.getTensorBuffer().getFloatArray();
- - assertThat(pixels).isEqualTo(new float[]{0.5f, 1.5f, 2.5f});
- - }
- -
- - private static Bitmap createExampleBitmap() {
- - int[] colors = new int[EXAMPLE_NUM_PIXELS];
- - for (int i = 0; i < EXAMPLE_NUM_PIXELS; i++) {
- - colors[i] = (i << 16) | ((i + 1) << 8) | (i + 2);
- }
-
- - return Bitmap.createBitmap(colors, EXAMPLE_WIDTH, EXAMPLE_HEIGHT, Bitmap.Config.ARGB_8888);
- - }
- + @Test
- + public void normalizeShouldSuccessWithNonRgbImages() {
- + int[] data = new int[] {1, 2, 3};
- + TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.UINT8);
- + tensorBuffer.loadArray(data, new int[] {1, 3});
- + TensorImage image = new TensorImage();
- + image.load(tensorBuffer, ColorSpaceType.GRAYSCALE);
- +
- + ImageProcessor processor =
- + new ImageProcessor.Builder().add(new NormalizeOp(0.5f, 1f)).build();
- + TensorImage output = processor.process(image);
- +
- + float[] pixels = output.getTensorBuffer().getFloatArray();
- + assertThat(pixels).isEqualTo(new float[] {0.5f, 1.5f, 2.5f});
- + }
- +
- + private static Bitmap createExampleBitmap() {
- + int[] colors = new int[EXAMPLE_NUM_PIXELS];
- + for (int i = 0; i < EXAMPLE_NUM_PIXELS; i++) {
- + colors[i] = (i << 16) | ((i + 1) << 8) | (i + 2);
- + }
- +
- + return Bitmap.createBitmap(colors, EXAMPLE_WIDTH, EXAMPLE_HEIGHT, Bitmap.Config.ARGB_8888);
- + }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/MlImageAdapterTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/MlImageAdapterTest.java
- index 7e61aa8d3ce58..e8caefcab8a04 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/MlImageAdapterTest.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/MlImageAdapterTest.java
- @@ -16,20 +16,19 @@ limitations under the License.
- package org.tensorflow.lite.support.image;
-
- import static com.google.common.truth.Truth.assertThat;
- +
- import static org.junit.Assert.assertThrows;
- import static org.mockito.Mockito.when;
-
- import android.graphics.Bitmap;
- import android.media.Image;
- +
- import com.google.android.odml.image.BitmapMlImageBuilder;
- import com.google.android.odml.image.ByteBufferMlImageBuilder;
- import com.google.android.odml.image.MediaMlImageBuilder;
- import com.google.android.odml.image.MlImage;
- import com.google.android.odml.image.MlImage.ImageFormat;
- -import java.io.IOException;
- -import java.nio.ByteBuffer;
- -import java.util.Arrays;
- -import java.util.Collection;
- +
- import org.junit.Before;
- import org.junit.Test;
- import org.junit.runner.RunWith;
- @@ -42,139 +41,141 @@ import org.robolectric.ParameterizedRobolectricTestRunner.Parameter;
- import org.robolectric.ParameterizedRobolectricTestRunner.Parameters;
- import org.robolectric.RobolectricTestRunner;
-
- +import java.io.IOException;
- +import java.nio.ByteBuffer;
- +import java.util.Arrays;
- +import java.util.Collection;
- +
- /** Test for {@link MlImageAdapter}. */
- @RunWith(Suite.class)
- @SuiteClasses({
- - MlImageAdapterTest.CreateTensorImageFromSupportedByteBufferMlImage.class,
- - MlImageAdapterTest.CreateTensorImageFromUnsupportedByteBufferMlImage.class,
- - MlImageAdapterTest.General.class,
- + MlImageAdapterTest.CreateTensorImageFromSupportedByteBufferMlImage.class,
- + MlImageAdapterTest.CreateTensorImageFromUnsupportedByteBufferMlImage.class,
- + MlImageAdapterTest.General.class,
- })
- public class MlImageAdapterTest {
- -
- - @RunWith(ParameterizedRobolectricTestRunner.class)
- - public static final class CreateTensorImageFromSupportedByteBufferMlImage
- - extends MlImageAdapterTest {
- -
- - @Parameter(0)
- - @ImageFormat
- - public int imageFormat;
- -
- - @Parameter(1)
- - public ColorSpaceType colorSpaceType;
- -
- - @Parameters(name = "imageFormat={0}")
- - public static Collection<Object[]> data() {
- - return Arrays.asList(
- - new Object[][] {
- - {MlImage.IMAGE_FORMAT_RGB, ColorSpaceType.RGB},
- - {MlImage.IMAGE_FORMAT_ALPHA, ColorSpaceType.GRAYSCALE},
- - {MlImage.IMAGE_FORMAT_NV21, ColorSpaceType.NV21},
- - {MlImage.IMAGE_FORMAT_NV12, ColorSpaceType.NV12},
- - {MlImage.IMAGE_FORMAT_YV12, ColorSpaceType.YV12},
- - {MlImage.IMAGE_FORMAT_YV21, ColorSpaceType.YV21},
- - });
- - }
- -
- - @Test
- - public void createTensorImageFrom_supportedByteBufferMlImage_succeeds() throws IOException {
- - ByteBuffer buffer = ByteBuffer.allocateDirect(6).asReadOnlyBuffer();
- - buffer.rewind();
- - MlImage image = new ByteBufferMlImageBuilder(buffer, 1, 2, imageFormat).build();
- -
- - TensorImage tensorImage = MlImageAdapter.createTensorImageFrom(image);
- -
- - assertThat(tensorImage.getWidth()).isEqualTo(1);
- - assertThat(tensorImage.getHeight()).isEqualTo(2);
- - assertThat(tensorImage.getColorSpaceType()).isEqualTo(colorSpaceType);
- - assertThat(tensorImage.getBuffer().position()).isEqualTo(0);
- - assertThat(tensorImage.getBuffer()).isEquivalentAccordingToCompareTo(buffer);
- - }
- - }
- -
- - @RunWith(ParameterizedRobolectricTestRunner.class)
- - public static final class CreateTensorImageFromUnsupportedByteBufferMlImage
- - extends MlImageAdapterTest {
- - @Parameter(0)
- - @ImageFormat
- - public int imageFormat;
- -
- - @Parameters(name = "imageFormat={0}")
- - public static Collection<Object[]> data() {
- - return Arrays.asList(
- - new Object[][] {
- - {MlImage.IMAGE_FORMAT_RGBA},
- - {MlImage.IMAGE_FORMAT_JPEG},
- - {MlImage.IMAGE_FORMAT_YUV_420_888},
- - {MlImage.IMAGE_FORMAT_UNKNOWN},
- - });
- - }
- -
- - @Test
- - public void createTensorImageFrom_unsupportedByteBufferMlImage_throws() throws IOException {
- - ByteBuffer buffer = ByteBuffer.allocateDirect(6).asReadOnlyBuffer();
- - buffer.rewind();
- - MlImage image = new ByteBufferMlImageBuilder(buffer, 1, 2, imageFormat).build();
- -
- - assertThrows(
- - IllegalArgumentException.class, () -> MlImageAdapter.createTensorImageFrom(image));
- - }
- - }
- -
- - @RunWith(RobolectricTestRunner.class)
- - public static final class General extends MlImageAdapterTest {
- -
- - @Mock Image mediaImageMock;
- -
- - @Before
- - public void setUp() {
- - MockitoAnnotations.openMocks(this);
- - }
- -
- - @Test
- - public void createTensorImageFrom_bitmapMlImage_succeeds() throws IOException {
- - Bitmap bitmap =
- - Bitmap.createBitmap(new int[] {0xff000100, 0xff000001}, 1, 2, Bitmap.Config.ARGB_8888);
- - MlImage image = new BitmapMlImageBuilder(bitmap).build();
- - ByteBuffer expectedBuffer = ByteBuffer.allocateDirect(6);
- - for (byte b : new byte[] {0, 1, 0, 0, 0, 1}) {
- - expectedBuffer.put(b);
- - }
- - expectedBuffer.rewind();
- -
- - TensorImage tensorImage = MlImageAdapter.createTensorImageFrom(image);
- -
- - assertThat(tensorImage.getWidth()).isEqualTo(1);
- - assertThat(tensorImage.getHeight()).isEqualTo(2);
- - assertThat(tensorImage.getBuffer().position()).isEqualTo(0);
- - assertThat(tensorImage.getBuffer()).isEquivalentAccordingToCompareTo(expectedBuffer);
- - }
- -
- - @Test
- - public void createTensorImageFrom_yuv420888MediaImageMlImage_succeeds() throws IOException {
- - setUpMediaImageMock(mediaImageMock, android.graphics.ImageFormat.YUV_420_888, 1, 2);
- - MlImage image = new MediaMlImageBuilder(mediaImageMock).build();
- -
- - TensorImage tensorImage = MlImageAdapter.createTensorImageFrom(image);
- -
- - assertThat(tensorImage.getWidth()).isEqualTo(1);
- - assertThat(tensorImage.getHeight()).isEqualTo(2);
- - assertThat(tensorImage.getColorSpaceType()).isEqualTo(ColorSpaceType.YUV_420_888);
- + @RunWith(ParameterizedRobolectricTestRunner.class)
- + public static final class CreateTensorImageFromSupportedByteBufferMlImage
- + extends MlImageAdapterTest {
- + @Parameter(0)
- + @ImageFormat
- + public int imageFormat;
- +
- + @Parameter(1)
- + public ColorSpaceType colorSpaceType;
- +
- + @Parameters(name = "imageFormat={0}")
- + public static Collection<Object[]> data() {
- + return Arrays.asList(new Object[][] {
- + {MlImage.IMAGE_FORMAT_RGB, ColorSpaceType.RGB},
- + {MlImage.IMAGE_FORMAT_ALPHA, ColorSpaceType.GRAYSCALE},
- + {MlImage.IMAGE_FORMAT_NV21, ColorSpaceType.NV21},
- + {MlImage.IMAGE_FORMAT_NV12, ColorSpaceType.NV12},
- + {MlImage.IMAGE_FORMAT_YV12, ColorSpaceType.YV12},
- + {MlImage.IMAGE_FORMAT_YV21, ColorSpaceType.YV21},
- + });
- + }
- +
- + @Test
- + public void createTensorImageFrom_supportedByteBufferMlImage_succeeds() throws IOException {
- + ByteBuffer buffer = ByteBuffer.allocateDirect(6).asReadOnlyBuffer();
- + buffer.rewind();
- + MlImage image = new ByteBufferMlImageBuilder(buffer, 1, 2, imageFormat).build();
- +
- + TensorImage tensorImage = MlImageAdapter.createTensorImageFrom(image);
- +
- + assertThat(tensorImage.getWidth()).isEqualTo(1);
- + assertThat(tensorImage.getHeight()).isEqualTo(2);
- + assertThat(tensorImage.getColorSpaceType()).isEqualTo(colorSpaceType);
- + assertThat(tensorImage.getBuffer().position()).isEqualTo(0);
- + assertThat(tensorImage.getBuffer()).isEquivalentAccordingToCompareTo(buffer);
- + }
- }
-
- - @Test
- - public void createTensorImageFrom_nonYuv420888MediaImageMlImage_throws() throws IOException {
- - setUpMediaImageMock(mediaImageMock, android.graphics.ImageFormat.YUV_422_888, 1, 2);
- - MlImage image = new MediaMlImageBuilder(mediaImageMock).build();
- -
- - assertThrows(
- - IllegalArgumentException.class, () -> MlImageAdapter.createTensorImageFrom(image));
- + @RunWith(ParameterizedRobolectricTestRunner.class)
- + public static final class CreateTensorImageFromUnsupportedByteBufferMlImage
- + extends MlImageAdapterTest {
- + @Parameter(0)
- + @ImageFormat
- + public int imageFormat;
- +
- + @Parameters(name = "imageFormat={0}")
- + public static Collection<Object[]> data() {
- + return Arrays.asList(new Object[][] {
- + {MlImage.IMAGE_FORMAT_RGBA},
- + {MlImage.IMAGE_FORMAT_JPEG},
- + {MlImage.IMAGE_FORMAT_YUV_420_888},
- + {MlImage.IMAGE_FORMAT_UNKNOWN},
- + });
- + }
- +
- + @Test
- + public void createTensorImageFrom_unsupportedByteBufferMlImage_throws() throws IOException {
- + ByteBuffer buffer = ByteBuffer.allocateDirect(6).asReadOnlyBuffer();
- + buffer.rewind();
- + MlImage image = new ByteBufferMlImageBuilder(buffer, 1, 2, imageFormat).build();
- +
- + assertThrows(IllegalArgumentException.class,
- + () -> MlImageAdapter.createTensorImageFrom(image));
- + }
- }
-
- - private static void setUpMediaImageMock(
- - Image mediaImageMock, int imageFormat, int width, int height) {
- - when(mediaImageMock.getFormat()).thenReturn(imageFormat);
- - when(mediaImageMock.getWidth()).thenReturn(width);
- - when(mediaImageMock.getHeight()).thenReturn(height);
- + @RunWith(RobolectricTestRunner.class)
- + public static final class General extends MlImageAdapterTest {
- + @Mock
- + Image mediaImageMock;
- +
- + @Before
- + public void setUp() {
- + MockitoAnnotations.openMocks(this);
- + }
- +
- + @Test
- + public void createTensorImageFrom_bitmapMlImage_succeeds() throws IOException {
- + Bitmap bitmap = Bitmap.createBitmap(
- + new int[] {0xff000100, 0xff000001}, 1, 2, Bitmap.Config.ARGB_8888);
- + MlImage image = new BitmapMlImageBuilder(bitmap).build();
- + ByteBuffer expectedBuffer = ByteBuffer.allocateDirect(6);
- + for (byte b : new byte[] {0, 1, 0, 0, 0, 1}) {
- + expectedBuffer.put(b);
- + }
- + expectedBuffer.rewind();
- +
- + TensorImage tensorImage = MlImageAdapter.createTensorImageFrom(image);
- +
- + assertThat(tensorImage.getWidth()).isEqualTo(1);
- + assertThat(tensorImage.getHeight()).isEqualTo(2);
- + assertThat(tensorImage.getBuffer().position()).isEqualTo(0);
- + assertThat(tensorImage.getBuffer()).isEquivalentAccordingToCompareTo(expectedBuffer);
- + }
- +
- + @Test
- + public void createTensorImageFrom_yuv420888MediaImageMlImage_succeeds() throws IOException {
- + setUpMediaImageMock(mediaImageMock, android.graphics.ImageFormat.YUV_420_888, 1, 2);
- + MlImage image = new MediaMlImageBuilder(mediaImageMock).build();
- +
- + TensorImage tensorImage = MlImageAdapter.createTensorImageFrom(image);
- +
- + assertThat(tensorImage.getWidth()).isEqualTo(1);
- + assertThat(tensorImage.getHeight()).isEqualTo(2);
- + assertThat(tensorImage.getColorSpaceType()).isEqualTo(ColorSpaceType.YUV_420_888);
- + }
- +
- + @Test
- + public void createTensorImageFrom_nonYuv420888MediaImageMlImage_throws()
- + throws IOException {
- + setUpMediaImageMock(mediaImageMock, android.graphics.ImageFormat.YUV_422_888, 1, 2);
- + MlImage image = new MediaMlImageBuilder(mediaImageMock).build();
- +
- + assertThrows(IllegalArgumentException.class,
- + () -> MlImageAdapter.createTensorImageFrom(image));
- + }
- +
- + private static void setUpMediaImageMock(
- + Image mediaImageMock, int imageFormat, int width, int height) {
- + when(mediaImageMock.getFormat()).thenReturn(imageFormat);
- + when(mediaImageMock.getWidth()).thenReturn(width);
- + when(mediaImageMock.getHeight()).thenReturn(height);
- + }
- }
- - }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/TensorImageInstrumentedTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/TensorImageInstrumentedTest.java
- index ca5f7dc7551be..83b54d0a8db78 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/TensorImageInstrumentedTest.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/TensorImageInstrumentedTest.java
- @@ -15,6 +15,7 @@ limitations under the License.
- package org.tensorflow.lite.support.image;
-
- import static com.google.common.truth.Truth.assertThat;
- +
- import static org.tensorflow.lite.DataType.FLOAT32;
- import static org.tensorflow.lite.DataType.UINT8;
- import static org.tensorflow.lite.support.image.TestImageCreator.createGrayscaleBitmap;
- @@ -23,6 +24,7 @@ import static org.tensorflow.lite.support.image.TestImageCreator.createRgbBitmap
- import static org.tensorflow.lite.support.image.TestImageCreator.createRgbTensorBuffer;
-
- import android.graphics.Bitmap;
- +
- import org.junit.Test;
- import org.junit.runner.RunWith;
- import org.junit.runners.JUnit4;
- @@ -31,110 +33,110 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
-
- @RunWith(JUnit4.class)
- public final class TensorImageInstrumentedTest {
- + /**
- + * Difference between the pair of float and uint8 values. It is used to test the data
- + * conversion.
- + */
- + private static final float DELTA = 0.1f;
- +
- + // Note that parameterized test with android_library_instrumentation_tests is currently not
- + // supported in internally.
- + @Test
- + public void loadAndGetBitmapSucceedsWithFloatBufferFloatImage() {
- + DataType tensorBufferDataType = FLOAT32;
- + DataType tensorImageDataType = FLOAT32;
- + boolean isNormalized = true;
- + ColorSpaceType colorSpaceType = ColorSpaceType.GRAYSCALE;
- +
- + TensorBuffer tensorBuffer =
- + createTensorBuffer(tensorBufferDataType, isNormalized, colorSpaceType, DELTA);
- + TensorImage tensorImage = new TensorImage(tensorImageDataType);
- +
- + tensorImage.load(tensorBuffer, colorSpaceType);
- + Bitmap bitmap = tensorImage.getBitmap();
- +
- + Bitmap expectedBitmap = createBitmap(colorSpaceType);
- + assertThat(bitmap.sameAs(expectedBitmap)).isTrue();
- + }
- +
- + @Test
- + public void loadAndGetBitmapSucceedsWithFloatBufferUINT8Image() {
- + DataType tensorBufferDataType = FLOAT32;
- + DataType tensorImageDataType = UINT8;
- + boolean isNormalized = false;
- + ColorSpaceType colorSpaceType = ColorSpaceType.GRAYSCALE;
- +
- + TensorBuffer tensorBuffer =
- + createTensorBuffer(tensorBufferDataType, isNormalized, colorSpaceType, DELTA);
- + TensorImage tensorImage = new TensorImage(tensorImageDataType);
-
- - /**
- - * Difference between the pair of float and uint8 values. It is used to test the data conversion.
- - */
- - private static final float DELTA = 0.1f;
- -
- - // Note that parameterized test with android_library_instrumentation_tests is currently not
- - // supported in internally.
- - @Test
- - public void loadAndGetBitmapSucceedsWithFloatBufferFloatImage() {
- - DataType tensorBufferDataType = FLOAT32;
- - DataType tensorImageDataType = FLOAT32;
- - boolean isNormalized = true;
- - ColorSpaceType colorSpaceType = ColorSpaceType.GRAYSCALE;
- -
- - TensorBuffer tensorBuffer =
- - createTensorBuffer(tensorBufferDataType, isNormalized, colorSpaceType, DELTA);
- - TensorImage tensorImage = new TensorImage(tensorImageDataType);
- -
- - tensorImage.load(tensorBuffer, colorSpaceType);
- - Bitmap bitmap = tensorImage.getBitmap();
- -
- - Bitmap expectedBitmap = createBitmap(colorSpaceType);
- - assertThat(bitmap.sameAs(expectedBitmap)).isTrue();
- - }
- -
- - @Test
- - public void loadAndGetBitmapSucceedsWithFloatBufferUINT8Image() {
- - DataType tensorBufferDataType = FLOAT32;
- - DataType tensorImageDataType = UINT8;
- - boolean isNormalized = false;
- - ColorSpaceType colorSpaceType = ColorSpaceType.GRAYSCALE;
- -
- - TensorBuffer tensorBuffer =
- - createTensorBuffer(tensorBufferDataType, isNormalized, colorSpaceType, DELTA);
- - TensorImage tensorImage = new TensorImage(tensorImageDataType);
- -
- - tensorImage.load(tensorBuffer, colorSpaceType);
- - Bitmap bitmap = tensorImage.getBitmap();
- -
- - Bitmap expectedBitmap = createBitmap(colorSpaceType);
- - assertThat(bitmap.sameAs(expectedBitmap)).isTrue();
- - }
- -
- - @Test
- - public void loadAndGetBitmapSucceedsWithUINT8BufferFloatImage() {
- - DataType tensorBufferDataType = UINT8;
- - DataType tensorImageDataType = FLOAT32;
- - boolean isNormalized = true;
- - ColorSpaceType colorSpaceType = ColorSpaceType.GRAYSCALE;
- -
- - TensorBuffer tensorBuffer =
- - createTensorBuffer(tensorBufferDataType, isNormalized, colorSpaceType, DELTA);
- - TensorImage tensorImage = new TensorImage(tensorImageDataType);
- -
- - tensorImage.load(tensorBuffer, colorSpaceType);
- - Bitmap bitmap = tensorImage.getBitmap();
- -
- - Bitmap expectedBitmap = createBitmap(colorSpaceType);
- - assertThat(bitmap.sameAs(expectedBitmap)).isTrue();
- - }
- -
- - @Test
- - public void loadAndGetBitmapSucceedsWithUINT8BufferUINT8Image() {
- - DataType tensorBufferDataType = UINT8;
- - DataType tensorImageDataType = UINT8;
- - boolean isNormalized = false;
- - ColorSpaceType colorSpaceType = ColorSpaceType.GRAYSCALE;
- -
- - TensorBuffer tensorBuffer =
- - createTensorBuffer(tensorBufferDataType, isNormalized, colorSpaceType, DELTA);
- - TensorImage tensorImage = new TensorImage(tensorImageDataType);
- -
- - tensorImage.load(tensorBuffer, colorSpaceType);
- - Bitmap bitmap = tensorImage.getBitmap();
- -
- - Bitmap expectedBitmap = createBitmap(colorSpaceType);
- - assertThat(bitmap.sameAs(expectedBitmap)).isTrue();
- - }
- -
- - private static TensorBuffer createTensorBuffer(
- - DataType dataType, boolean isNormalized, ColorSpaceType colorSpaceType, float delta) {
- - switch (colorSpaceType) {
- - case RGB:
- - return createRgbTensorBuffer(dataType, isNormalized, delta);
- - case GRAYSCALE:
- - return createGrayscaleTensorBuffer(dataType, isNormalized, delta);
- - default:
- - break;
- + tensorImage.load(tensorBuffer, colorSpaceType);
- + Bitmap bitmap = tensorImage.getBitmap();
- +
- + Bitmap expectedBitmap = createBitmap(colorSpaceType);
- + assertThat(bitmap.sameAs(expectedBitmap)).isTrue();
- }
- - throw new IllegalArgumentException(
- - "The ColorSpaceType, " + colorSpaceType + ", is unsupported.");
- - }
- -
- - private static Bitmap createBitmap(ColorSpaceType colorSpaceType) {
- - switch (colorSpaceType) {
- - case RGB:
- - return createRgbBitmap();
- - case GRAYSCALE:
- - return createGrayscaleBitmap();
- - default:
- - break;
- +
- + @Test
- + public void loadAndGetBitmapSucceedsWithUINT8BufferFloatImage() {
- + DataType tensorBufferDataType = UINT8;
- + DataType tensorImageDataType = FLOAT32;
- + boolean isNormalized = true;
- + ColorSpaceType colorSpaceType = ColorSpaceType.GRAYSCALE;
- +
- + TensorBuffer tensorBuffer =
- + createTensorBuffer(tensorBufferDataType, isNormalized, colorSpaceType, DELTA);
- + TensorImage tensorImage = new TensorImage(tensorImageDataType);
- +
- + tensorImage.load(tensorBuffer, colorSpaceType);
- + Bitmap bitmap = tensorImage.getBitmap();
- +
- + Bitmap expectedBitmap = createBitmap(colorSpaceType);
- + assertThat(bitmap.sameAs(expectedBitmap)).isTrue();
- + }
- +
- + @Test
- + public void loadAndGetBitmapSucceedsWithUINT8BufferUINT8Image() {
- + DataType tensorBufferDataType = UINT8;
- + DataType tensorImageDataType = UINT8;
- + boolean isNormalized = false;
- + ColorSpaceType colorSpaceType = ColorSpaceType.GRAYSCALE;
- +
- + TensorBuffer tensorBuffer =
- + createTensorBuffer(tensorBufferDataType, isNormalized, colorSpaceType, DELTA);
- + TensorImage tensorImage = new TensorImage(tensorImageDataType);
- +
- + tensorImage.load(tensorBuffer, colorSpaceType);
- + Bitmap bitmap = tensorImage.getBitmap();
- +
- + Bitmap expectedBitmap = createBitmap(colorSpaceType);
- + assertThat(bitmap.sameAs(expectedBitmap)).isTrue();
- + }
- +
- + private static TensorBuffer createTensorBuffer(
- + DataType dataType, boolean isNormalized, ColorSpaceType colorSpaceType, float delta) {
- + switch (colorSpaceType) {
- + case RGB:
- + return createRgbTensorBuffer(dataType, isNormalized, delta);
- + case GRAYSCALE:
- + return createGrayscaleTensorBuffer(dataType, isNormalized, delta);
- + default:
- + break;
- + }
- + throw new IllegalArgumentException(
- + "The ColorSpaceType, " + colorSpaceType + ", is unsupported.");
- + }
- +
- + private static Bitmap createBitmap(ColorSpaceType colorSpaceType) {
- + switch (colorSpaceType) {
- + case RGB:
- + return createRgbBitmap();
- + case GRAYSCALE:
- + return createGrayscaleBitmap();
- + default:
- + break;
- + }
- + throw new IllegalArgumentException(
- + "The ColorSpaceType, " + colorSpaceType + ", is unsupported.");
- }
- - throw new IllegalArgumentException(
- - "The ColorSpaceType, " + colorSpaceType + ", is unsupported.");
- - }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/TensorImageTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/TensorImageTest.java
- index f27edef4de779..b3130f4f2073c 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/TensorImageTest.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/TensorImageTest.java
- @@ -16,6 +16,7 @@ limitations under the License.
- package org.tensorflow.lite.support.image;
-
- import static com.google.common.truth.Truth.assertThat;
- +
- import static org.junit.Assert.assertArrayEquals;
- import static org.junit.Assert.assertThrows;
- import static org.mockito.Mockito.when;
- @@ -31,9 +32,7 @@ import android.graphics.Bitmap.Config;
- import android.graphics.Color;
- import android.graphics.ImageFormat;
- import android.media.Image;
- -import java.nio.ByteBuffer;
- -import java.util.Arrays;
- -import java.util.Collection;
- +
- import org.junit.Before;
- import org.junit.Test;
- import org.junit.runner.RunWith;
- @@ -48,713 +47,689 @@ import org.robolectric.RobolectricTestRunner;
- import org.tensorflow.lite.DataType;
- import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
-
- +import java.nio.ByteBuffer;
- +import java.util.Arrays;
- +import java.util.Collection;
- +
- /** Tests of {@link org.tensorflow.lite.support.image.TensorImage}. */
- @RunWith(Suite.class)
- -@SuiteClasses({
- - TensorImageTest.General.class,
- - TensorImageTest.LoadTensorBufferWithRgbAndGrayscale.class,
- - TensorImageTest.LoadTensorBufferWithInvalidShapeTest.class,
- - TensorImageTest.LoadTensorBufferWithYUV.class,
- - TensorImageTest.LoadTensorBufferWithImageProperties.class
- -})
- +@SuiteClasses(
- + {TensorImageTest.General.class, TensorImageTest.LoadTensorBufferWithRgbAndGrayscale.class,
- + TensorImageTest.LoadTensorBufferWithInvalidShapeTest.class,
- + TensorImageTest.LoadTensorBufferWithYUV.class,
- + TensorImageTest.LoadTensorBufferWithImageProperties.class})
- public class TensorImageTest {
- -
- - @RunWith(RobolectricTestRunner.class)
- - public static final class General extends TensorImageTest {
- -
- - private static final Bitmap exampleBitmap = createExampleBitmap();
- - private static final float[] exampleFloatPixels = createExampleFloatPixels();
- - private static final int[] exampleUint8Pixels = createExampleUint8Pixels();
- -
- - private static final int EXAMPLE_WIDTH = 5;
- - private static final int EXAMPLE_HEIGHT = 10;
- - private static final int EXAMPLE_NUM_PIXELS = EXAMPLE_HEIGHT * EXAMPLE_WIDTH;
- - private static final int EXAMPLE_NUM_CHANNELS = 3;
- - private static final int[] EXAMPLE_SHAPE = {
- - EXAMPLE_HEIGHT, EXAMPLE_WIDTH, EXAMPLE_NUM_CHANNELS
- - };
- - private static final float MEAN = 127.5f;
- - private static final float STDDEV = 127.5f;
- -
- - @Mock Image imageMock;
- -
- - @Before
- - public void setUp() {
- - MockitoAnnotations.initMocks(this);
- - }
- -
- - @Test
- - public void defaultConstructorCreatesUint8TensorImage() {
- - TensorImage image = new TensorImage();
- - assertThat(image.getDataType()).isEqualTo(UINT8);
- - }
- -
- - @Test
- - public void createFromSucceedsWithUint8TensorImage() {
- - TensorImage uint8Image = new TensorImage(UINT8);
- - uint8Image.load(new int[] {1, 2, 3, 4, -5, 600}, new int[] {2, 1, 3});
- -
- - TensorImage floatImage = TensorImage.createFrom(uint8Image, FLOAT32);
- - float[] pixels = floatImage.getTensorBuffer().getFloatArray();
- - assertThat(pixels).isEqualTo(new float[] {1.0f, 2.0f, 3.0f, 4.0f, 0.0f, 255.0f});
- - }
- -
- - @Test
- - public void createFromSucceedsWithFloatTensorImage() {
- - TensorImage floatImage = new TensorImage(FLOAT32);
- - floatImage.load(new float[] {1, 2.495f, 3.5f, 4.5f, -5, 600}, new int[] {2, 1, 3});
- -
- - TensorImage uint8Image = TensorImage.createFrom(floatImage, UINT8);
- - int[] pixels = uint8Image.getTensorBuffer().getIntArray();
- - assertThat(pixels).isEqualTo(new int[] {1, 2, 3, 4, 0, 255});
- - }
- -
- - @Test
- - public void loadBitmapSucceedsWithUint8TensorImage() {
- - Bitmap rgbBitmap = createRgbBitmap();
- - TensorBuffer rgbTensorBuffer = createRgbTensorBuffer(UINT8, false, 0.0f);
- - TensorImage uint8Image = new TensorImage(UINT8);
- -
- - uint8Image.load(rgbBitmap);
- - assertThat(uint8Image.getBitmap().sameAs(rgbBitmap)).isTrue();
- - assertEqualTensorBuffers(uint8Image.getTensorBuffer(), rgbTensorBuffer);
- - assertThat(uint8Image.getDataType()).isEqualTo(UINT8);
- - }
- -
- - @Test
- - public void loadBitmapSucceedsWithFloatTensorImage() {
- - Bitmap rgbBitmap = createRgbBitmap();
- - TensorBuffer rgbTensorBuffer = createRgbTensorBuffer(FLOAT32, false, 0.0f);
- - TensorImage floatImage = new TensorImage(FLOAT32);
- -
- - floatImage.load(rgbBitmap);
- - assertThat(floatImage.getBitmap().sameAs(rgbBitmap)).isTrue();
- - assertEqualTensorBuffers(floatImage.getTensorBuffer(), rgbTensorBuffer);
- - assertThat(floatImage.getDataType()).isEqualTo(FLOAT32);
- - }
- -
- - @Test
- - public void loadFloatArrayWithUint8TensorImage() {
- - TensorImage uint8Image = new TensorImage(UINT8);
- -
- - uint8Image.load(exampleFloatPixels, EXAMPLE_SHAPE);
- - assertThat(uint8Image.getBitmap()).isNotNull();
- - assertThat(uint8Image.getTensorBuffer().getFloatArray())
- - .isEqualTo(
- - new float
- - [exampleFloatPixels
- - .length]); // All zero because of normalization and casting when loading.
- - }
- -
- - @Test
- - public void loadFloatArrayWithFloatTensorImage() {
- - TensorImage floatImage = new TensorImage(FLOAT32);
- -
- - floatImage.load(exampleFloatPixels, EXAMPLE_SHAPE);
- - assertThat(floatImage.getTensorBuffer().getFloatArray()).isEqualTo(exampleFloatPixels);
- - }
- -
- - @Test
- - public void loadUint8ArrayWithUint8TensorImage() {
- - TensorImage uint8Image = new TensorImage(UINT8);
- -
- - uint8Image.load(exampleUint8Pixels, EXAMPLE_SHAPE);
- - assertThat(uint8Image.getBitmap().sameAs(exampleBitmap)).isTrue();
- - assertThat(uint8Image.getTensorBuffer().getIntArray()).isEqualTo(exampleUint8Pixels);
- - }
- -
- - @Test
- - public void loadUint8ArrayWithFloatTensorImage() {
- - TensorImage floatImage = new TensorImage(FLOAT32);
- -
- - floatImage.load(exampleUint8Pixels, EXAMPLE_SHAPE);
- - assertThat(floatImage.getTensorBuffer().getIntArray()).isEqualTo(exampleUint8Pixels);
- - }
- -
- - @Test
- - public void loadTensorBufferWithUint8TensorImage() {
- - TensorImage uint8Image = new TensorImage(UINT8);
- -
- - uint8Image.load(exampleBitmap);
- - TensorBuffer buffer = uint8Image.getTensorBuffer();
- - uint8Image.load(buffer);
- - assertThat(uint8Image.getBitmap().sameAs(exampleBitmap)).isTrue();
- - }
- -
- - @Test
- - public void loadTensorBufferWithFloatTensorImage() {
- - TensorImage floatImage = new TensorImage(FLOAT32);
- -
- - floatImage.load(exampleBitmap);
- - TensorBuffer buffer = floatImage.getTensorBuffer();
- - floatImage.load(buffer);
- - assertThat(floatImage.getTensorBuffer().getIntArray()).isEqualTo(exampleUint8Pixels);
- - }
- -
- - @Test
- - public void loadAndGetMediaImageSucceedsWithYuv420888Format() {
- - setUpImageMock(imageMock, ImageFormat.YUV_420_888);
- - TensorImage tensorImage = new TensorImage(UINT8);
- -
- - tensorImage.load(imageMock);
- - Image imageReturned = tensorImage.getMediaImage();
- -
- - assertThat(imageReturned).isEqualTo(imageMock);
- - }
- -
- - @Test
- - public void loadMediaImageFailsWithNonYuv420888Format() {
- - setUpImageMock(imageMock, ImageFormat.YUV_422_888);
- - TensorImage tensorImage = new TensorImage(UINT8);
- -
- - IllegalArgumentException exception =
- - assertThrows(IllegalArgumentException.class, () -> tensorImage.load(imageMock));
- - assertThat(exception).hasMessageThat().contains("Only supports loading YUV_420_888 Image.");
- - }
- -
- - @Test
- - public void getBitmapWithUint8TensorImage() {
- - TensorImage uint8Image = new TensorImage(UINT8);
- -
- - uint8Image.load(exampleBitmap);
- - assertThat(uint8Image.getBitmap().sameAs(exampleBitmap)).isTrue();
- - // Also check zero copy is effective here (input and output are references of the same
- - // object).
- - assertThat(uint8Image.getBitmap()).isSameInstanceAs(exampleBitmap);
- - // Also check we don't create new Bitmap only with reading operations.
- - assertThat(uint8Image.getBuffer().limit())
- - .isEqualTo(EXAMPLE_NUM_PIXELS * EXAMPLE_NUM_CHANNELS);
- - assertThat(uint8Image.getBitmap()).isSameInstanceAs(exampleBitmap);
- -
- - uint8Image.load(exampleUint8Pixels, EXAMPLE_SHAPE);
- - assertThat(uint8Image.getBitmap()).isNotSameInstanceAs(exampleBitmap);
- - }
- -
- - @Test
- - public void getBitmapWithFloatTensorImage() {
- - TensorImage floatImage = new TensorImage(FLOAT32);
- -
- - floatImage.load(exampleBitmap);
- - assertThat(floatImage.getBitmap()).isSameInstanceAs(exampleBitmap);
- - }
- -
- - @Test
- - public void getBitmapWithEmptyTensorImage() {
- - TensorImage uint8Image = new TensorImage(UINT8);
- -
- - assertThrows(IllegalStateException.class, uint8Image::getBitmap);
- - }
- -
- - @Test
- - public void getMediaImageFailsWithBackedBitmap() {
- - TensorImage tensorImage = TensorImage.fromBitmap(exampleBitmap);
- -
- - UnsupportedOperationException exception =
- - assertThrows(UnsupportedOperationException.class, () -> tensorImage.getMediaImage());
- - assertThat(exception)
- - .hasMessageThat()
- - .contains("Converting from Bitmap to android.media.Image is unsupported.");
- - }
- -
- - @Test
- - public void getMediaImageFailsWithBackedTensorBuffer() {
- - TensorImage tensorImage = new TensorImage(UINT8);
- - tensorImage.load(exampleFloatPixels, EXAMPLE_SHAPE);
- -
- - UnsupportedOperationException exception =
- - assertThrows(UnsupportedOperationException.class, () -> tensorImage.getMediaImage());
- - assertThat(exception)
- - .hasMessageThat()
- - .contains("Converting from TensorBuffer to android.media.Image is unsupported.");
- - }
- -
- - @Test
- - public void getShapeOfInternalBitmapShouldSuccess() {
- - Bitmap bitmap = Bitmap.createBitmap(300, 400, Config.ARGB_8888);
- - TensorImage image = TensorImage.fromBitmap(bitmap);
- -
- - int width = image.getWidth();
- - int height = image.getHeight();
- -
- - assertThat(width).isEqualTo(300);
- - assertThat(height).isEqualTo(400);
- - }
- -
- - @Test
- - public void getShapeOfInternalTensorBufferShouldSuccess() {
- - TensorBuffer buffer = TensorBuffer.createFixedSize(new int[] {1, 400, 300, 3}, UINT8);
- - TensorImage image = new TensorImage();
- - image.load(buffer);
- -
- - int width = image.getWidth();
- - int height = image.getHeight();
- -
- - assertThat(width).isEqualTo(300);
- - assertThat(height).isEqualTo(400);
- - }
- -
- - @Test
- - public void getShapeOfNullImageShouldThrow() {
- - TensorImage image = new TensorImage();
- -
- - assertThrows(IllegalStateException.class, image::getHeight);
- - }
- -
- - @Test
- - public void getShapeOfACorruptedBufferShouldThrowRatherThanCrash() {
- - int[] data = new int[] {1, 2, 3, 4, 5, 6};
- - TensorBuffer buffer = TensorBuffer.createDynamic(UINT8);
- - buffer.loadArray(data, new int[] {1, 1, 2, 3});
- - TensorImage image = new TensorImage();
- - image.load(buffer);
- - // Reload data but with an invalid shape, which leads to `buffer` corrupted.
- - int[] newData = new int[] {1, 2, 3};
- - buffer.loadArray(newData, new int[] {1, 1, 1, 3});
- -
- - assertThrows(IllegalArgumentException.class, image::getHeight);
- - }
- -
- - @Test
- - public void getColorSpaceTypeSucceedsWithBitmapARGB_8888() {
- - Bitmap rgbBitmap = createRgbBitmap();
- - TensorImage tensorImage = TensorImage.fromBitmap(rgbBitmap);
- -
- - assertThat(tensorImage.getColorSpaceType()).isEqualTo(ColorSpaceType.RGB);
- - }
- -
- - @Test
- - public void getColorSpaceTypeSucceedsWithRgbTensorBuffer() {
- - TensorBuffer rgbBuffer = createRgbTensorBuffer(UINT8, false);
- - TensorImage tensorImage = new TensorImage();
- - tensorImage.load(rgbBuffer);
- -
- - assertThat(tensorImage.getColorSpaceType()).isEqualTo(ColorSpaceType.RGB);
- - }
- -
- - @Test
- - public void getColorSpaceTypeSucceedsWithGrayscaleTensorBuffer() {
- - TensorBuffer grayBuffer = createGrayscaleTensorBuffer(UINT8, false);
- - TensorImage tensorImage = new TensorImage();
- - tensorImage.load(grayBuffer, ColorSpaceType.GRAYSCALE);
- -
- - assertThat(tensorImage.getColorSpaceType()).isEqualTo(ColorSpaceType.GRAYSCALE);
- - }
- -
- - @Test
- - public void getColorSpaceTypeSucceedsWithRepeatedLoading() {
- - TensorBuffer grayBuffer = createGrayscaleTensorBuffer(UINT8, false);
- - TensorBuffer rgbBuffer = createRgbTensorBuffer(UINT8, false);
- - Bitmap rgbBitmap = createRgbBitmap();
- - TensorImage tensorImage = new TensorImage();
- -
- - tensorImage.load(rgbBuffer);
- - assertThat(tensorImage.getColorSpaceType()).isEqualTo(ColorSpaceType.RGB);
- - tensorImage.load(grayBuffer, ColorSpaceType.GRAYSCALE);
- - assertThat(tensorImage.getColorSpaceType()).isEqualTo(ColorSpaceType.GRAYSCALE);
- - tensorImage.load(rgbBitmap);
- - assertThat(tensorImage.getColorSpaceType()).isEqualTo(ColorSpaceType.RGB);
- - }
- -
- - @Test
- - public void getColorSpaceTypeFailsWhenNoImageHasBeenLoaded() {
- - TensorImage tensorImage = new TensorImage();
- -
- - IllegalStateException exception =
- - assertThrows(IllegalStateException.class, tensorImage::getColorSpaceType);
- - assertThat(exception).hasMessageThat().contains("No image has been loaded yet.");
- - }
- -
- - /**
- - * Creates an example bit map, which is a 10x10 ARGB bitmap and pixels are set by: pixel[i] =
- - * {A: 0, B: i + 2, G: i + 1, G: i}, where i is the flatten index
- - */
- - private static Bitmap createExampleBitmap() {
- - int[] colors = new int[EXAMPLE_NUM_PIXELS];
- - for (int i = 0; i < EXAMPLE_NUM_PIXELS; i++) {
- - colors[i] = Color.rgb(i, i + 1, i + 2);
- - }
- -
- - return Bitmap.createBitmap(colors, EXAMPLE_WIDTH, EXAMPLE_HEIGHT, Bitmap.Config.ARGB_8888);
- - }
- -
- - private static float[] createExampleFloatPixels() {
- - float[] pixels = new float[EXAMPLE_NUM_PIXELS * EXAMPLE_NUM_CHANNELS];
- - for (int i = 0, j = 0; i < EXAMPLE_NUM_PIXELS; i++) {
- - pixels[j++] = (i - MEAN) / STDDEV;
- - pixels[j++] = (i + 1 - MEAN) / STDDEV;
- - pixels[j++] = (i + 2 - MEAN) / STDDEV;
- - }
- - return pixels;
- - }
- -
- - private static int[] createExampleUint8Pixels() {
- - int[] pixels = new int[EXAMPLE_NUM_PIXELS * EXAMPLE_NUM_CHANNELS];
- - for (int i = 0, j = 0; i < EXAMPLE_NUM_PIXELS; i++) {
- - pixels[j++] = i;
- - pixels[j++] = i + 1;
- - pixels[j++] = i + 2;
- - }
- - return pixels;
- - }
- - }
- -
- - /** Parameterized tests for loading TensorBuffers with RGB and Grayscale images. */
- - @RunWith(ParameterizedRobolectricTestRunner.class)
- - public static final class LoadTensorBufferWithRgbAndGrayscale extends TensorImageTest {
- -
- - /**
- - * Difference between the pair of float and uint8 values. It is used to test the data
- - * conversion.
- - */
- - private static final float DELTA = 0.1f;
- -
- - /** The data type that used to create the TensorBuffer. */
- - @Parameter(0)
- - public DataType tensorBufferDataType;
- -
- - /** Indicates whether the shape is in the normalized form of (1, h, w, 3). */
- - @Parameter(1)
- - public boolean isNormalized;
- -
- - /** The color space type of the TensorBuffer. */
- - @Parameter(2)
- - public ColorSpaceType colorSpaceType;
- -
- - /** The data type that used to create the TensorImage. */
- - @Parameter(3)
- - public DataType tensorImageDataType;
- -
- - @Parameters(
- - name =
- - "tensorBufferDataType={0}; isNormalized={1}; colorSpaceType={2};"
- - + " tensorImageDataType={3}")
- - public static Collection<Object[]> data() {
- - return Arrays.asList(
- - new Object[][] {
- - {FLOAT32, true, ColorSpaceType.RGB, FLOAT32},
- - {FLOAT32, false, ColorSpaceType.RGB, UINT8},
- - {UINT8, true, ColorSpaceType.RGB, FLOAT32},
- - {UINT8, false, ColorSpaceType.RGB, UINT8},
- - });
- - }
- -
- - @Test
- - public void loadAndGetBitmapSucceedsWithTensorBufferAndColorSpaceType() {
- - TensorBuffer tensorBuffer =
- - createTensorBuffer(tensorBufferDataType, isNormalized, colorSpaceType, DELTA);
- - TensorImage tensorImage = new TensorImage(tensorImageDataType);
- -
- - tensorImage.load(tensorBuffer, colorSpaceType);
- - Bitmap bitmap = tensorImage.getBitmap();
- -
- - Bitmap expectedBitmap = createBitmap(colorSpaceType);
- - assertThat(bitmap.sameAs(expectedBitmap)).isTrue();
- - }
- -
- - @Test
- - public void loadAndGetTensorBufferSucceedsWithTensorBufferAndColorSpaceType() {
- - TensorBuffer tensorBuffer =
- - createTensorBuffer(tensorBufferDataType, isNormalized, colorSpaceType, DELTA);
- - TensorImage tensorImage = new TensorImage(tensorImageDataType);
- -
- - tensorImage.load(tensorBuffer, colorSpaceType);
- - TensorBuffer buffer = tensorImage.getTensorBuffer();
- -
- - // If tensorBufferDataType is UINT8, expectedTensorBuffer should not contain delta.
- - float expectedResidual = tensorBufferDataType == UINT8 ? 0.f : DELTA;
- - TensorBuffer expectedTensorBuffer =
- - createTensorBuffer(tensorImageDataType, isNormalized, colorSpaceType, expectedResidual);
- - assertEqualTensorBuffers(buffer, expectedTensorBuffer);
- - }
- -
- - private static TensorBuffer createTensorBuffer(
- - DataType dataType, boolean isNormalized, ColorSpaceType colorSpaceType, float delta) {
- - switch (colorSpaceType) {
- - case RGB:
- - return createRgbTensorBuffer(dataType, isNormalized, delta);
- - case GRAYSCALE:
- - return createGrayscaleTensorBuffer(dataType, isNormalized, delta);
- - default:
- - break;
- - }
- - throw new IllegalArgumentException(
- - "The ColorSpaceType, " + colorSpaceType + ", is unsupported.");
- - }
- -
- - private static Bitmap createBitmap(ColorSpaceType colorSpaceType) {
- - switch (colorSpaceType) {
- - case RGB:
- - return createRgbBitmap();
- - case GRAYSCALE:
- - return createGrayscaleBitmap();
- - default:
- - break;
- - }
- - throw new IllegalArgumentException(
- - "The ColorSpaceType, " + colorSpaceType + ", is unsupported.");
- - }
- - }
- -
- - /** Parameterized tests for loading TensorBuffers with YUV images. */
- - @RunWith(ParameterizedRobolectricTestRunner.class)
- - public static final class LoadTensorBufferWithYUV extends TensorImageTest {
- -
- - private static final int HEIGHT = 2;
- - private static final int WIDTH = 3;
- -
- - @Parameter(0)
- - public ColorSpaceType colorSpaceType;
- -
- - @Parameters(name = "colorSpaceType={0}")
- - public static Collection<Object[]> data() {
- - return Arrays.asList(
- - new Object[][] {
- - {ColorSpaceType.NV12},
- - {ColorSpaceType.NV21},
- - {ColorSpaceType.YV12},
- - {ColorSpaceType.YV21},
- - });
- - }
- -
- - @Test
- - public void loadTensorBufferWithColorSpaceShouldFail() {
- - TensorImage tensorImage = new TensorImage();
- -
- - IllegalArgumentException exception =
- - assertThrows(
- - IllegalArgumentException.class,
- - () -> tensorImage.load(TensorBuffer.createDynamic(DataType.FLOAT32), colorSpaceType));
- - assertThat(exception)
- - .hasMessageThat()
- - .contains(
- - "Only ColorSpaceType.RGB and ColorSpaceType.GRAYSCALE are supported. Use"
- - + " `load(TensorBuffer, ImageProperties)` for other color space types.");
- - }
- -
- - @Test
- - public void loadTensorBufferAndGetBitmapShouldFail() {
- - int[] data = new int[colorSpaceType.getNumElements(HEIGHT, WIDTH)];
- - TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.FLOAT32);
- - tensorBuffer.loadArray(data, new int[] {data.length});
- -
- - ImageProperties imageProperties =
- - ImageProperties.builder()
- - .setHeight(HEIGHT)
- - .setWidth(WIDTH)
- - .setColorSpaceType(colorSpaceType)
- - .build();
- -
- - TensorImage tensorImage = new TensorImage(DataType.FLOAT32);
- - tensorImage.load(tensorBuffer, imageProperties);
- -
- - UnsupportedOperationException exception =
- - assertThrows(UnsupportedOperationException.class, () -> tensorImage.getBitmap());
- - assertThat(exception)
- - .hasMessageThat()
- - .contains(
- - "convertTensorBufferToBitmap() is unsupported for the color space type "
- - + colorSpaceType.name());
- - }
- - }
- -
- - /** Parameterized tests for loading TensorBuffers with ImageProperties. */
- - @RunWith(ParameterizedRobolectricTestRunner.class)
- - public static final class LoadTensorBufferWithImageProperties extends TensorImageTest {
- -
- - private static final int HEIGHT = 2;
- - private static final int WIDTH = 3;
- - private static final int WRONG_WIDTH = 10;
- -
- - @Parameter(0)
- - public ColorSpaceType colorSpaceType;
- -
- - @Parameters(name = "colorSpaceType={0}")
- - public static Collection<Object[]> data() {
- - return Arrays.asList(
- - new Object[][] {
- - {ColorSpaceType.RGB},
- - {ColorSpaceType.GRAYSCALE},
- - {ColorSpaceType.NV12},
- - {ColorSpaceType.NV21},
- - {ColorSpaceType.YV12},
- - {ColorSpaceType.YV21},
- - });
- - }
- -
- - @Test
- - public void loadAndGetTensorBufferShouldSucceedWithCorrectProperties() {
- - int[] data = new int[colorSpaceType.getNumElements(HEIGHT, WIDTH)];
- - TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.FLOAT32);
- - tensorBuffer.loadArray(data, new int[] {data.length});
- -
- - ImageProperties imageProperties =
- - ImageProperties.builder()
- - .setHeight(HEIGHT)
- - .setWidth(WIDTH)
- - .setColorSpaceType(colorSpaceType)
- - .build();
- -
- - TensorImage tensorImage = new TensorImage(DataType.FLOAT32);
- - tensorImage.load(tensorBuffer, imageProperties);
- -
- - assertEqualTensorBuffers(tensorImage.getTensorBuffer(), tensorBuffer);
- - }
- -
- - @Test
- - public void loadAndGetTensorBufferShouldSucceedWithLargerBuffer() {
- - // Should allow buffer to be greater than the size specified by height and width.
- - int moreElements = 1;
- - int[] data = new int[colorSpaceType.getNumElements(HEIGHT, WIDTH) + moreElements];
- - TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.FLOAT32);
- - tensorBuffer.loadArray(data, new int[] {data.length});
- -
- - ImageProperties imageProperties =
- - ImageProperties.builder()
- - .setHeight(HEIGHT)
- - .setWidth(WIDTH)
- - .setColorSpaceType(colorSpaceType)
- - .build();
- -
- - TensorImage tensorImage = new TensorImage(DataType.FLOAT32);
- - tensorImage.load(tensorBuffer, imageProperties);
- -
- - assertEqualTensorBuffers(tensorImage.getTensorBuffer(), tensorBuffer);
- - }
- -
- - @Test
- - public void loadAndGetByteBufferShouldSucceedWithCorrectProperties() {
- - ByteBuffer byteBuffer = ByteBuffer.allocate(colorSpaceType.getNumElements(HEIGHT, WIDTH));
- -
- - ImageProperties imageProperties =
- - ImageProperties.builder()
- - .setHeight(HEIGHT)
- - .setWidth(WIDTH)
- - .setColorSpaceType(colorSpaceType)
- - .build();
- -
- - TensorImage tensorImage = new TensorImage(DataType.UINT8);
- - tensorImage.load(byteBuffer, imageProperties);
- -
- - assertEqualByteBuffers(tensorImage.getBuffer(), byteBuffer);
- - }
- -
- - @Test
- - public void loadTensorBufferWithShouldFailWithWrongImageShape() {
- - int[] data = new int[colorSpaceType.getNumElements(HEIGHT, WIDTH)];
- - TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.FLOAT32);
- - tensorBuffer.loadArray(data, new int[] {data.length});
- -
- - ImageProperties imageProperties =
- - ImageProperties.builder()
- - .setHeight(HEIGHT)
- - .setWidth(WRONG_WIDTH)
- - .setColorSpaceType(colorSpaceType)
- - .build();
- -
- - TensorImage tensorImage = new TensorImage(DataType.FLOAT32);
- -
- - IllegalArgumentException exception =
- - assertThrows(
- - IllegalArgumentException.class,
- - () -> tensorImage.load(tensorBuffer, imageProperties));
- - assertThat(exception)
- - .hasMessageThat()
- - .contains(
- - String.format(
- - "The given number of elements (%d) does not match the image (%s) in %d x %d. The"
- - + " expected number of elements should be at least %d.",
- - data.length,
- - colorSpaceType.name(),
- - HEIGHT,
- - WRONG_WIDTH,
- - colorSpaceType.getNumElements(HEIGHT, WRONG_WIDTH)));
- - }
- -
- - @Test
- - public void getShapeOfInternalTensorBufferShouldSuccess() {
- - int[] data = new int[colorSpaceType.getNumElements(HEIGHT, WIDTH)];
- - TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.FLOAT32);
- - tensorBuffer.loadArray(data, new int[] {data.length});
- -
- - ImageProperties imageProperties =
- - ImageProperties.builder()
- - .setHeight(HEIGHT)
- - .setWidth(WIDTH)
- - .setColorSpaceType(colorSpaceType)
- - .build();
- -
- - TensorImage tensorImage = new TensorImage(DataType.FLOAT32);
- - tensorImage.load(tensorBuffer, imageProperties);
- -
- - assertThat(tensorImage.getWidth()).isEqualTo(WIDTH);
- - assertThat(tensorImage.getHeight()).isEqualTo(HEIGHT);
- - }
- - }
- -
- - /** Parameterized tests for loading TensorBuffer with invalid shapes. */
- - @RunWith(ParameterizedRobolectricTestRunner.class)
- - public static final class LoadTensorBufferWithInvalidShapeTest extends TensorImageTest {
- -
- - private static final String RGB_ASSERT_SHAPE_MESSAGE =
- - "The shape of a RGB image should be (h, w, c) or (1, h, w, c), and channels"
- - + " representing R, G, B in order. The provided image shape is ";
- - private static final String GRAYSCALE_ASSERT_SHAPE_MESSAGE =
- - "The shape of a grayscale image should be (h, w) or (1, h, w, 1). The provided image"
- - + " shape is ";
- -
- - @Parameter(0)
- - public ColorSpaceType colorSpaceType;
- -
- - /** The shape that does not match the colorSpaceType. */
- - @Parameter(1)
- - public int[] invalidShape;
- -
- - @Parameter(2)
- - public String errorMessage;
- -
- - @Parameters(name = "colorSpaceType={0}; invalidShape={1}")
- - public static Collection<Object[]> data() {
- - return Arrays.asList(
- - new Object[][] {
- - {ColorSpaceType.RGB, new int[] {2, 10, 20, 3}, RGB_ASSERT_SHAPE_MESSAGE},
- - {ColorSpaceType.RGB, new int[] {1, 10, 20, 3, 4}, RGB_ASSERT_SHAPE_MESSAGE},
- - {ColorSpaceType.RGB, new int[] {1, 10, 20, 5}, RGB_ASSERT_SHAPE_MESSAGE},
- - {ColorSpaceType.RGB, new int[] {1, 10, 20}, RGB_ASSERT_SHAPE_MESSAGE},
- - {ColorSpaceType.RGB, new int[] {10, 20, 3, 4}, RGB_ASSERT_SHAPE_MESSAGE},
- - {ColorSpaceType.RGB, new int[] {10, 20, 5}, RGB_ASSERT_SHAPE_MESSAGE},
- - {ColorSpaceType.RGB, new int[] {10, 20}, RGB_ASSERT_SHAPE_MESSAGE},
- - {ColorSpaceType.GRAYSCALE, new int[] {2, 10, 20}, GRAYSCALE_ASSERT_SHAPE_MESSAGE},
- - {ColorSpaceType.GRAYSCALE, new int[] {1, 10, 20, 3}, GRAYSCALE_ASSERT_SHAPE_MESSAGE},
- - {ColorSpaceType.GRAYSCALE, new int[] {10, 20, 4}, GRAYSCALE_ASSERT_SHAPE_MESSAGE},
- - {ColorSpaceType.GRAYSCALE, new int[] {10}, GRAYSCALE_ASSERT_SHAPE_MESSAGE},
- - });
- - }
- -
- - @Test
- - public void loadTensorBufferWithInvalidShape() {
- - TensorBuffer tensorBuffer = TensorBuffer.createFixedSize(invalidShape, UINT8);
- - TensorImage tensorImage = new TensorImage();
- -
- - IllegalArgumentException exception =
- - assertThrows(
- - IllegalArgumentException.class, () -> tensorImage.load(tensorBuffer, colorSpaceType));
- - assertThat(exception).hasMessageThat().contains(errorMessage + Arrays.toString(invalidShape));
- + @RunWith(RobolectricTestRunner.class)
- + public static final class General extends TensorImageTest {
- + private static final Bitmap exampleBitmap = createExampleBitmap();
- + private static final float[] exampleFloatPixels = createExampleFloatPixels();
- + private static final int[] exampleUint8Pixels = createExampleUint8Pixels();
- +
- + private static final int EXAMPLE_WIDTH = 5;
- + private static final int EXAMPLE_HEIGHT = 10;
- + private static final int EXAMPLE_NUM_PIXELS = EXAMPLE_HEIGHT * EXAMPLE_WIDTH;
- + private static final int EXAMPLE_NUM_CHANNELS = 3;
- + private static final int[] EXAMPLE_SHAPE = {
- + EXAMPLE_HEIGHT, EXAMPLE_WIDTH, EXAMPLE_NUM_CHANNELS};
- + private static final float MEAN = 127.5f;
- + private static final float STDDEV = 127.5f;
- +
- + @Mock
- + Image imageMock;
- +
- + @Before
- + public void setUp() {
- + MockitoAnnotations.initMocks(this);
- + }
- +
- + @Test
- + public void defaultConstructorCreatesUint8TensorImage() {
- + TensorImage image = new TensorImage();
- + assertThat(image.getDataType()).isEqualTo(UINT8);
- + }
- +
- + @Test
- + public void createFromSucceedsWithUint8TensorImage() {
- + TensorImage uint8Image = new TensorImage(UINT8);
- + uint8Image.load(new int[] {1, 2, 3, 4, -5, 600}, new int[] {2, 1, 3});
- +
- + TensorImage floatImage = TensorImage.createFrom(uint8Image, FLOAT32);
- + float[] pixels = floatImage.getTensorBuffer().getFloatArray();
- + assertThat(pixels).isEqualTo(new float[] {1.0f, 2.0f, 3.0f, 4.0f, 0.0f, 255.0f});
- + }
- +
- + @Test
- + public void createFromSucceedsWithFloatTensorImage() {
- + TensorImage floatImage = new TensorImage(FLOAT32);
- + floatImage.load(new float[] {1, 2.495f, 3.5f, 4.5f, -5, 600}, new int[] {2, 1, 3});
- +
- + TensorImage uint8Image = TensorImage.createFrom(floatImage, UINT8);
- + int[] pixels = uint8Image.getTensorBuffer().getIntArray();
- + assertThat(pixels).isEqualTo(new int[] {1, 2, 3, 4, 0, 255});
- + }
- +
- + @Test
- + public void loadBitmapSucceedsWithUint8TensorImage() {
- + Bitmap rgbBitmap = createRgbBitmap();
- + TensorBuffer rgbTensorBuffer = createRgbTensorBuffer(UINT8, false, 0.0f);
- + TensorImage uint8Image = new TensorImage(UINT8);
- +
- + uint8Image.load(rgbBitmap);
- + assertThat(uint8Image.getBitmap().sameAs(rgbBitmap)).isTrue();
- + assertEqualTensorBuffers(uint8Image.getTensorBuffer(), rgbTensorBuffer);
- + assertThat(uint8Image.getDataType()).isEqualTo(UINT8);
- + }
- +
- + @Test
- + public void loadBitmapSucceedsWithFloatTensorImage() {
- + Bitmap rgbBitmap = createRgbBitmap();
- + TensorBuffer rgbTensorBuffer = createRgbTensorBuffer(FLOAT32, false, 0.0f);
- + TensorImage floatImage = new TensorImage(FLOAT32);
- +
- + floatImage.load(rgbBitmap);
- + assertThat(floatImage.getBitmap().sameAs(rgbBitmap)).isTrue();
- + assertEqualTensorBuffers(floatImage.getTensorBuffer(), rgbTensorBuffer);
- + assertThat(floatImage.getDataType()).isEqualTo(FLOAT32);
- + }
- +
- + @Test
- + public void loadFloatArrayWithUint8TensorImage() {
- + TensorImage uint8Image = new TensorImage(UINT8);
- +
- + uint8Image.load(exampleFloatPixels, EXAMPLE_SHAPE);
- + assertThat(uint8Image.getBitmap()).isNotNull();
- + assertThat(uint8Image.getTensorBuffer().getFloatArray())
- + .isEqualTo(new float[exampleFloatPixels.length]); // All zero because of
- + // normalization and casting
- + // when loading.
- + }
- +
- + @Test
- + public void loadFloatArrayWithFloatTensorImage() {
- + TensorImage floatImage = new TensorImage(FLOAT32);
- +
- + floatImage.load(exampleFloatPixels, EXAMPLE_SHAPE);
- + assertThat(floatImage.getTensorBuffer().getFloatArray()).isEqualTo(exampleFloatPixels);
- + }
- +
- + @Test
- + public void loadUint8ArrayWithUint8TensorImage() {
- + TensorImage uint8Image = new TensorImage(UINT8);
- +
- + uint8Image.load(exampleUint8Pixels, EXAMPLE_SHAPE);
- + assertThat(uint8Image.getBitmap().sameAs(exampleBitmap)).isTrue();
- + assertThat(uint8Image.getTensorBuffer().getIntArray()).isEqualTo(exampleUint8Pixels);
- + }
- +
- + @Test
- + public void loadUint8ArrayWithFloatTensorImage() {
- + TensorImage floatImage = new TensorImage(FLOAT32);
- +
- + floatImage.load(exampleUint8Pixels, EXAMPLE_SHAPE);
- + assertThat(floatImage.getTensorBuffer().getIntArray()).isEqualTo(exampleUint8Pixels);
- + }
- +
- + @Test
- + public void loadTensorBufferWithUint8TensorImage() {
- + TensorImage uint8Image = new TensorImage(UINT8);
- +
- + uint8Image.load(exampleBitmap);
- + TensorBuffer buffer = uint8Image.getTensorBuffer();
- + uint8Image.load(buffer);
- + assertThat(uint8Image.getBitmap().sameAs(exampleBitmap)).isTrue();
- + }
- +
- + @Test
- + public void loadTensorBufferWithFloatTensorImage() {
- + TensorImage floatImage = new TensorImage(FLOAT32);
- +
- + floatImage.load(exampleBitmap);
- + TensorBuffer buffer = floatImage.getTensorBuffer();
- + floatImage.load(buffer);
- + assertThat(floatImage.getTensorBuffer().getIntArray()).isEqualTo(exampleUint8Pixels);
- + }
- +
- + @Test
- + public void loadAndGetMediaImageSucceedsWithYuv420888Format() {
- + setUpImageMock(imageMock, ImageFormat.YUV_420_888);
- + TensorImage tensorImage = new TensorImage(UINT8);
- +
- + tensorImage.load(imageMock);
- + Image imageReturned = tensorImage.getMediaImage();
- +
- + assertThat(imageReturned).isEqualTo(imageMock);
- + }
- +
- + @Test
- + public void loadMediaImageFailsWithNonYuv420888Format() {
- + setUpImageMock(imageMock, ImageFormat.YUV_422_888);
- + TensorImage tensorImage = new TensorImage(UINT8);
- +
- + IllegalArgumentException exception =
- + assertThrows(IllegalArgumentException.class, () -> tensorImage.load(imageMock));
- + assertThat(exception).hasMessageThat().contains(
- + "Only supports loading YUV_420_888 Image.");
- + }
- +
- + @Test
- + public void getBitmapWithUint8TensorImage() {
- + TensorImage uint8Image = new TensorImage(UINT8);
- +
- + uint8Image.load(exampleBitmap);
- + assertThat(uint8Image.getBitmap().sameAs(exampleBitmap)).isTrue();
- + // Also check zero copy is effective here (input and output are references of the same
- + // object).
- + assertThat(uint8Image.getBitmap()).isSameInstanceAs(exampleBitmap);
- + // Also check we don't create new Bitmap only with reading operations.
- + assertThat(uint8Image.getBuffer().limit())
- + .isEqualTo(EXAMPLE_NUM_PIXELS * EXAMPLE_NUM_CHANNELS);
- + assertThat(uint8Image.getBitmap()).isSameInstanceAs(exampleBitmap);
- +
- + uint8Image.load(exampleUint8Pixels, EXAMPLE_SHAPE);
- + assertThat(uint8Image.getBitmap()).isNotSameInstanceAs(exampleBitmap);
- + }
- +
- + @Test
- + public void getBitmapWithFloatTensorImage() {
- + TensorImage floatImage = new TensorImage(FLOAT32);
- +
- + floatImage.load(exampleBitmap);
- + assertThat(floatImage.getBitmap()).isSameInstanceAs(exampleBitmap);
- + }
- +
- + @Test
- + public void getBitmapWithEmptyTensorImage() {
- + TensorImage uint8Image = new TensorImage(UINT8);
- +
- + assertThrows(IllegalStateException.class, uint8Image::getBitmap);
- + }
- +
- + @Test
- + public void getMediaImageFailsWithBackedBitmap() {
- + TensorImage tensorImage = TensorImage.fromBitmap(exampleBitmap);
- +
- + UnsupportedOperationException exception = assertThrows(
- + UnsupportedOperationException.class, () -> tensorImage.getMediaImage());
- + assertThat(exception).hasMessageThat().contains(
- + "Converting from Bitmap to android.media.Image is unsupported.");
- + }
- +
- + @Test
- + public void getMediaImageFailsWithBackedTensorBuffer() {
- + TensorImage tensorImage = new TensorImage(UINT8);
- + tensorImage.load(exampleFloatPixels, EXAMPLE_SHAPE);
- +
- + UnsupportedOperationException exception = assertThrows(
- + UnsupportedOperationException.class, () -> tensorImage.getMediaImage());
- + assertThat(exception).hasMessageThat().contains(
- + "Converting from TensorBuffer to android.media.Image is unsupported.");
- + }
- +
- + @Test
- + public void getShapeOfInternalBitmapShouldSuccess() {
- + Bitmap bitmap = Bitmap.createBitmap(300, 400, Config.ARGB_8888);
- + TensorImage image = TensorImage.fromBitmap(bitmap);
- +
- + int width = image.getWidth();
- + int height = image.getHeight();
- +
- + assertThat(width).isEqualTo(300);
- + assertThat(height).isEqualTo(400);
- + }
- +
- + @Test
- + public void getShapeOfInternalTensorBufferShouldSuccess() {
- + TensorBuffer buffer = TensorBuffer.createFixedSize(new int[] {1, 400, 300, 3}, UINT8);
- + TensorImage image = new TensorImage();
- + image.load(buffer);
- +
- + int width = image.getWidth();
- + int height = image.getHeight();
- +
- + assertThat(width).isEqualTo(300);
- + assertThat(height).isEqualTo(400);
- + }
- +
- + @Test
- + public void getShapeOfNullImageShouldThrow() {
- + TensorImage image = new TensorImage();
- +
- + assertThrows(IllegalStateException.class, image::getHeight);
- + }
- +
- + @Test
- + public void getShapeOfACorruptedBufferShouldThrowRatherThanCrash() {
- + int[] data = new int[] {1, 2, 3, 4, 5, 6};
- + TensorBuffer buffer = TensorBuffer.createDynamic(UINT8);
- + buffer.loadArray(data, new int[] {1, 1, 2, 3});
- + TensorImage image = new TensorImage();
- + image.load(buffer);
- + // Reload data but with an invalid shape, which leads to `buffer` corrupted.
- + int[] newData = new int[] {1, 2, 3};
- + buffer.loadArray(newData, new int[] {1, 1, 1, 3});
- +
- + assertThrows(IllegalArgumentException.class, image::getHeight);
- + }
- +
- + @Test
- + public void getColorSpaceTypeSucceedsWithBitmapARGB_8888() {
- + Bitmap rgbBitmap = createRgbBitmap();
- + TensorImage tensorImage = TensorImage.fromBitmap(rgbBitmap);
- +
- + assertThat(tensorImage.getColorSpaceType()).isEqualTo(ColorSpaceType.RGB);
- + }
- +
- + @Test
- + public void getColorSpaceTypeSucceedsWithRgbTensorBuffer() {
- + TensorBuffer rgbBuffer = createRgbTensorBuffer(UINT8, false);
- + TensorImage tensorImage = new TensorImage();
- + tensorImage.load(rgbBuffer);
- +
- + assertThat(tensorImage.getColorSpaceType()).isEqualTo(ColorSpaceType.RGB);
- + }
- +
- + @Test
- + public void getColorSpaceTypeSucceedsWithGrayscaleTensorBuffer() {
- + TensorBuffer grayBuffer = createGrayscaleTensorBuffer(UINT8, false);
- + TensorImage tensorImage = new TensorImage();
- + tensorImage.load(grayBuffer, ColorSpaceType.GRAYSCALE);
- +
- + assertThat(tensorImage.getColorSpaceType()).isEqualTo(ColorSpaceType.GRAYSCALE);
- + }
- +
- + @Test
- + public void getColorSpaceTypeSucceedsWithRepeatedLoading() {
- + TensorBuffer grayBuffer = createGrayscaleTensorBuffer(UINT8, false);
- + TensorBuffer rgbBuffer = createRgbTensorBuffer(UINT8, false);
- + Bitmap rgbBitmap = createRgbBitmap();
- + TensorImage tensorImage = new TensorImage();
- +
- + tensorImage.load(rgbBuffer);
- + assertThat(tensorImage.getColorSpaceType()).isEqualTo(ColorSpaceType.RGB);
- + tensorImage.load(grayBuffer, ColorSpaceType.GRAYSCALE);
- + assertThat(tensorImage.getColorSpaceType()).isEqualTo(ColorSpaceType.GRAYSCALE);
- + tensorImage.load(rgbBitmap);
- + assertThat(tensorImage.getColorSpaceType()).isEqualTo(ColorSpaceType.RGB);
- + }
- +
- + @Test
- + public void getColorSpaceTypeFailsWhenNoImageHasBeenLoaded() {
- + TensorImage tensorImage = new TensorImage();
- +
- + IllegalStateException exception =
- + assertThrows(IllegalStateException.class, tensorImage::getColorSpaceType);
- + assertThat(exception).hasMessageThat().contains("No image has been loaded yet.");
- + }
- +
- + /**
- + * Creates an example bit map, which is a 10x10 ARGB bitmap and pixels are set by: pixel[i]
- + * = {A: 0, B: i + 2, G: i + 1, G: i}, where i is the flatten index
- + */
- + private static Bitmap createExampleBitmap() {
- + int[] colors = new int[EXAMPLE_NUM_PIXELS];
- + for (int i = 0; i < EXAMPLE_NUM_PIXELS; i++) {
- + colors[i] = Color.rgb(i, i + 1, i + 2);
- + }
- +
- + return Bitmap.createBitmap(
- + colors, EXAMPLE_WIDTH, EXAMPLE_HEIGHT, Bitmap.Config.ARGB_8888);
- + }
- +
- + private static float[] createExampleFloatPixels() {
- + float[] pixels = new float[EXAMPLE_NUM_PIXELS * EXAMPLE_NUM_CHANNELS];
- + for (int i = 0, j = 0; i < EXAMPLE_NUM_PIXELS; i++) {
- + pixels[j++] = (i - MEAN) / STDDEV;
- + pixels[j++] = (i + 1 - MEAN) / STDDEV;
- + pixels[j++] = (i + 2 - MEAN) / STDDEV;
- + }
- + return pixels;
- + }
- +
- + private static int[] createExampleUint8Pixels() {
- + int[] pixels = new int[EXAMPLE_NUM_PIXELS * EXAMPLE_NUM_CHANNELS];
- + for (int i = 0, j = 0; i < EXAMPLE_NUM_PIXELS; i++) {
- + pixels[j++] = i;
- + pixels[j++] = i + 1;
- + pixels[j++] = i + 2;
- + }
- + return pixels;
- + }
- + }
- +
- + /** Parameterized tests for loading TensorBuffers with RGB and Grayscale images. */
- + @RunWith(ParameterizedRobolectricTestRunner.class)
- + public static final class LoadTensorBufferWithRgbAndGrayscale extends TensorImageTest {
- + /**
- + * Difference between the pair of float and uint8 values. It is used to test the data
- + * conversion.
- + */
- + private static final float DELTA = 0.1f;
- +
- + /** The data type that used to create the TensorBuffer. */
- + @Parameter(0)
- + public DataType tensorBufferDataType;
- +
- + /** Indicates whether the shape is in the normalized form of (1, h, w, 3). */
- + @Parameter(1)
- + public boolean isNormalized;
- +
- + /** The color space type of the TensorBuffer. */
- + @Parameter(2)
- + public ColorSpaceType colorSpaceType;
- +
- + /** The data type that used to create the TensorImage. */
- + @Parameter(3)
- + public DataType tensorImageDataType;
- +
- + @Parameters(name = "tensorBufferDataType={0}; isNormalized={1}; colorSpaceType={2};"
- + + " tensorImageDataType={3}")
- + public static Collection<Object[]>
- + data() {
- + return Arrays.asList(new Object[][] {
- + {FLOAT32, true, ColorSpaceType.RGB, FLOAT32},
- + {FLOAT32, false, ColorSpaceType.RGB, UINT8},
- + {UINT8, true, ColorSpaceType.RGB, FLOAT32},
- + {UINT8, false, ColorSpaceType.RGB, UINT8},
- + });
- + }
- +
- + @Test
- + public void loadAndGetBitmapSucceedsWithTensorBufferAndColorSpaceType() {
- + TensorBuffer tensorBuffer =
- + createTensorBuffer(tensorBufferDataType, isNormalized, colorSpaceType, DELTA);
- + TensorImage tensorImage = new TensorImage(tensorImageDataType);
- +
- + tensorImage.load(tensorBuffer, colorSpaceType);
- + Bitmap bitmap = tensorImage.getBitmap();
- +
- + Bitmap expectedBitmap = createBitmap(colorSpaceType);
- + assertThat(bitmap.sameAs(expectedBitmap)).isTrue();
- + }
- +
- + @Test
- + public void loadAndGetTensorBufferSucceedsWithTensorBufferAndColorSpaceType() {
- + TensorBuffer tensorBuffer =
- + createTensorBuffer(tensorBufferDataType, isNormalized, colorSpaceType, DELTA);
- + TensorImage tensorImage = new TensorImage(tensorImageDataType);
- +
- + tensorImage.load(tensorBuffer, colorSpaceType);
- + TensorBuffer buffer = tensorImage.getTensorBuffer();
- +
- + // If tensorBufferDataType is UINT8, expectedTensorBuffer should not contain delta.
- + float expectedResidual = tensorBufferDataType == UINT8 ? 0.f : DELTA;
- + TensorBuffer expectedTensorBuffer = createTensorBuffer(
- + tensorImageDataType, isNormalized, colorSpaceType, expectedResidual);
- + assertEqualTensorBuffers(buffer, expectedTensorBuffer);
- + }
- +
- + private static TensorBuffer createTensorBuffer(DataType dataType, boolean isNormalized,
- + ColorSpaceType colorSpaceType, float delta) {
- + switch (colorSpaceType) {
- + case RGB:
- + return createRgbTensorBuffer(dataType, isNormalized, delta);
- + case GRAYSCALE:
- + return createGrayscaleTensorBuffer(dataType, isNormalized, delta);
- + default:
- + break;
- + }
- + throw new IllegalArgumentException(
- + "The ColorSpaceType, " + colorSpaceType + ", is unsupported.");
- + }
- +
- + private static Bitmap createBitmap(ColorSpaceType colorSpaceType) {
- + switch (colorSpaceType) {
- + case RGB:
- + return createRgbBitmap();
- + case GRAYSCALE:
- + return createGrayscaleBitmap();
- + default:
- + break;
- + }
- + throw new IllegalArgumentException(
- + "The ColorSpaceType, " + colorSpaceType + ", is unsupported.");
- + }
- + }
- +
- + /** Parameterized tests for loading TensorBuffers with YUV images. */
- + @RunWith(ParameterizedRobolectricTestRunner.class)
- + public static final class LoadTensorBufferWithYUV extends TensorImageTest {
- + private static final int HEIGHT = 2;
- + private static final int WIDTH = 3;
- +
- + @Parameter(0)
- + public ColorSpaceType colorSpaceType;
- +
- + @Parameters(name = "colorSpaceType={0}")
- + public static Collection<Object[]> data() {
- + return Arrays.asList(new Object[][] {
- + {ColorSpaceType.NV12},
- + {ColorSpaceType.NV21},
- + {ColorSpaceType.YV12},
- + {ColorSpaceType.YV21},
- + });
- + }
- +
- + @Test
- + public void loadTensorBufferWithColorSpaceShouldFail() {
- + TensorImage tensorImage = new TensorImage();
- +
- + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
- + ()
- + -> tensorImage.load(
- + TensorBuffer.createDynamic(DataType.FLOAT32), colorSpaceType));
- + assertThat(exception).hasMessageThat().contains(
- + "Only ColorSpaceType.RGB and ColorSpaceType.GRAYSCALE are supported. Use"
- + + " `load(TensorBuffer, ImageProperties)` for other color space types.");
- + }
- +
- + @Test
- + public void loadTensorBufferAndGetBitmapShouldFail() {
- + int[] data = new int[colorSpaceType.getNumElements(HEIGHT, WIDTH)];
- + TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.FLOAT32);
- + tensorBuffer.loadArray(data, new int[] {data.length});
- +
- + ImageProperties imageProperties = ImageProperties.builder()
- + .setHeight(HEIGHT)
- + .setWidth(WIDTH)
- + .setColorSpaceType(colorSpaceType)
- + .build();
- +
- + TensorImage tensorImage = new TensorImage(DataType.FLOAT32);
- + tensorImage.load(tensorBuffer, imageProperties);
- +
- + UnsupportedOperationException exception = assertThrows(
- + UnsupportedOperationException.class, () -> tensorImage.getBitmap());
- + assertThat(exception).hasMessageThat().contains(
- + "convertTensorBufferToBitmap() is unsupported for the color space type "
- + + colorSpaceType.name());
- + }
- + }
- +
- + /** Parameterized tests for loading TensorBuffers with ImageProperties. */
- + @RunWith(ParameterizedRobolectricTestRunner.class)
- + public static final class LoadTensorBufferWithImageProperties extends TensorImageTest {
- + private static final int HEIGHT = 2;
- + private static final int WIDTH = 3;
- + private static final int WRONG_WIDTH = 10;
- +
- + @Parameter(0)
- + public ColorSpaceType colorSpaceType;
- +
- + @Parameters(name = "colorSpaceType={0}")
- + public static Collection<Object[]> data() {
- + return Arrays.asList(new Object[][] {
- + {ColorSpaceType.RGB},
- + {ColorSpaceType.GRAYSCALE},
- + {ColorSpaceType.NV12},
- + {ColorSpaceType.NV21},
- + {ColorSpaceType.YV12},
- + {ColorSpaceType.YV21},
- + });
- + }
- +
- + @Test
- + public void loadAndGetTensorBufferShouldSucceedWithCorrectProperties() {
- + int[] data = new int[colorSpaceType.getNumElements(HEIGHT, WIDTH)];
- + TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.FLOAT32);
- + tensorBuffer.loadArray(data, new int[] {data.length});
- +
- + ImageProperties imageProperties = ImageProperties.builder()
- + .setHeight(HEIGHT)
- + .setWidth(WIDTH)
- + .setColorSpaceType(colorSpaceType)
- + .build();
- +
- + TensorImage tensorImage = new TensorImage(DataType.FLOAT32);
- + tensorImage.load(tensorBuffer, imageProperties);
- +
- + assertEqualTensorBuffers(tensorImage.getTensorBuffer(), tensorBuffer);
- + }
- +
- + @Test
- + public void loadAndGetTensorBufferShouldSucceedWithLargerBuffer() {
- + // Should allow buffer to be greater than the size specified by height and width.
- + int moreElements = 1;
- + int[] data = new int[colorSpaceType.getNumElements(HEIGHT, WIDTH) + moreElements];
- + TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.FLOAT32);
- + tensorBuffer.loadArray(data, new int[] {data.length});
- +
- + ImageProperties imageProperties = ImageProperties.builder()
- + .setHeight(HEIGHT)
- + .setWidth(WIDTH)
- + .setColorSpaceType(colorSpaceType)
- + .build();
- +
- + TensorImage tensorImage = new TensorImage(DataType.FLOAT32);
- + tensorImage.load(tensorBuffer, imageProperties);
- +
- + assertEqualTensorBuffers(tensorImage.getTensorBuffer(), tensorBuffer);
- + }
- +
- + @Test
- + public void loadAndGetByteBufferShouldSucceedWithCorrectProperties() {
- + ByteBuffer byteBuffer =
- + ByteBuffer.allocate(colorSpaceType.getNumElements(HEIGHT, WIDTH));
- +
- + ImageProperties imageProperties = ImageProperties.builder()
- + .setHeight(HEIGHT)
- + .setWidth(WIDTH)
- + .setColorSpaceType(colorSpaceType)
- + .build();
- +
- + TensorImage tensorImage = new TensorImage(DataType.UINT8);
- + tensorImage.load(byteBuffer, imageProperties);
- +
- + assertEqualByteBuffers(tensorImage.getBuffer(), byteBuffer);
- + }
- +
- + @Test
- + public void loadTensorBufferWithShouldFailWithWrongImageShape() {
- + int[] data = new int[colorSpaceType.getNumElements(HEIGHT, WIDTH)];
- + TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.FLOAT32);
- + tensorBuffer.loadArray(data, new int[] {data.length});
- +
- + ImageProperties imageProperties = ImageProperties.builder()
- + .setHeight(HEIGHT)
- + .setWidth(WRONG_WIDTH)
- + .setColorSpaceType(colorSpaceType)
- + .build();
- +
- + TensorImage tensorImage = new TensorImage(DataType.FLOAT32);
- +
- + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
- + () -> tensorImage.load(tensorBuffer, imageProperties));
- + assertThat(exception).hasMessageThat().contains(String.format(
- + "The given number of elements (%d) does not match the image (%s) in %d x %d. The"
- + + " expected number of elements should be at least %d.",
- + data.length, colorSpaceType.name(), HEIGHT, WRONG_WIDTH,
- + colorSpaceType.getNumElements(HEIGHT, WRONG_WIDTH)));
- + }
- +
- + @Test
- + public void getShapeOfInternalTensorBufferShouldSuccess() {
- + int[] data = new int[colorSpaceType.getNumElements(HEIGHT, WIDTH)];
- + TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.FLOAT32);
- + tensorBuffer.loadArray(data, new int[] {data.length});
- +
- + ImageProperties imageProperties = ImageProperties.builder()
- + .setHeight(HEIGHT)
- + .setWidth(WIDTH)
- + .setColorSpaceType(colorSpaceType)
- + .build();
- +
- + TensorImage tensorImage = new TensorImage(DataType.FLOAT32);
- + tensorImage.load(tensorBuffer, imageProperties);
- +
- + assertThat(tensorImage.getWidth()).isEqualTo(WIDTH);
- + assertThat(tensorImage.getHeight()).isEqualTo(HEIGHT);
- + }
- + }
- +
- + /** Parameterized tests for loading TensorBuffer with invalid shapes. */
- + @RunWith(ParameterizedRobolectricTestRunner.class)
- + public static final class LoadTensorBufferWithInvalidShapeTest extends TensorImageTest {
- + private static final String RGB_ASSERT_SHAPE_MESSAGE =
- + "The shape of a RGB image should be (h, w, c) or (1, h, w, c), and channels"
- + + " representing R, G, B in order. The provided image shape is ";
- + private static final String GRAYSCALE_ASSERT_SHAPE_MESSAGE =
- + "The shape of a grayscale image should be (h, w) or (1, h, w, 1). The provided image"
- + + " shape is ";
- +
- + @Parameter(0)
- + public ColorSpaceType colorSpaceType;
- +
- + /** The shape that does not match the colorSpaceType. */
- + @Parameter(1)
- + public int[] invalidShape;
- +
- + @Parameter(2)
- + public String errorMessage;
- +
- + @Parameters(name = "colorSpaceType={0}; invalidShape={1}")
- + public static Collection<Object[]> data() {
- + return Arrays.asList(new Object[][] {
- + {ColorSpaceType.RGB, new int[] {2, 10, 20, 3}, RGB_ASSERT_SHAPE_MESSAGE},
- + {ColorSpaceType.RGB, new int[] {1, 10, 20, 3, 4}, RGB_ASSERT_SHAPE_MESSAGE},
- + {ColorSpaceType.RGB, new int[] {1, 10, 20, 5}, RGB_ASSERT_SHAPE_MESSAGE},
- + {ColorSpaceType.RGB, new int[] {1, 10, 20}, RGB_ASSERT_SHAPE_MESSAGE},
- + {ColorSpaceType.RGB, new int[] {10, 20, 3, 4}, RGB_ASSERT_SHAPE_MESSAGE},
- + {ColorSpaceType.RGB, new int[] {10, 20, 5}, RGB_ASSERT_SHAPE_MESSAGE},
- + {ColorSpaceType.RGB, new int[] {10, 20}, RGB_ASSERT_SHAPE_MESSAGE},
- + {ColorSpaceType.GRAYSCALE, new int[] {2, 10, 20},
- + GRAYSCALE_ASSERT_SHAPE_MESSAGE},
- + {ColorSpaceType.GRAYSCALE, new int[] {1, 10, 20, 3},
- + GRAYSCALE_ASSERT_SHAPE_MESSAGE},
- + {ColorSpaceType.GRAYSCALE, new int[] {10, 20, 4},
- + GRAYSCALE_ASSERT_SHAPE_MESSAGE},
- + {ColorSpaceType.GRAYSCALE, new int[] {10}, GRAYSCALE_ASSERT_SHAPE_MESSAGE},
- + });
- + }
- +
- + @Test
- + public void loadTensorBufferWithInvalidShape() {
- + TensorBuffer tensorBuffer = TensorBuffer.createFixedSize(invalidShape, UINT8);
- + TensorImage tensorImage = new TensorImage();
- +
- + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
- + () -> tensorImage.load(tensorBuffer, colorSpaceType));
- + assertThat(exception).hasMessageThat().contains(
- + errorMessage + Arrays.toString(invalidShape));
- + }
- + }
- +
- + private static void assertEqualTensorBuffers(
- + TensorBuffer tensorBuffer1, TensorBuffer tensorBuffer2) {
- + assertEqualByteBuffers(tensorBuffer1.getBuffer(), tensorBuffer2.getBuffer());
- + assertArrayEquals(tensorBuffer1.getShape(), tensorBuffer2.getShape());
- + }
- +
- + private static void assertEqualByteBuffers(ByteBuffer buffer1, ByteBuffer buffer2) {
- + buffer1.rewind();
- + buffer2.rewind();
- + assertThat(buffer1.equals(buffer2)).isTrue();
- + }
- +
- + private static void setUpImageMock(Image imageMock, int imageFormat) {
- + when(imageMock.getFormat()).thenReturn(imageFormat);
- }
- - }
- -
- - private static void assertEqualTensorBuffers(
- - TensorBuffer tensorBuffer1, TensorBuffer tensorBuffer2) {
- - assertEqualByteBuffers(tensorBuffer1.getBuffer(), tensorBuffer2.getBuffer());
- - assertArrayEquals(tensorBuffer1.getShape(), tensorBuffer2.getShape());
- - }
- -
- - private static void assertEqualByteBuffers(ByteBuffer buffer1, ByteBuffer buffer2) {
- - buffer1.rewind();
- - buffer2.rewind();
- - assertThat(buffer1.equals(buffer2)).isTrue();
- - }
- -
- - private static void setUpImageMock(Image imageMock, int imageFormat) {
- - when(imageMock.getFormat()).thenReturn(imageFormat);
- - }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/TestImageCreator.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/TestImageCreator.java
- index 7a5d0e9a9ea33..4ac2eca0b8cc6 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/TestImageCreator.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/TestImageCreator.java
- @@ -17,109 +17,112 @@ package org.tensorflow.lite.support.image;
-
- import android.graphics.Bitmap;
- import android.graphics.Color;
- -import java.nio.ByteBuffer;
- +
- import org.tensorflow.lite.DataType;
- import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
-
- +import java.nio.ByteBuffer;
- +
- /** Creates test images for other test files. */
- final class TestImageCreator {
- - /**
- - * Creates an example bitmap, which is a 10x10 ARGB bitmap and pixels are set by: <br>
- - * pixel[i] = {A: 255, B: i + 2, G: i + 1, R: i}, where i is the flatten index.
- - */
- - static Bitmap createRgbBitmap() {
- - int[] colors = new int[100];
- - for (int i = 0; i < 100; i++) {
- - colors[i] = Color.rgb(i, i + 1, i + 2);
- + /**
- + * Creates an example bitmap, which is a 10x10 ARGB bitmap and pixels are set by: <br>
- + * pixel[i] = {A: 255, B: i + 2, G: i + 1, R: i}, where i is the flatten index.
- + */
- + static Bitmap createRgbBitmap() {
- + int[] colors = new int[100];
- + for (int i = 0; i < 100; i++) {
- + colors[i] = Color.rgb(i, i + 1, i + 2);
- + }
- + return Bitmap.createBitmap(colors, 10, 10, Bitmap.Config.ARGB_8888);
- }
- - return Bitmap.createBitmap(colors, 10, 10, Bitmap.Config.ARGB_8888);
- - }
-
- - /**
- - * Creates a 10*10*3 float or uint8 TensorBuffer representing the same image in createRgbBitmap.
- - *
- - * <p>Adds a default delta, 0.1f, to the generated float values, such that the float array is
- - * [0.1, 1.1, 2.1, 3.1, ...], while the uint8 array is[0, 1, 2, 3, ...].
- - *
- - * @param isNormalized if true, the shape is (1, h, w, 3), otherwise it's (h, w, 3)
- - */
- - static TensorBuffer createRgbTensorBuffer(DataType dataType, boolean isNormalized) {
- - return createRgbTensorBuffer(dataType, isNormalized, /*delta=*/ 0.1f);
- - }
- -
- - /**
- - * Creates a 10*10*3 float or uint8 TensorBuffer representing the same image in createRgbBitmap.
- - *
- - * @param isNormalized if true, the shape is (1, h, w, 3), otherwise it's (h, w)
- - * @param delta the delta that applied to the float values, such that the float array is [0 + +
- - * delta, 1+ delta, 2+ delta, 3+ delta, ...], while the uint8 array is [0, 1, 2, 3, ...]
- - */
- - static TensorBuffer createRgbTensorBuffer(DataType dataType, boolean isNormalized, float delta) {
- - float[] rgbValues = new float[300];
- - for (int i = 0, j = 0; i < 100; i++) {
- - rgbValues[j++] = i + delta;
- - rgbValues[j++] = i + 1 + delta;
- - rgbValues[j++] = i + 2 + delta;
- + /**
- + * Creates a 10*10*3 float or uint8 TensorBuffer representing the same image in createRgbBitmap.
- + *
- + * <p>Adds a default delta, 0.1f, to the generated float values, such that the float array is
- + * [0.1, 1.1, 2.1, 3.1, ...], while the uint8 array is[0, 1, 2, 3, ...].
- + *
- + * @param isNormalized if true, the shape is (1, h, w, 3), otherwise it's (h, w, 3)
- + */
- + static TensorBuffer createRgbTensorBuffer(DataType dataType, boolean isNormalized) {
- + return createRgbTensorBuffer(dataType, isNormalized, /*delta=*/0.1f);
- }
-
- - int[] shape = isNormalized ? new int[] {1, 10, 10, 3} : new int[] {10, 10, 3};
- - TensorBuffer buffer = TensorBuffer.createFixedSize(shape, dataType);
- - // If dataType is UINT8, rgbValues will be converted into uint8, such as from
- - // [0.1, 1.1, 2.1, 3.1, ...] to [0, 1, 2, 3, ...].
- - buffer.loadArray(rgbValues, shape);
- - return buffer;
- - }
- + /**
- + * Creates a 10*10*3 float or uint8 TensorBuffer representing the same image in createRgbBitmap.
- + *
- + * @param isNormalized if true, the shape is (1, h, w, 3), otherwise it's (h, w)
- + * @param delta the delta that applied to the float values, such that the float array is [0 + +
- + * delta, 1+ delta, 2+ delta, 3+ delta, ...], while the uint8 array is [0, 1, 2, 3, ...]
- + */
- + static TensorBuffer createRgbTensorBuffer(
- + DataType dataType, boolean isNormalized, float delta) {
- + float[] rgbValues = new float[300];
- + for (int i = 0, j = 0; i < 100; i++) {
- + rgbValues[j++] = i + delta;
- + rgbValues[j++] = i + 1 + delta;
- + rgbValues[j++] = i + 2 + delta;
- + }
- +
- + int[] shape = isNormalized ? new int[] {1, 10, 10, 3} : new int[] {10, 10, 3};
- + TensorBuffer buffer = TensorBuffer.createFixedSize(shape, dataType);
- + // If dataType is UINT8, rgbValues will be converted into uint8, such as from
- + // [0.1, 1.1, 2.1, 3.1, ...] to [0, 1, 2, 3, ...].
- + buffer.loadArray(rgbValues, shape);
- + return buffer;
- + }
-
- - /**
- - * Creates an example bitmap, which is a 10x10 ALPHA_8 bitmap and pixels are set by: <br>
- - * pixel[i] = i, where i is the flatten index.
- - */
- - static Bitmap createGrayscaleBitmap() {
- - byte[] grayValues = new byte[100];
- - for (int i = 0; i < 100; i++) {
- - grayValues[i] = (byte) i;
- + /**
- + * Creates an example bitmap, which is a 10x10 ALPHA_8 bitmap and pixels are set by: <br>
- + * pixel[i] = i, where i is the flatten index.
- + */
- + static Bitmap createGrayscaleBitmap() {
- + byte[] grayValues = new byte[100];
- + for (int i = 0; i < 100; i++) {
- + grayValues[i] = (byte) i;
- + }
- + ByteBuffer buffer = ByteBuffer.wrap(grayValues);
- + Bitmap bitmap = Bitmap.createBitmap(10, 10, Bitmap.Config.ALPHA_8);
- + buffer.rewind();
- + bitmap.copyPixelsFromBuffer(buffer);
- + return bitmap;
- }
- - ByteBuffer buffer = ByteBuffer.wrap(grayValues);
- - Bitmap bitmap = Bitmap.createBitmap(10, 10, Bitmap.Config.ALPHA_8);
- - buffer.rewind();
- - bitmap.copyPixelsFromBuffer(buffer);
- - return bitmap;
- - }
-
- - /**
- - * Creates a 10*10 float or uint8 TensorBuffer representing the same image in
- - * createGrayscaleBitmap.
- - *
- - * <p>Adds a default delta, 0.1f, to the generated float values, such that the float array is
- - * [0.1, 1.1, 2.1, 3.1, ...], while the uint8 array is[0, 1, 2, 3, ...].
- - *
- - * @param isNormalized if true, the shape is (1, h, w, 1), otherwise it's (h, w)
- - */
- - static TensorBuffer createGrayscaleTensorBuffer(DataType dataType, boolean isNormalized) {
- - return createGrayscaleTensorBuffer(dataType, isNormalized, /*delta=*/ 0.1f);
- - }
- + /**
- + * Creates a 10*10 float or uint8 TensorBuffer representing the same image in
- + * createGrayscaleBitmap.
- + *
- + * <p>Adds a default delta, 0.1f, to the generated float values, such that the float array is
- + * [0.1, 1.1, 2.1, 3.1, ...], while the uint8 array is[0, 1, 2, 3, ...].
- + *
- + * @param isNormalized if true, the shape is (1, h, w, 1), otherwise it's (h, w)
- + */
- + static TensorBuffer createGrayscaleTensorBuffer(DataType dataType, boolean isNormalized) {
- + return createGrayscaleTensorBuffer(dataType, isNormalized, /*delta=*/0.1f);
- + }
-
- - /**
- - * Creates a 10*10 float or uint8 TensorBuffer representing the same image in
- - * createGrayscaleBitmap.
- - *
- - * @param isNormalized if true, the shape is (1, h, w, 1), otherwise it's (h, w)
- - * @param delta the delta that applied to the float values, such that the float array is [0 +
- - * delta, 1+ delta, 2+ delta, 3+ delta, ...], while the uint8 array is [0, 1, 2, 3, ...]
- - */
- - static TensorBuffer createGrayscaleTensorBuffer(
- - DataType dataType, boolean isNormalized, float delta) {
- - float[] grayValues = new float[100];
- - for (int i = 0; i < 100; i++) {
- - grayValues[i] = i + delta;
- + /**
- + * Creates a 10*10 float or uint8 TensorBuffer representing the same image in
- + * createGrayscaleBitmap.
- + *
- + * @param isNormalized if true, the shape is (1, h, w, 1), otherwise it's (h, w)
- + * @param delta the delta that applied to the float values, such that the float array is [0 +
- + * delta, 1+ delta, 2+ delta, 3+ delta, ...], while the uint8 array is [0, 1, 2, 3, ...]
- + */
- + static TensorBuffer createGrayscaleTensorBuffer(
- + DataType dataType, boolean isNormalized, float delta) {
- + float[] grayValues = new float[100];
- + for (int i = 0; i < 100; i++) {
- + grayValues[i] = i + delta;
- + }
- + int[] shape = isNormalized ? new int[] {1, 10, 10, 1} : new int[] {10, 10};
- + TensorBuffer buffer = TensorBuffer.createFixedSize(shape, dataType);
- + // If dataType is UINT8, grayValues will be converted into uint8, such as from
- + // [0.1, 1.1, 2.1, 3.1, ...] to [0, 1, 2, 3, ...].
- + buffer.loadArray(grayValues, shape);
- + return buffer;
- }
- - int[] shape = isNormalized ? new int[] {1, 10, 10, 1} : new int[] {10, 10};
- - TensorBuffer buffer = TensorBuffer.createFixedSize(shape, dataType);
- - // If dataType is UINT8, grayValues will be converted into uint8, such as from
- - // [0.1, 1.1, 2.1, 3.1, ...] to [0, 1, 2, 3, ...].
- - buffer.loadArray(grayValues, shape);
- - return buffer;
- - }
-
- - private TestImageCreator() {}
- + private TestImageCreator() {}
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ops/ResizeOpInstrumentedTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ops/ResizeOpInstrumentedTest.java
- index a34f47d44c0ac..070e17893ad76 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ops/ResizeOpInstrumentedTest.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ops/ResizeOpInstrumentedTest.java
- @@ -19,7 +19,9 @@ import static com.google.common.truth.Truth.assertThat;
-
- import android.graphics.Bitmap;
- import android.graphics.PointF;
- +
- import androidx.test.ext.junit.runners.AndroidJUnit4;
- +
- import org.junit.Before;
- import org.junit.Test;
- import org.junit.runner.RunWith;
- @@ -31,63 +33,62 @@ import org.tensorflow.lite.support.image.ops.ResizeOp.ResizeMethod;
- /** Instrumented unit test for {@link ResizeOp}. */
- @RunWith(AndroidJUnit4.class)
- public class ResizeOpInstrumentedTest {
- + private static final int EXAMPLE_WIDTH = 10;
- + private static final int EXAMPLE_HEIGHT = 15;
-
- - private static final int EXAMPLE_WIDTH = 10;
- - private static final int EXAMPLE_HEIGHT = 15;
- -
- - private Bitmap exampleBitmap;
- - private TensorImage input;
- + private Bitmap exampleBitmap;
- + private TensorImage input;
-
- - @Before
- - public void setUp() {
- - exampleBitmap = createExampleBitmap();
- - input = new TensorImage(DataType.UINT8);
- - input.load(exampleBitmap);
- - }
- + @Before
- + public void setUp() {
- + exampleBitmap = createExampleBitmap();
- + input = new TensorImage(DataType.UINT8);
- + input.load(exampleBitmap);
- + }
-
- - @Test
- - public void resizeShouldSuccess() {
- - int targetWidth = EXAMPLE_WIDTH * 2;
- - int targetHeight = EXAMPLE_HEIGHT * 2;
- - ImageProcessor processor =
- - new ImageProcessor.Builder()
- - .add(new ResizeOp(targetHeight, targetWidth, ResizeMethod.NEAREST_NEIGHBOR))
- - .build();
- - TensorImage output = processor.process(input);
- + @Test
- + public void resizeShouldSuccess() {
- + int targetWidth = EXAMPLE_WIDTH * 2;
- + int targetHeight = EXAMPLE_HEIGHT * 2;
- + ImageProcessor processor =
- + new ImageProcessor.Builder()
- + .add(new ResizeOp(targetHeight, targetWidth, ResizeMethod.NEAREST_NEIGHBOR))
- + .build();
- + TensorImage output = processor.process(input);
-
- - Bitmap outputBitmap = output.getBitmap();
- - assertThat(outputBitmap.getWidth()).isEqualTo(targetWidth);
- - assertThat(outputBitmap.getHeight()).isEqualTo(targetHeight);
- - for (int i = 0; i < outputBitmap.getWidth(); i++) {
- - for (int j = 0; j < outputBitmap.getHeight(); j++) {
- - int expected = exampleBitmap.getPixel(i / 2, j / 2);
- - assertThat(outputBitmap.getPixel(i, j)).isEqualTo(expected);
- - }
- + Bitmap outputBitmap = output.getBitmap();
- + assertThat(outputBitmap.getWidth()).isEqualTo(targetWidth);
- + assertThat(outputBitmap.getHeight()).isEqualTo(targetHeight);
- + for (int i = 0; i < outputBitmap.getWidth(); i++) {
- + for (int j = 0; j < outputBitmap.getHeight(); j++) {
- + int expected = exampleBitmap.getPixel(i / 2, j / 2);
- + assertThat(outputBitmap.getPixel(i, j)).isEqualTo(expected);
- + }
- + }
- }
- - }
-
- - @Test
- - public void inverseTransformPointShouldSuccess() {
- - ResizeOp op = new ResizeOp(200, 300, ResizeMethod.NEAREST_NEIGHBOR);
- - PointF transformed = new PointF(32.0f, 42.0f);
- - // The original image size is 900x400 assumed
- - PointF original = op.inverseTransform(transformed, 400, 900);
- - assertThat(original.x).isEqualTo(96);
- - assertThat(original.y).isEqualTo(84);
- - PointF outside = op.inverseTransform(new PointF(500, 1000), 400, 900);
- - assertThat(outside.x).isEqualTo(1500);
- - assertThat(outside.y).isEqualTo(2000);
- - }
- + @Test
- + public void inverseTransformPointShouldSuccess() {
- + ResizeOp op = new ResizeOp(200, 300, ResizeMethod.NEAREST_NEIGHBOR);
- + PointF transformed = new PointF(32.0f, 42.0f);
- + // The original image size is 900x400 assumed
- + PointF original = op.inverseTransform(transformed, 400, 900);
- + assertThat(original.x).isEqualTo(96);
- + assertThat(original.y).isEqualTo(84);
- + PointF outside = op.inverseTransform(new PointF(500, 1000), 400, 900);
- + assertThat(outside.x).isEqualTo(1500);
- + assertThat(outside.y).isEqualTo(2000);
- + }
-
- - /**
- - * Creates an example bitmap, which is a 10x15 ARGB bitmap and pixels are set by: - pixel[i] = {A:
- - * 255, B: i + 2, G: i + 1, G: i}, where i is the flatten index
- - */
- - private static Bitmap createExampleBitmap() {
- - int[] colors = new int[EXAMPLE_WIDTH * EXAMPLE_HEIGHT];
- - for (int i = 0; i < EXAMPLE_WIDTH * EXAMPLE_HEIGHT; i++) {
- - colors[i] = (i << 16) | ((i + 1) << 8) | (i + 2);
- + /**
- + * Creates an example bitmap, which is a 10x15 ARGB bitmap and pixels are set by: - pixel[i] =
- + * {A: 255, B: i + 2, G: i + 1, G: i}, where i is the flatten index
- + */
- + private static Bitmap createExampleBitmap() {
- + int[] colors = new int[EXAMPLE_WIDTH * EXAMPLE_HEIGHT];
- + for (int i = 0; i < EXAMPLE_WIDTH * EXAMPLE_HEIGHT; i++) {
- + colors[i] = (i << 16) | ((i + 1) << 8) | (i + 2);
- + }
- + return Bitmap.createBitmap(colors, EXAMPLE_WIDTH, EXAMPLE_HEIGHT, Bitmap.Config.ARGB_8888);
- }
- - return Bitmap.createBitmap(colors, EXAMPLE_WIDTH, EXAMPLE_HEIGHT, Bitmap.Config.ARGB_8888);
- - }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ops/ResizeWithCropOrPadOpInstrumentedTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ops/ResizeWithCropOrPadOpInstrumentedTest.java
- index 5c483780b30f4..85c777904f2ec 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ops/ResizeWithCropOrPadOpInstrumentedTest.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ops/ResizeWithCropOrPadOpInstrumentedTest.java
- @@ -19,7 +19,9 @@ import static com.google.common.truth.Truth.assertThat;
-
- import android.graphics.Bitmap;
- import android.graphics.PointF;
- +
- import androidx.test.ext.junit.runners.AndroidJUnit4;
- +
- import org.junit.Before;
- import org.junit.Test;
- import org.junit.runner.RunWith;
- @@ -30,131 +32,128 @@ import org.tensorflow.lite.support.image.TensorImage;
- /** Instrumented unit test for {@link ResizeWithCropOrPadOp}. */
- @RunWith(AndroidJUnit4.class)
- public class ResizeWithCropOrPadOpInstrumentedTest {
- + private Bitmap exampleBitmap;
- + private TensorImage input;
-
- - private Bitmap exampleBitmap;
- - private TensorImage input;
- -
- - private static final int EXAMPLE_WIDTH = 10;
- - private static final int EXAMPLE_HEIGHT = 15;
- -
- - @Before
- - public void setUp() {
- - exampleBitmap = createExampleBitmap();
- - input = new TensorImage(DataType.UINT8);
- - input.load(exampleBitmap);
- - }
- -
- - @Test
- - public void testResizeWithCrop() {
- - int targetWidth = 6;
- - int targetHeight = 5;
- - ImageProcessor processor =
- - new ImageProcessor.Builder()
- - .add(new ResizeWithCropOrPadOp(targetHeight, targetWidth))
- - .build();
- - TensorImage output = processor.process(input);
- -
- - Bitmap outputBitmap = output.getBitmap();
- - assertThat(outputBitmap.getWidth()).isEqualTo(targetWidth);
- - assertThat(outputBitmap.getHeight()).isEqualTo(targetHeight);
- - for (int i = 0; i < outputBitmap.getWidth(); i++) {
- - for (int j = 0; j < outputBitmap.getHeight(); j++) {
- - int expected =
- - exampleBitmap.getPixel(
- - i + (EXAMPLE_WIDTH - targetWidth) / 2, j + (EXAMPLE_HEIGHT - targetHeight) / 2);
- - assertThat(outputBitmap.getPixel(i, j)).isEqualTo(expected);
- - }
- + private static final int EXAMPLE_WIDTH = 10;
- + private static final int EXAMPLE_HEIGHT = 15;
- +
- + @Before
- + public void setUp() {
- + exampleBitmap = createExampleBitmap();
- + input = new TensorImage(DataType.UINT8);
- + input.load(exampleBitmap);
- }
- - }
- -
- - @Test
- - public void testResizeWithPad() {
- - int targetWidth = 15;
- - int targetHeight = 20;
- - ImageProcessor processor =
- - new ImageProcessor.Builder()
- - .add(new ResizeWithCropOrPadOp(targetHeight, targetWidth))
- - .build();
- - TensorImage output = processor.process(input);
- - // Pad 2 rows / columns on top / left, and 3 rows / columns on bottom / right
- -
- - Bitmap outputBitmap = output.getBitmap();
- - assertThat(outputBitmap.getWidth()).isEqualTo(targetWidth);
- - assertThat(outputBitmap.getHeight()).isEqualTo(targetHeight);
- - int leftPad = (targetWidth - EXAMPLE_WIDTH) / 2;
- - int topPad = (targetHeight - EXAMPLE_HEIGHT) / 2;
- - for (int i = 0; i < outputBitmap.getWidth(); i++) {
- - for (int j = 0; j < outputBitmap.getHeight(); j++) {
- - int expected = 0; // ZERO padding
- - if (i >= leftPad
- - && i < leftPad + EXAMPLE_WIDTH
- - && j >= topPad
- - && j < topPad + EXAMPLE_HEIGHT) {
- - expected = exampleBitmap.getPixel(i - leftPad, j - topPad);
- +
- + @Test
- + public void testResizeWithCrop() {
- + int targetWidth = 6;
- + int targetHeight = 5;
- + ImageProcessor processor =
- + new ImageProcessor.Builder()
- + .add(new ResizeWithCropOrPadOp(targetHeight, targetWidth))
- + .build();
- + TensorImage output = processor.process(input);
- +
- + Bitmap outputBitmap = output.getBitmap();
- + assertThat(outputBitmap.getWidth()).isEqualTo(targetWidth);
- + assertThat(outputBitmap.getHeight()).isEqualTo(targetHeight);
- + for (int i = 0; i < outputBitmap.getWidth(); i++) {
- + for (int j = 0; j < outputBitmap.getHeight(); j++) {
- + int expected = exampleBitmap.getPixel(i + (EXAMPLE_WIDTH - targetWidth) / 2,
- + j + (EXAMPLE_HEIGHT - targetHeight) / 2);
- + assertThat(outputBitmap.getPixel(i, j)).isEqualTo(expected);
- + }
- + }
- + }
- +
- + @Test
- + public void testResizeWithPad() {
- + int targetWidth = 15;
- + int targetHeight = 20;
- + ImageProcessor processor =
- + new ImageProcessor.Builder()
- + .add(new ResizeWithCropOrPadOp(targetHeight, targetWidth))
- + .build();
- + TensorImage output = processor.process(input);
- + // Pad 2 rows / columns on top / left, and 3 rows / columns on bottom / right
- +
- + Bitmap outputBitmap = output.getBitmap();
- + assertThat(outputBitmap.getWidth()).isEqualTo(targetWidth);
- + assertThat(outputBitmap.getHeight()).isEqualTo(targetHeight);
- + int leftPad = (targetWidth - EXAMPLE_WIDTH) / 2;
- + int topPad = (targetHeight - EXAMPLE_HEIGHT) / 2;
- + for (int i = 0; i < outputBitmap.getWidth(); i++) {
- + for (int j = 0; j < outputBitmap.getHeight(); j++) {
- + int expected = 0; // ZERO padding
- + if (i >= leftPad && i < leftPad + EXAMPLE_WIDTH && j >= topPad
- + && j < topPad + EXAMPLE_HEIGHT) {
- + expected = exampleBitmap.getPixel(i - leftPad, j - topPad);
- + }
- + assertThat(outputBitmap.getPixel(i, j)).isEqualTo(expected);
- + }
- }
- - assertThat(outputBitmap.getPixel(i, j)).isEqualTo(expected);
- - }
- }
- - }
- -
- - @Test
- - public void testResizeWithCropAndPad() {
- - int targetSize = 12;
- - // Pad 1 column on left & right, crop 1 row on top and 2 rows on bottom
- - ImageProcessor processor =
- - new ImageProcessor.Builder().add(new ResizeWithCropOrPadOp(targetSize, targetSize)).build();
- - TensorImage output = processor.process(input);
- -
- - Bitmap outputBitmap = output.getBitmap();
- - assertThat(outputBitmap.getWidth()).isEqualTo(targetSize);
- - assertThat(outputBitmap.getHeight()).isEqualTo(targetSize);
- -
- - int leftPad = (targetSize - EXAMPLE_WIDTH) / 2;
- - int topCrop = (EXAMPLE_HEIGHT - targetSize) / 2;
- - for (int i = 0; i < outputBitmap.getWidth(); i++) {
- - for (int j = 0; j < outputBitmap.getHeight(); j++) {
- - int expected = 0;
- - if (i >= leftPad && i < leftPad + EXAMPLE_WIDTH) {
- - expected = exampleBitmap.getPixel(i - leftPad, j + topCrop);
- +
- + @Test
- + public void testResizeWithCropAndPad() {
- + int targetSize = 12;
- + // Pad 1 column on left & right, crop 1 row on top and 2 rows on bottom
- + ImageProcessor processor = new ImageProcessor.Builder()
- + .add(new ResizeWithCropOrPadOp(targetSize, targetSize))
- + .build();
- + TensorImage output = processor.process(input);
- +
- + Bitmap outputBitmap = output.getBitmap();
- + assertThat(outputBitmap.getWidth()).isEqualTo(targetSize);
- + assertThat(outputBitmap.getHeight()).isEqualTo(targetSize);
- +
- + int leftPad = (targetSize - EXAMPLE_WIDTH) / 2;
- + int topCrop = (EXAMPLE_HEIGHT - targetSize) / 2;
- + for (int i = 0; i < outputBitmap.getWidth(); i++) {
- + for (int j = 0; j < outputBitmap.getHeight(); j++) {
- + int expected = 0;
- + if (i >= leftPad && i < leftPad + EXAMPLE_WIDTH) {
- + expected = exampleBitmap.getPixel(i - leftPad, j + topCrop);
- + }
- + assertThat(outputBitmap.getPixel(i, j)).isEqualTo(expected);
- + }
- }
- - assertThat(outputBitmap.getPixel(i, j)).isEqualTo(expected);
- - }
- }
- - }
- -
- - @Test
- - public void inverseTransformCorrectlyWhenCropped() {
- - ResizeWithCropOrPadOp op = new ResizeWithCropOrPadOp(300, 300);
- - // The point (100, 50) is transformed from 600x500 image
- - PointF original = op.inverseTransform(new PointF(100, 50), 500, 600);
- - assertThat(original.x).isEqualTo(250);
- - assertThat(original.y).isEqualTo(150);
- - PointF cropped = op.inverseTransform(new PointF(-10, -10), 500, 600);
- - assertThat(cropped.x).isEqualTo(140);
- - assertThat(cropped.y).isEqualTo(90);
- - }
- -
- - @Test
- - public void inverseTransformCorrectlyWhenPadded() {
- - ResizeWithCropOrPadOp op = new ResizeWithCropOrPadOp(300, 300);
- - // The point (100, 50) is transformed from 100x200 image
- - PointF original = op.inverseTransform(new PointF(100, 50), 200, 100);
- - assertThat(original.x).isEqualTo(0);
- - assertThat(original.y).isEqualTo(0);
- - PointF outside = op.inverseTransform(new PointF(50, 10), 200, 100);
- - assertThat(outside.x).isEqualTo(-50);
- - assertThat(outside.y).isEqualTo(-40);
- - }
- -
- - /**
- - * Creates an example bitmap, which is a 10x15 ARGB bitmap and pixels are set by: - pixel[i] = {A:
- - * 255, R: i + 2, G: i + 1, B: i}, where i is the flatten index
- - */
- - private static Bitmap createExampleBitmap() {
- - int[] colors = new int[EXAMPLE_WIDTH * EXAMPLE_HEIGHT];
- - for (int i = 0; i < EXAMPLE_WIDTH * EXAMPLE_HEIGHT; i++) {
- - colors[i] = (i << 16) | ((i + 1) << 8) | (i + 2);
- +
- + @Test
- + public void inverseTransformCorrectlyWhenCropped() {
- + ResizeWithCropOrPadOp op = new ResizeWithCropOrPadOp(300, 300);
- + // The point (100, 50) is transformed from 600x500 image
- + PointF original = op.inverseTransform(new PointF(100, 50), 500, 600);
- + assertThat(original.x).isEqualTo(250);
- + assertThat(original.y).isEqualTo(150);
- + PointF cropped = op.inverseTransform(new PointF(-10, -10), 500, 600);
- + assertThat(cropped.x).isEqualTo(140);
- + assertThat(cropped.y).isEqualTo(90);
- + }
- +
- + @Test
- + public void inverseTransformCorrectlyWhenPadded() {
- + ResizeWithCropOrPadOp op = new ResizeWithCropOrPadOp(300, 300);
- + // The point (100, 50) is transformed from 100x200 image
- + PointF original = op.inverseTransform(new PointF(100, 50), 200, 100);
- + assertThat(original.x).isEqualTo(0);
- + assertThat(original.y).isEqualTo(0);
- + PointF outside = op.inverseTransform(new PointF(50, 10), 200, 100);
- + assertThat(outside.x).isEqualTo(-50);
- + assertThat(outside.y).isEqualTo(-40);
- + }
- +
- + /**
- + * Creates an example bitmap, which is a 10x15 ARGB bitmap and pixels are set by: - pixel[i] =
- + * {A: 255, R: i + 2, G: i + 1, B: i}, where i is the flatten index
- + */
- + private static Bitmap createExampleBitmap() {
- + int[] colors = new int[EXAMPLE_WIDTH * EXAMPLE_HEIGHT];
- + for (int i = 0; i < EXAMPLE_WIDTH * EXAMPLE_HEIGHT; i++) {
- + colors[i] = (i << 16) | ((i + 1) << 8) | (i + 2);
- + }
- + return Bitmap.createBitmap(colors, EXAMPLE_WIDTH, EXAMPLE_HEIGHT, Bitmap.Config.ARGB_8888);
- }
- - return Bitmap.createBitmap(colors, EXAMPLE_WIDTH, EXAMPLE_HEIGHT, Bitmap.Config.ARGB_8888);
- - }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ops/Rot90OpInstrumentedTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ops/Rot90OpInstrumentedTest.java
- index eb54788764f1e..d00fe0e44422e 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ops/Rot90OpInstrumentedTest.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ops/Rot90OpInstrumentedTest.java
- @@ -19,7 +19,9 @@ import static com.google.common.truth.Truth.assertThat;
-
- import android.graphics.Bitmap;
- import android.graphics.PointF;
- +
- import androidx.test.ext.junit.runners.AndroidJUnit4;
- +
- import org.junit.Before;
- import org.junit.Test;
- import org.junit.runner.RunWith;
- @@ -30,68 +32,68 @@ import org.tensorflow.lite.support.image.TensorImage;
- /** Instrumented unit test for {@link Rot90Op}. */
- @RunWith(AndroidJUnit4.class)
- public class Rot90OpInstrumentedTest {
- + private Bitmap exampleBitmap;
- + private TensorImage input;
- +
- + private static final int EXAMPLE_WIDTH = 10;
- + private static final int EXAMPLE_HEIGHT = 15;
-
- - private Bitmap exampleBitmap;
- - private TensorImage input;
- -
- - private static final int EXAMPLE_WIDTH = 10;
- - private static final int EXAMPLE_HEIGHT = 15;
- -
- - @Before
- - public void setUp() {
- - exampleBitmap = createExampleBitmap();
- - input = new TensorImage(DataType.UINT8);
- - input.load(exampleBitmap);
- - }
- -
- - @Test
- - public void testRot90() {
- - ImageProcessor processor = new ImageProcessor.Builder().add(new Rot90Op()).build();
- - TensorImage output = processor.process(input);
- -
- - Bitmap outputBitmap = output.getBitmap();
- - assertThat(outputBitmap.getWidth()).isEqualTo(EXAMPLE_HEIGHT);
- - assertThat(outputBitmap.getHeight()).isEqualTo(EXAMPLE_WIDTH);
- - for (int i = 0; i < exampleBitmap.getWidth(); i++) {
- - for (int j = 0; j < exampleBitmap.getHeight(); j++) {
- - assertThat(exampleBitmap.getPixel(i, j))
- - .isEqualTo(outputBitmap.getPixel(j, EXAMPLE_WIDTH - 1 - i));
- - }
- + @Before
- + public void setUp() {
- + exampleBitmap = createExampleBitmap();
- + input = new TensorImage(DataType.UINT8);
- + input.load(exampleBitmap);
- }
- - }
- -
- - @Test
- - public void testRot90Twice() {
- - ImageProcessor processor = new ImageProcessor.Builder().add(new Rot90Op(2)).build();
- - TensorImage output = processor.process(input);
- -
- - Bitmap outputBitmap = output.getBitmap();
- - assertThat(outputBitmap.getWidth()).isEqualTo(EXAMPLE_WIDTH);
- - assertThat(outputBitmap.getHeight()).isEqualTo(EXAMPLE_HEIGHT);
- - for (int i = 0; i < exampleBitmap.getWidth(); i++) {
- - for (int j = 0; j < exampleBitmap.getHeight(); j++) {
- - assertThat(exampleBitmap.getPixel(i, j))
- - .isEqualTo(outputBitmap.getPixel(EXAMPLE_WIDTH - 1 - i, EXAMPLE_HEIGHT - 1 - j));
- - }
- +
- + @Test
- + public void testRot90() {
- + ImageProcessor processor = new ImageProcessor.Builder().add(new Rot90Op()).build();
- + TensorImage output = processor.process(input);
- +
- + Bitmap outputBitmap = output.getBitmap();
- + assertThat(outputBitmap.getWidth()).isEqualTo(EXAMPLE_HEIGHT);
- + assertThat(outputBitmap.getHeight()).isEqualTo(EXAMPLE_WIDTH);
- + for (int i = 0; i < exampleBitmap.getWidth(); i++) {
- + for (int j = 0; j < exampleBitmap.getHeight(); j++) {
- + assertThat(exampleBitmap.getPixel(i, j))
- + .isEqualTo(outputBitmap.getPixel(j, EXAMPLE_WIDTH - 1 - i));
- + }
- + }
- }
- - }
- -
- - @Test
- - public void inverseTransformCorrectlyWhenRotated() {
- - Rot90Op op = new Rot90Op(3);
- - PointF original = op.inverseTransform(new PointF(20, 10), 200, 100);
- - assertThat(original.x).isEqualTo(10);
- - assertThat(original.y).isEqualTo(180);
- - PointF outside = op.inverseTransform(new PointF(-10, 110), 200, 100);
- - assertThat(outside.x).isEqualTo(110);
- - assertThat(outside.y).isEqualTo(210);
- - }
- -
- - private static Bitmap createExampleBitmap() {
- - int[] colors = new int[EXAMPLE_WIDTH * EXAMPLE_HEIGHT];
- - for (int i = 0; i < EXAMPLE_WIDTH * EXAMPLE_HEIGHT; i++) {
- - colors[i] = (i << 16) | ((i + 1) << 8) | (i + 2);
- +
- + @Test
- + public void testRot90Twice() {
- + ImageProcessor processor = new ImageProcessor.Builder().add(new Rot90Op(2)).build();
- + TensorImage output = processor.process(input);
- +
- + Bitmap outputBitmap = output.getBitmap();
- + assertThat(outputBitmap.getWidth()).isEqualTo(EXAMPLE_WIDTH);
- + assertThat(outputBitmap.getHeight()).isEqualTo(EXAMPLE_HEIGHT);
- + for (int i = 0; i < exampleBitmap.getWidth(); i++) {
- + for (int j = 0; j < exampleBitmap.getHeight(); j++) {
- + assertThat(exampleBitmap.getPixel(i, j))
- + .isEqualTo(outputBitmap.getPixel(
- + EXAMPLE_WIDTH - 1 - i, EXAMPLE_HEIGHT - 1 - j));
- + }
- + }
- + }
- +
- + @Test
- + public void inverseTransformCorrectlyWhenRotated() {
- + Rot90Op op = new Rot90Op(3);
- + PointF original = op.inverseTransform(new PointF(20, 10), 200, 100);
- + assertThat(original.x).isEqualTo(10);
- + assertThat(original.y).isEqualTo(180);
- + PointF outside = op.inverseTransform(new PointF(-10, 110), 200, 100);
- + assertThat(outside.x).isEqualTo(110);
- + assertThat(outside.y).isEqualTo(210);
- + }
- +
- + private static Bitmap createExampleBitmap() {
- + int[] colors = new int[EXAMPLE_WIDTH * EXAMPLE_HEIGHT];
- + for (int i = 0; i < EXAMPLE_WIDTH * EXAMPLE_HEIGHT; i++) {
- + colors[i] = (i << 16) | ((i + 1) << 8) | (i + 2);
- + }
- + return Bitmap.createBitmap(colors, EXAMPLE_WIDTH, EXAMPLE_HEIGHT, Bitmap.Config.ARGB_8888);
- }
- - return Bitmap.createBitmap(colors, EXAMPLE_WIDTH, EXAMPLE_HEIGHT, Bitmap.Config.ARGB_8888);
- - }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ops/TransformToGrayScaleOpInstrumentedTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ops/TransformToGrayScaleOpInstrumentedTest.java
- index 46713fd486fa7..f024f68911d27 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ops/TransformToGrayScaleOpInstrumentedTest.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ops/TransformToGrayScaleOpInstrumentedTest.java
- @@ -16,6 +16,7 @@ limitations under the License.
- package org.tensorflow.lite.support.image.ops;
-
- import static com.google.common.truth.Truth.assertThat;
- +
- import static org.junit.Assert.assertThrows;
- import static org.mockito.Mockito.doReturn;
- import static org.tensorflow.lite.DataType.UINT8;
- @@ -24,7 +25,9 @@ import android.graphics.Bitmap;
- import android.graphics.Color;
- import android.graphics.ImageFormat;
- import android.media.Image;
- +
- import androidx.test.ext.junit.runners.AndroidJUnit4;
- +
- import org.junit.Before;
- import org.junit.Rule;
- import org.junit.Test;
- @@ -40,54 +43,55 @@ import org.tensorflow.lite.support.image.TensorImage;
- /** Instrumented unit test for {@link TransformToGrayscaleOp}. */
- @RunWith(AndroidJUnit4.class)
- public class TransformToGrayScaleOpInstrumentedTest {
- -
- - @Rule public final MockitoRule mockito = MockitoJUnit.rule();
- -
- - private TensorImage input;
- -
- - private static final int EXAMPLE_WIDTH = 2;
- - private static final int EXAMPLE_HEIGHT = 3;
- - @Mock Image imageMock;
- -
- - @Before
- - public void setUp() {
- - Bitmap exampleBitmap = createExampleBitmap();
- - input = new TensorImage(DataType.UINT8);
- - input.load(exampleBitmap);
- - }
- -
- - @Test
- - public void apply_onRgb_succeeds() {
- - ImageProcessor processor =
- - new ImageProcessor.Builder().add(new TransformToGrayscaleOp()).build();
- -
- - TensorImage output = processor.process(input);
- - int[] pixels = output.getTensorBuffer().getIntArray();
- -
- - assertThat(output.getWidth()).isEqualTo(EXAMPLE_WIDTH);
- - assertThat(output.getHeight()).isEqualTo(EXAMPLE_HEIGHT);
- - assertThat(output.getColorSpaceType()).isEqualTo(ColorSpaceType.GRAYSCALE);
- - assertThat(pixels).isEqualTo(new int[] {0, 255, 76, 29, 150, 179});
- - }
- -
- - @Test
- - public void apply_onYuv_throws() {
- - setUpImageMock(imageMock, ImageFormat.YUV_420_888);
- - TensorImage tensorImage = new TensorImage(UINT8);
- - tensorImage.load(imageMock);
- - ImageProcessor processor =
- - new ImageProcessor.Builder().add(new TransformToGrayscaleOp()).build();
- -
- - assertThrows(IllegalArgumentException.class, () -> processor.process(tensorImage));
- - }
- -
- - private static Bitmap createExampleBitmap() {
- - int[] colors =
- - new int[] {Color.BLACK, Color.WHITE, Color.RED, Color.BLUE, Color.GREEN, Color.CYAN};
- - return Bitmap.createBitmap(colors, EXAMPLE_WIDTH, EXAMPLE_HEIGHT, Bitmap.Config.ARGB_8888);
- - }
- -
- - private static void setUpImageMock(Image imageMock, int imageFormat) {
- - doReturn(imageFormat).when(imageMock).getFormat();
- - }
- + @Rule
- + public final MockitoRule mockito = MockitoJUnit.rule();
- +
- + private TensorImage input;
- +
- + private static final int EXAMPLE_WIDTH = 2;
- + private static final int EXAMPLE_HEIGHT = 3;
- + @Mock
- + Image imageMock;
- +
- + @Before
- + public void setUp() {
- + Bitmap exampleBitmap = createExampleBitmap();
- + input = new TensorImage(DataType.UINT8);
- + input.load(exampleBitmap);
- + }
- +
- + @Test
- + public void apply_onRgb_succeeds() {
- + ImageProcessor processor =
- + new ImageProcessor.Builder().add(new TransformToGrayscaleOp()).build();
- +
- + TensorImage output = processor.process(input);
- + int[] pixels = output.getTensorBuffer().getIntArray();
- +
- + assertThat(output.getWidth()).isEqualTo(EXAMPLE_WIDTH);
- + assertThat(output.getHeight()).isEqualTo(EXAMPLE_HEIGHT);
- + assertThat(output.getColorSpaceType()).isEqualTo(ColorSpaceType.GRAYSCALE);
- + assertThat(pixels).isEqualTo(new int[] {0, 255, 76, 29, 150, 179});
- + }
- +
- + @Test
- + public void apply_onYuv_throws() {
- + setUpImageMock(imageMock, ImageFormat.YUV_420_888);
- + TensorImage tensorImage = new TensorImage(UINT8);
- + tensorImage.load(imageMock);
- + ImageProcessor processor =
- + new ImageProcessor.Builder().add(new TransformToGrayscaleOp()).build();
- +
- + assertThrows(IllegalArgumentException.class, () -> processor.process(tensorImage));
- + }
- +
- + private static Bitmap createExampleBitmap() {
- + int[] colors = new int[] {
- + Color.BLACK, Color.WHITE, Color.RED, Color.BLUE, Color.GREEN, Color.CYAN};
- + return Bitmap.createBitmap(colors, EXAMPLE_WIDTH, EXAMPLE_HEIGHT, Bitmap.Config.ARGB_8888);
- + }
- +
- + private static void setUpImageMock(Image imageMock, int imageFormat) {
- + doReturn(imageFormat).when(imageMock).getFormat();
- + }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/label/CategoryTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/label/CategoryTest.java
- index 28620dd941e9c..98d1f92f56c6d 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/label/CategoryTest.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/label/CategoryTest.java
- @@ -24,114 +24,98 @@ import org.robolectric.RobolectricTestRunner;
- /** Tests of {@link org.tensorflow.lite.support.label.Category}. */
- @RunWith(RobolectricTestRunner.class)
- public final class CategoryTest {
- - private static final String APPLE_LABEL = "apple";
- - private static final String DEFAULT_DISPLAY_NAME = "";
- - private static final String APPLE_DISPLAY_NAME = "manzana"; // "apple" in Spanish.
- - private static final float APPLE_SCORE = 0.5f;
- - private static final int APPLE_INDEX = 10;
- -
- - @Test
- - public void createShouldSucceed() {
- - Category category = Category.create(APPLE_LABEL, APPLE_DISPLAY_NAME, APPLE_SCORE);
- -
- - assertThat(category.getLabel()).isEqualTo(APPLE_LABEL);
- - assertThat(category.getDisplayName()).isEqualTo(APPLE_DISPLAY_NAME);
- - assertThat(category.getScore()).isWithin(1e-7f).of(APPLE_SCORE);
- - }
- -
- - @Test
- - public void createWithIndexShouldSucceed() {
- - Category category = Category.create(APPLE_LABEL, APPLE_DISPLAY_NAME, APPLE_SCORE, APPLE_INDEX);
- -
- - assertThat(category.getLabel()).isEqualTo(APPLE_LABEL);
- - assertThat(category.getDisplayName()).isEqualTo(APPLE_DISPLAY_NAME);
- - assertThat(category.getScore()).isWithin(1e-7f).of(APPLE_SCORE);
- - assertThat(category.getIndex()).isEqualTo(APPLE_INDEX);
- - }
- -
- - @Test
- - public void constructorShouldSucceed() {
- - Category category = new Category(APPLE_LABEL, APPLE_SCORE);
- -
- - assertThat(category.getLabel()).isEqualTo(APPLE_LABEL);
- - // Using the constructor, displayName will be default to an empty string.
- - assertThat(category.getDisplayName()).isEqualTo(DEFAULT_DISPLAY_NAME);
- - assertThat(category.getScore()).isWithin(1e-7f).of(APPLE_SCORE);
- - }
- -
- - @Test
- - public void toStringWithCreateShouldProvideReadableResult() {
- - Category category = Category.create(APPLE_LABEL, APPLE_DISPLAY_NAME, APPLE_SCORE);
- - String categoryString = category.toString();
- -
- - assertThat(categoryString)
- - .isEqualTo(
- - "<Category \""
- - + APPLE_LABEL
- - + "\" (displayName="
- - + APPLE_DISPLAY_NAME
- - + " score="
- - + APPLE_SCORE
- - + " index=-1"
- - + ")>");
- - }
- -
- - @Test
- - public void toStringWithCreateIndexShouldProvideReadableResult() {
- - Category category = Category.create(APPLE_LABEL, APPLE_DISPLAY_NAME, APPLE_SCORE, APPLE_INDEX);
- - String categoryString = category.toString();
- -
- - assertThat(categoryString)
- - .isEqualTo(
- - "<Category \""
- - + APPLE_LABEL
- - + "\" (displayName="
- - + APPLE_DISPLAY_NAME
- - + " score="
- - + APPLE_SCORE
- - + " index="
- - + APPLE_INDEX
- - + ")>");
- - }
- -
- - @Test
- - public void toStringWithConstuctorShouldProvideReadableResult() {
- - Category category = new Category(APPLE_LABEL, APPLE_SCORE);
- - String categoryString = category.toString();
- -
- - assertThat(categoryString)
- - .isEqualTo(
- - "<Category \""
- - + APPLE_LABEL
- - + "\" (displayName="
- - + DEFAULT_DISPLAY_NAME
- - + " score="
- - + APPLE_SCORE
- - + " index=-1"
- - + ")>");
- - }
- -
- - @Test
- - public void equalsShouldSucceedWithCreate() {
- - Category categoryA = Category.create(APPLE_LABEL, APPLE_DISPLAY_NAME, APPLE_SCORE);
- - Category categoryB = Category.create(APPLE_LABEL, APPLE_DISPLAY_NAME, APPLE_SCORE);
- -
- - assertThat(categoryA).isEqualTo(categoryB);
- - }
- -
- - @Test
- - public void equalsShouldSucceedWithCreateIndex() {
- - Category categoryA = Category.create(APPLE_LABEL, APPLE_DISPLAY_NAME, APPLE_SCORE, APPLE_INDEX);
- - Category categoryB = Category.create(APPLE_LABEL, APPLE_DISPLAY_NAME, APPLE_SCORE, APPLE_INDEX);
- -
- - assertThat(categoryA).isEqualTo(categoryB);
- - }
- -
- - @Test
- - public void equalsShouldSucceedWithConstructor() {
- - Category categoryA = new Category(APPLE_LABEL, APPLE_SCORE);
- - Category categoryB = new Category(APPLE_LABEL, APPLE_SCORE);
- -
- - assertThat(categoryA).isEqualTo(categoryB);
- - }
- + private static final String APPLE_LABEL = "apple";
- + private static final String DEFAULT_DISPLAY_NAME = "";
- + private static final String APPLE_DISPLAY_NAME = "manzana"; // "apple" in Spanish.
- + private static final float APPLE_SCORE = 0.5f;
- + private static final int APPLE_INDEX = 10;
- +
- + @Test
- + public void createShouldSucceed() {
- + Category category = Category.create(APPLE_LABEL, APPLE_DISPLAY_NAME, APPLE_SCORE);
- +
- + assertThat(category.getLabel()).isEqualTo(APPLE_LABEL);
- + assertThat(category.getDisplayName()).isEqualTo(APPLE_DISPLAY_NAME);
- + assertThat(category.getScore()).isWithin(1e-7f).of(APPLE_SCORE);
- + }
- +
- + @Test
- + public void createWithIndexShouldSucceed() {
- + Category category =
- + Category.create(APPLE_LABEL, APPLE_DISPLAY_NAME, APPLE_SCORE, APPLE_INDEX);
- +
- + assertThat(category.getLabel()).isEqualTo(APPLE_LABEL);
- + assertThat(category.getDisplayName()).isEqualTo(APPLE_DISPLAY_NAME);
- + assertThat(category.getScore()).isWithin(1e-7f).of(APPLE_SCORE);
- + assertThat(category.getIndex()).isEqualTo(APPLE_INDEX);
- + }
- +
- + @Test
- + public void constructorShouldSucceed() {
- + Category category = new Category(APPLE_LABEL, APPLE_SCORE);
- +
- + assertThat(category.getLabel()).isEqualTo(APPLE_LABEL);
- + // Using the constructor, displayName will be default to an empty string.
- + assertThat(category.getDisplayName()).isEqualTo(DEFAULT_DISPLAY_NAME);
- + assertThat(category.getScore()).isWithin(1e-7f).of(APPLE_SCORE);
- + }
- +
- + @Test
- + public void toStringWithCreateShouldProvideReadableResult() {
- + Category category = Category.create(APPLE_LABEL, APPLE_DISPLAY_NAME, APPLE_SCORE);
- + String categoryString = category.toString();
- +
- + assertThat(categoryString)
- + .isEqualTo("<Category \"" + APPLE_LABEL + "\" (displayName=" + APPLE_DISPLAY_NAME
- + + " score=" + APPLE_SCORE + " index=-1"
- + + ")>");
- + }
- +
- + @Test
- + public void toStringWithCreateIndexShouldProvideReadableResult() {
- + Category category =
- + Category.create(APPLE_LABEL, APPLE_DISPLAY_NAME, APPLE_SCORE, APPLE_INDEX);
- + String categoryString = category.toString();
- +
- + assertThat(categoryString)
- + .isEqualTo("<Category \"" + APPLE_LABEL + "\" (displayName=" + APPLE_DISPLAY_NAME
- + + " score=" + APPLE_SCORE + " index=" + APPLE_INDEX + ")>");
- + }
- +
- + @Test
- + public void toStringWithConstuctorShouldProvideReadableResult() {
- + Category category = new Category(APPLE_LABEL, APPLE_SCORE);
- + String categoryString = category.toString();
- +
- + assertThat(categoryString)
- + .isEqualTo("<Category \"" + APPLE_LABEL + "\" (displayName=" + DEFAULT_DISPLAY_NAME
- + + " score=" + APPLE_SCORE + " index=-1"
- + + ")>");
- + }
- +
- + @Test
- + public void equalsShouldSucceedWithCreate() {
- + Category categoryA = Category.create(APPLE_LABEL, APPLE_DISPLAY_NAME, APPLE_SCORE);
- + Category categoryB = Category.create(APPLE_LABEL, APPLE_DISPLAY_NAME, APPLE_SCORE);
- +
- + assertThat(categoryA).isEqualTo(categoryB);
- + }
- +
- + @Test
- + public void equalsShouldSucceedWithCreateIndex() {
- + Category categoryA =
- + Category.create(APPLE_LABEL, APPLE_DISPLAY_NAME, APPLE_SCORE, APPLE_INDEX);
- + Category categoryB =
- + Category.create(APPLE_LABEL, APPLE_DISPLAY_NAME, APPLE_SCORE, APPLE_INDEX);
- +
- + assertThat(categoryA).isEqualTo(categoryB);
- + }
- +
- + @Test
- + public void equalsShouldSucceedWithConstructor() {
- + Category categoryA = new Category(APPLE_LABEL, APPLE_SCORE);
- + Category categoryB = new Category(APPLE_LABEL, APPLE_SCORE);
- +
- + assertThat(categoryA).isEqualTo(categoryB);
- + }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/label/LabelUtilTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/label/LabelUtilTest.java
- index caa468bb0a9ec..91c81c4932b81 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/label/LabelUtilTest.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/label/LabelUtilTest.java
- @@ -17,35 +17,38 @@ package org.tensorflow.lite.support.label;
-
- import static com.google.common.truth.Truth.assertThat;
-
- -import java.util.Arrays;
- -import java.util.List;
- import org.junit.Test;
- import org.junit.runner.RunWith;
- import org.robolectric.RobolectricTestRunner;
- import org.tensorflow.lite.DataType;
- import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
-
- +import java.util.Arrays;
- +import java.util.List;
- +
- /** Tests of {@link org.tensorflow.lite.support.label.LabelUtil}. */
- @RunWith(RobolectricTestRunner.class)
- public class LabelUtilTest {
- -
- - @Test
- - public void mapIndexToStringsWithInvalidValues() {
- - String[] labels = new String[] {"background", "apple", "banana", "cherry", "date"};
- - TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.UINT8);
- - tensorBuffer.loadArray(new int[] {0, 1, 2, 3, 2, 5}, new int[] {1, 6});
- - List<String> categories = LabelUtil.mapValueToLabels(tensorBuffer, Arrays.asList(labels), 1);
- - assertThat(categories.toArray())
- - .isEqualTo(new String[] {"apple", "banana", "cherry", "date", "cherry", ""});
- - }
- -
- - @Test
- - public void mapFloatIndexShouldCast() {
- - String[] labels = new String[] {"background", "apple", "banana", "cherry", "date"};
- - TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.FLOAT32);
- - tensorBuffer.loadArray(new float[] {-1.1f, -0.3f, 0.3f, 1.2f, 1.8f, 1}, new int[] {1, 6});
- - List<String> categories = LabelUtil.mapValueToLabels(tensorBuffer, Arrays.asList(labels), 1);
- - assertThat(categories.toArray())
- - .isEqualTo(new String[] {"background", "apple", "apple", "banana", "banana", "banana"});
- - }
- + @Test
- + public void mapIndexToStringsWithInvalidValues() {
- + String[] labels = new String[] {"background", "apple", "banana", "cherry", "date"};
- + TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.UINT8);
- + tensorBuffer.loadArray(new int[] {0, 1, 2, 3, 2, 5}, new int[] {1, 6});
- + List<String> categories =
- + LabelUtil.mapValueToLabels(tensorBuffer, Arrays.asList(labels), 1);
- + assertThat(categories.toArray())
- + .isEqualTo(new String[] {"apple", "banana", "cherry", "date", "cherry", ""});
- + }
- +
- + @Test
- + public void mapFloatIndexShouldCast() {
- + String[] labels = new String[] {"background", "apple", "banana", "cherry", "date"};
- + TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.FLOAT32);
- + tensorBuffer.loadArray(new float[] {-1.1f, -0.3f, 0.3f, 1.2f, 1.8f, 1}, new int[] {1, 6});
- + List<String> categories =
- + LabelUtil.mapValueToLabels(tensorBuffer, Arrays.asList(labels), 1);
- + assertThat(categories.toArray())
- + .isEqualTo(new String[] {
- + "background", "apple", "apple", "banana", "banana", "banana"});
- + }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/label/TensorLabelTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/label/TensorLabelTest.java
- index 4f296b7476c2d..857a77a2a4bd4 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/label/TensorLabelTest.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/label/TensorLabelTest.java
- @@ -17,10 +17,6 @@ package org.tensorflow.lite.support.label;
-
- import static com.google.common.truth.Truth.assertThat;
-
- -import java.util.Arrays;
- -import java.util.HashMap;
- -import java.util.List;
- -import java.util.Map;
- import org.junit.Assert;
- import org.junit.Test;
- import org.junit.runner.RunWith;
- @@ -28,169 +24,180 @@ import org.robolectric.RobolectricTestRunner;
- import org.tensorflow.lite.DataType;
- import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
-
- +import java.util.Arrays;
- +import java.util.HashMap;
- +import java.util.List;
- +import java.util.Map;
- +
- /** Tests of {@link org.tensorflow.lite.support.label.TensorLabel}. */
- @RunWith(RobolectricTestRunner.class)
- public final class TensorLabelTest {
- - @Test
- - public void createTensorLabelWithNullAxisLabelsShouldFail() {
- - int[] shape = {2};
- - int[] arr = {1, 2};
- - TensorBuffer buffer = TensorBuffer.createFixedSize(shape, DataType.UINT8);
- - buffer.loadArray(arr, shape);
- - Map<Integer, List<String>> nullAxisLabels = null;
- -
- - Assert.assertThrows(NullPointerException.class, () -> new TensorLabel(nullAxisLabels, buffer));
- - }
- -
- - @Test
- - public void createTensorLabelWithNullTensorBufferShouldFail() {
- - Map<Integer, List<String>> axisLabels = new HashMap<>();
- - axisLabels.put(1, Arrays.asList("a", "b", "c", "d"));
- - TensorBuffer nullTensorBuffer = null;
- -
- - Assert.assertThrows(
- - NullPointerException.class, () -> new TensorLabel(axisLabels, nullTensorBuffer));
- - }
- -
- - @Test
- - public void createTensorLabelWithStringListShouldSuccess() {
- - TensorBuffer buffer = TensorBuffer.createFixedSize(new int[] {1, 4, 3}, DataType.FLOAT32);
- -
- - TensorLabel tensorLabel = new TensorLabel(Arrays.asList("a", "b", "c", "d"), buffer);
- -
- - assertThat(tensorLabel.getMapWithTensorBuffer()).isNotNull();
- - assertThat(tensorLabel.getMapWithTensorBuffer().keySet()).contains("c"); // randomly pick one
- - }
- -
- - @Test
- - public void createTensorLabelWithEmptyShapeShouldFail() {
- - int[] shape = new int[] {};
- - TensorBuffer buffer = TensorBuffer.createFixedSize(shape, DataType.FLOAT32);
- - Map<Integer, List<String>> axisLabels = new HashMap<>();
- - axisLabels.put(1, Arrays.asList("a", "b", "c", "d"));
- -
- - Assert.assertThrows(IllegalArgumentException.class, () -> new TensorLabel(axisLabels, buffer));
- - }
- -
- - @Test
- - public void createTensorLabelWithMismatchedAxisShouldFail() {
- - int[] shape = {1, 4};
- - TensorBuffer buffer = TensorBuffer.createFixedSize(shape, DataType.FLOAT32);
- - Map<Integer, List<String>> axisLabels = new HashMap<>();
- - axisLabels.put(0, Arrays.asList("a", "b", "c", "d"));
- -
- - Assert.assertThrows(IllegalArgumentException.class, () -> new TensorLabel(axisLabels, buffer));
- - }
- -
- - @Test
- - public void createTensorLabelWithMismatchedShapeShouldFail() {
- - int[] shape = {1, 3};
- - TensorBuffer buffer = TensorBuffer.createFixedSize(shape, DataType.FLOAT32);
- - Map<Integer, List<String>> axisLabels = new HashMap<>();
- - axisLabels.put(1, Arrays.asList("a", "b", "c", "d"));
- -
- - Assert.assertThrows(IllegalArgumentException.class, () -> new TensorLabel(axisLabels, buffer));
- - }
- -
- - @Test
- - public void getMapWithFloatBufferValuesShouldSuccess() {
- - int numberLabel = 4;
- - float[] inputArr = {0.5f, 0.2f, 0.2f, 0.1f};
- - int[] shape = {1, numberLabel};
- - TensorBuffer input = TensorBuffer.createFixedSize(shape, DataType.FLOAT32);
- - input.loadArray(inputArr, shape);
- - Map<Integer, List<String>> axisLabels = new HashMap<>();
- - int labelAxis = 1;
- - axisLabels.put(labelAxis, Arrays.asList("a", "b", "c", "d"));
- -
- - TensorLabel tensorLabeled = new TensorLabel(axisLabels, input);
- - Map<String, TensorBuffer> map = tensorLabeled.getMapWithTensorBuffer();
- -
- - for (int i = 0; i < numberLabel; i++) {
- - String label = axisLabels.get(labelAxis).get(i);
- - assertThat(map).containsKey(label);
- - float[] array = map.get(label).getFloatArray();
- - assertThat(array).hasLength(1);
- - assertThat(array[0]).isEqualTo(inputArr[i]);
- + @Test
- + public void createTensorLabelWithNullAxisLabelsShouldFail() {
- + int[] shape = {2};
- + int[] arr = {1, 2};
- + TensorBuffer buffer = TensorBuffer.createFixedSize(shape, DataType.UINT8);
- + buffer.loadArray(arr, shape);
- + Map<Integer, List<String>> nullAxisLabels = null;
- +
- + Assert.assertThrows(
- + NullPointerException.class, () -> new TensorLabel(nullAxisLabels, buffer));
- + }
- +
- + @Test
- + public void createTensorLabelWithNullTensorBufferShouldFail() {
- + Map<Integer, List<String>> axisLabels = new HashMap<>();
- + axisLabels.put(1, Arrays.asList("a", "b", "c", "d"));
- + TensorBuffer nullTensorBuffer = null;
- +
- + Assert.assertThrows(
- + NullPointerException.class, () -> new TensorLabel(axisLabels, nullTensorBuffer));
- + }
- +
- + @Test
- + public void createTensorLabelWithStringListShouldSuccess() {
- + TensorBuffer buffer = TensorBuffer.createFixedSize(new int[] {1, 4, 3}, DataType.FLOAT32);
- +
- + TensorLabel tensorLabel = new TensorLabel(Arrays.asList("a", "b", "c", "d"), buffer);
- +
- + assertThat(tensorLabel.getMapWithTensorBuffer()).isNotNull();
- + assertThat(tensorLabel.getMapWithTensorBuffer().keySet())
- + .contains("c"); // randomly pick one
- + }
- +
- + @Test
- + public void createTensorLabelWithEmptyShapeShouldFail() {
- + int[] shape = new int[] {};
- + TensorBuffer buffer = TensorBuffer.createFixedSize(shape, DataType.FLOAT32);
- + Map<Integer, List<String>> axisLabels = new HashMap<>();
- + axisLabels.put(1, Arrays.asList("a", "b", "c", "d"));
- +
- + Assert.assertThrows(
- + IllegalArgumentException.class, () -> new TensorLabel(axisLabels, buffer));
- }
- - }
- -
- - @Test
- - public void getMapWithIntBufferValuesShouldSuccess() {
- - int numberLabel = 3;
- - int[] inputArr = {1, 2, 0};
- - int[] shape = {1, 1, numberLabel};
- - TensorBuffer input = TensorBuffer.createFixedSize(shape, DataType.UINT8);
- - input.loadArray(inputArr, shape);
- - Map<Integer, List<String>> axisLabels = new HashMap<>();
- - int labelAxis = 2;
- - axisLabels.put(labelAxis, Arrays.asList("x", "y", "z"));
- -
- - TensorLabel tensorLabeled = new TensorLabel(axisLabels, input);
- - Map<String, TensorBuffer> map = tensorLabeled.getMapWithTensorBuffer();
- -
- - for (int i = 0; i < numberLabel; i++) {
- - String label = axisLabels.get(labelAxis).get(i);
- - assertThat(map).containsKey(label);
- - int[] array = map.get(label).getIntArray();
- - assertThat(array).hasLength(1);
- - assertThat(array[0]).isEqualTo(inputArr[i]);
- +
- + @Test
- + public void createTensorLabelWithMismatchedAxisShouldFail() {
- + int[] shape = {1, 4};
- + TensorBuffer buffer = TensorBuffer.createFixedSize(shape, DataType.FLOAT32);
- + Map<Integer, List<String>> axisLabels = new HashMap<>();
- + axisLabels.put(0, Arrays.asList("a", "b", "c", "d"));
- +
- + Assert.assertThrows(
- + IllegalArgumentException.class, () -> new TensorLabel(axisLabels, buffer));
- }
- - }
- -
- - @Test
- - public void getFloatMapShouldSuccess() {
- - int[] shape = {1, 3};
- - TensorBuffer buffer = TensorBuffer.createFixedSize(shape, DataType.FLOAT32);
- - buffer.loadArray(new float[] {1.0f, 2.0f, 3.0f});
- -
- - TensorLabel tensorLabeled = new TensorLabel(Arrays.asList("a", "b", "c"), buffer);
- - Map<String, Float> map = tensorLabeled.getMapWithFloatValue();
- -
- - assertThat(map).hasSize(3);
- - assertThat(map).containsEntry("a", 1.0f);
- - assertThat(map).containsEntry("b", 2.0f);
- - assertThat(map).containsEntry("c", 3.0f);
- - }
- -
- - @Test
- - public void getMapFromMultiDimensionalTensorBufferShouldSuccess() {
- - int numberLabel = 2;
- - int numDim = 3;
- - float[] inputArr = {0.5f, 0.1f, 0.3f, 0.2f, 0.2f, 0.1f};
- - int[] shape = {numberLabel, numDim};
- - TensorBuffer input = TensorBuffer.createFixedSize(shape, DataType.FLOAT32);
- - input.loadArray(inputArr, shape);
- - Map<Integer, List<String>> axisLabels = new HashMap<>();
- - int labelAxis = 0;
- - axisLabels.put(labelAxis, Arrays.asList("pos", "neg"));
- -
- - TensorLabel tensorLabeled = new TensorLabel(axisLabels, input);
- - Map<String, TensorBuffer> map = tensorLabeled.getMapWithTensorBuffer();
- -
- - for (int i = 0; i < numberLabel; i++) {
- - String label = axisLabels.get(labelAxis).get(i);
- - assertThat(map).containsKey(label);
- -
- - float[] array = map.get(label).getFloatArray();
- - assertThat(array).hasLength(numDim);
- - for (int j = 0; j < numDim; j++) {
- - assertThat(array[j]).isEqualTo(inputArr[i * numDim + j]);
- - }
- +
- + @Test
- + public void createTensorLabelWithMismatchedShapeShouldFail() {
- + int[] shape = {1, 3};
- + TensorBuffer buffer = TensorBuffer.createFixedSize(shape, DataType.FLOAT32);
- + Map<Integer, List<String>> axisLabels = new HashMap<>();
- + axisLabels.put(1, Arrays.asList("a", "b", "c", "d"));
- +
- + Assert.assertThrows(
- + IllegalArgumentException.class, () -> new TensorLabel(axisLabels, buffer));
- + }
- +
- + @Test
- + public void getMapWithFloatBufferValuesShouldSuccess() {
- + int numberLabel = 4;
- + float[] inputArr = {0.5f, 0.2f, 0.2f, 0.1f};
- + int[] shape = {1, numberLabel};
- + TensorBuffer input = TensorBuffer.createFixedSize(shape, DataType.FLOAT32);
- + input.loadArray(inputArr, shape);
- + Map<Integer, List<String>> axisLabels = new HashMap<>();
- + int labelAxis = 1;
- + axisLabels.put(labelAxis, Arrays.asList("a", "b", "c", "d"));
- +
- + TensorLabel tensorLabeled = new TensorLabel(axisLabels, input);
- + Map<String, TensorBuffer> map = tensorLabeled.getMapWithTensorBuffer();
- +
- + for (int i = 0; i < numberLabel; i++) {
- + String label = axisLabels.get(labelAxis).get(i);
- + assertThat(map).containsKey(label);
- + float[] array = map.get(label).getFloatArray();
- + assertThat(array).hasLength(1);
- + assertThat(array[0]).isEqualTo(inputArr[i]);
- + }
- }
- - }
-
- - @Test
- - public void getCategoryListShouldSuccess() {
- - int[] shape = {1, 3};
- - TensorBuffer buffer = TensorBuffer.createFixedSize(shape, DataType.FLOAT32);
- - buffer.loadArray(new float[] {1.0f, 2.0f, 3.0f});
- + @Test
- + public void getMapWithIntBufferValuesShouldSuccess() {
- + int numberLabel = 3;
- + int[] inputArr = {1, 2, 0};
- + int[] shape = {1, 1, numberLabel};
- + TensorBuffer input = TensorBuffer.createFixedSize(shape, DataType.UINT8);
- + input.loadArray(inputArr, shape);
- + Map<Integer, List<String>> axisLabels = new HashMap<>();
- + int labelAxis = 2;
- + axisLabels.put(labelAxis, Arrays.asList("x", "y", "z"));
- +
- + TensorLabel tensorLabeled = new TensorLabel(axisLabels, input);
- + Map<String, TensorBuffer> map = tensorLabeled.getMapWithTensorBuffer();
- +
- + for (int i = 0; i < numberLabel; i++) {
- + String label = axisLabels.get(labelAxis).get(i);
- + assertThat(map).containsKey(label);
- + int[] array = map.get(label).getIntArray();
- + assertThat(array).hasLength(1);
- + assertThat(array[0]).isEqualTo(inputArr[i]);
- + }
- + }
-
- - TensorLabel tensorLabeled = new TensorLabel(Arrays.asList("a", "b", "c"), buffer);
- - List<Category> categories = tensorLabeled.getCategoryList();
- + @Test
- + public void getFloatMapShouldSuccess() {
- + int[] shape = {1, 3};
- + TensorBuffer buffer = TensorBuffer.createFixedSize(shape, DataType.FLOAT32);
- + buffer.loadArray(new float[] {1.0f, 2.0f, 3.0f});
-
- - assertThat(categories).hasSize(3);
- - assertThat(categories)
- - .containsExactly(new Category("a", 1.0f), new Category("b", 2.0f), new Category("c", 3.0f));
- - }
- + TensorLabel tensorLabeled = new TensorLabel(Arrays.asList("a", "b", "c"), buffer);
- + Map<String, Float> map = tensorLabeled.getMapWithFloatValue();
- +
- + assertThat(map).hasSize(3);
- + assertThat(map).containsEntry("a", 1.0f);
- + assertThat(map).containsEntry("b", 2.0f);
- + assertThat(map).containsEntry("c", 3.0f);
- + }
- +
- + @Test
- + public void getMapFromMultiDimensionalTensorBufferShouldSuccess() {
- + int numberLabel = 2;
- + int numDim = 3;
- + float[] inputArr = {0.5f, 0.1f, 0.3f, 0.2f, 0.2f, 0.1f};
- + int[] shape = {numberLabel, numDim};
- + TensorBuffer input = TensorBuffer.createFixedSize(shape, DataType.FLOAT32);
- + input.loadArray(inputArr, shape);
- + Map<Integer, List<String>> axisLabels = new HashMap<>();
- + int labelAxis = 0;
- + axisLabels.put(labelAxis, Arrays.asList("pos", "neg"));
- +
- + TensorLabel tensorLabeled = new TensorLabel(axisLabels, input);
- + Map<String, TensorBuffer> map = tensorLabeled.getMapWithTensorBuffer();
- +
- + for (int i = 0; i < numberLabel; i++) {
- + String label = axisLabels.get(labelAxis).get(i);
- + assertThat(map).containsKey(label);
- +
- + float[] array = map.get(label).getFloatArray();
- + assertThat(array).hasLength(numDim);
- + for (int j = 0; j < numDim; j++) {
- + assertThat(array[j]).isEqualTo(inputArr[i * numDim + j]);
- + }
- + }
- + }
- +
- + @Test
- + public void getCategoryListShouldSuccess() {
- + int[] shape = {1, 3};
- + TensorBuffer buffer = TensorBuffer.createFixedSize(shape, DataType.FLOAT32);
- + buffer.loadArray(new float[] {1.0f, 2.0f, 3.0f});
- +
- + TensorLabel tensorLabeled = new TensorLabel(Arrays.asList("a", "b", "c"), buffer);
- + List<Category> categories = tensorLabeled.getCategoryList();
- +
- + assertThat(categories).hasSize(3);
- + assertThat(categories)
- + .containsExactly(
- + new Category("a", 1.0f), new Category("b", 2.0f), new Category("c", 3.0f));
- + }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/label/ops/LabelAxisOpTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/label/ops/LabelAxisOpTest.java
- index 8fa8860a09ef5..c1afe99f34f34 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/label/ops/LabelAxisOpTest.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/label/ops/LabelAxisOpTest.java
- @@ -18,11 +18,9 @@ package org.tensorflow.lite.support.label.ops;
- import static com.google.common.truth.Truth.assertThat;
-
- import android.content.Context;
- +
- import androidx.test.core.app.ApplicationProvider;
- -import java.io.IOException;
- -import java.util.Arrays;
- -import java.util.List;
- -import java.util.Map;
- +
- import org.junit.Test;
- import org.junit.runner.RunWith;
- import org.robolectric.RobolectricTestRunner;
- @@ -31,90 +29,94 @@ import org.tensorflow.lite.support.common.FileUtil;
- import org.tensorflow.lite.support.label.TensorLabel;
- import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
-
- +import java.io.IOException;
- +import java.util.Arrays;
- +import java.util.List;
- +import java.util.Map;
- +
- /** Tests of {@link org.tensorflow.lite.support.label.ops.LabelAxisOp}. */
- @RunWith(RobolectricTestRunner.class)
- public final class LabelAxisOpTest {
- + private final Context context = ApplicationProvider.getApplicationContext();
- + private static final String LABEL_PATH = "flower_labels.txt";
- +
- + @Test
- + public void testAddAxisLabelByStringList() {
- + int numberLabel = 2;
- + float[] inputArr = {0.7f, 0.3f};
- +
- + int[] shape = {numberLabel};
- + TensorBuffer input = TensorBuffer.createFixedSize(shape, DataType.FLOAT32);
- + input.loadArray(inputArr, shape);
- +
- + List<String> labels = Arrays.asList("pos", "neg");
- + LabelAxisOp op = new LabelAxisOp.Builder().addAxisLabel(0, labels).build();
- + TensorLabel output = op.apply(input);
- + Map<String, TensorBuffer> map = output.getMapWithTensorBuffer();
- +
- + assertThat(map).containsKey("pos");
- + float[] array = map.get("pos").getFloatArray();
- + assertThat(array).hasLength(1);
- + assertThat(array[0]).isEqualTo(0.7f);
- +
- + assertThat(map).containsKey("neg");
- + array = map.get("neg").getFloatArray();
- + assertThat(array).hasLength(1);
- + assertThat(array[0]).isEqualTo(0.3f);
- + }
- +
- + @Test
- + public void testAddAxisLabelWithMultiDimensionTensor() throws IOException {
- + int numberLabel = 2;
- + int numDim = 3;
- + float[] inputArr = {0.5f, 0.1f, 0.3f, 0.2f, 0.2f, 0.1f};
- +
- + int[] shape = {1, numberLabel, numDim};
- + TensorBuffer input = TensorBuffer.createFixedSize(shape, DataType.FLOAT32);
- + input.loadArray(inputArr, shape);
-
- - private final Context context = ApplicationProvider.getApplicationContext();
- - private static final String LABEL_PATH = "flower_labels.txt";
- -
- - @Test
- - public void testAddAxisLabelByStringList() {
- - int numberLabel = 2;
- - float[] inputArr = {0.7f, 0.3f};
- -
- - int[] shape = {numberLabel};
- - TensorBuffer input = TensorBuffer.createFixedSize(shape, DataType.FLOAT32);
- - input.loadArray(inputArr, shape);
- -
- - List<String> labels = Arrays.asList("pos", "neg");
- - LabelAxisOp op = new LabelAxisOp.Builder().addAxisLabel(0, labels).build();
- - TensorLabel output = op.apply(input);
- - Map<String, TensorBuffer> map = output.getMapWithTensorBuffer();
- -
- - assertThat(map).containsKey("pos");
- - float[] array = map.get("pos").getFloatArray();
- - assertThat(array).hasLength(1);
- - assertThat(array[0]).isEqualTo(0.7f);
- -
- - assertThat(map).containsKey("neg");
- - array = map.get("neg").getFloatArray();
- - assertThat(array).hasLength(1);
- - assertThat(array[0]).isEqualTo(0.3f);
- - }
- -
- - @Test
- - public void testAddAxisLabelWithMultiDimensionTensor() throws IOException {
- - int numberLabel = 2;
- - int numDim = 3;
- - float[] inputArr = {0.5f, 0.1f, 0.3f, 0.2f, 0.2f, 0.1f};
- -
- - int[] shape = {1, numberLabel, numDim};
- - TensorBuffer input = TensorBuffer.createFixedSize(shape, DataType.FLOAT32);
- - input.loadArray(inputArr, shape);
- -
- - List<String> labels = Arrays.asList("pos", "neg");
- - LabelAxisOp op = new LabelAxisOp.Builder().addAxisLabel(1, labels).build();
- -
- - TensorLabel output = op.apply(input);
- - Map<String, TensorBuffer> map = output.getMapWithTensorBuffer();
- -
- - assertThat(map).containsKey("pos");
- - float[] array = map.get("pos").getFloatArray();
- - assertThat(array).hasLength(numDim);
- - assertThat(array).isEqualTo(new float[] {0.5f, 0.1f, 0.3f});
- -
- - assertThat(map).containsKey("neg");
- - array = map.get("neg").getFloatArray();
- - assertThat(array).hasLength(numDim);
- - assertThat(array).isEqualTo(new float[] {0.2f, 0.2f, 0.1f});
- - }
- -
- - @Test
- - public void testAddAxisLabelByFilePath() throws IOException {
- - int numberLabel = 5;
- - int[] inputArr = new int[numberLabel];
- - for (int i = 0; i < numberLabel; i++) {
- - inputArr[i] = i;
- + List<String> labels = Arrays.asList("pos", "neg");
- + LabelAxisOp op = new LabelAxisOp.Builder().addAxisLabel(1, labels).build();
- +
- + TensorLabel output = op.apply(input);
- + Map<String, TensorBuffer> map = output.getMapWithTensorBuffer();
- +
- + assertThat(map).containsKey("pos");
- + float[] array = map.get("pos").getFloatArray();
- + assertThat(array).hasLength(numDim);
- + assertThat(array).isEqualTo(new float[] {0.5f, 0.1f, 0.3f});
- +
- + assertThat(map).containsKey("neg");
- + array = map.get("neg").getFloatArray();
- + assertThat(array).hasLength(numDim);
- + assertThat(array).isEqualTo(new float[] {0.2f, 0.2f, 0.1f});
- }
-
- - int[] shape = {numberLabel};
- - TensorBuffer input = TensorBuffer.createFixedSize(shape, DataType.UINT8);
- - input.loadArray(inputArr, shape);
- + @Test
- + public void testAddAxisLabelByFilePath() throws IOException {
- + int numberLabel = 5;
- + int[] inputArr = new int[numberLabel];
- + for (int i = 0; i < numberLabel; i++) {
- + inputArr[i] = i;
- + }
- +
- + int[] shape = {numberLabel};
- + TensorBuffer input = TensorBuffer.createFixedSize(shape, DataType.UINT8);
- + input.loadArray(inputArr, shape);
-
- - LabelAxisOp op = new LabelAxisOp.Builder().addAxisLabel(context, 0, LABEL_PATH).build();
- - TensorLabel output = op.apply(input);
- - Map<String, TensorBuffer> map = output.getMapWithTensorBuffer();
- + LabelAxisOp op = new LabelAxisOp.Builder().addAxisLabel(context, 0, LABEL_PATH).build();
- + TensorLabel output = op.apply(input);
- + Map<String, TensorBuffer> map = output.getMapWithTensorBuffer();
-
- - List<String> labels = FileUtil.loadLabels(context, LABEL_PATH);
- - for (int i = 0; i < numberLabel; i++) {
- - String label = labels.get(i);
- + List<String> labels = FileUtil.loadLabels(context, LABEL_PATH);
- + for (int i = 0; i < numberLabel; i++) {
- + String label = labels.get(i);
-
- - assertThat(map).containsKey(label);
- + assertThat(map).containsKey(label);
-
- - int[] array = map.get(label).getIntArray();
- - assertThat(array).hasLength(1);
- - assertThat(array[0]).isEqualTo(inputArr[i]);
- + int[] array = map.get(label).getIntArray();
- + assertThat(array).hasLength(1);
- + assertThat(array[0]).isEqualTo(inputArr[i]);
- + }
- }
- - }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/model/GpuDelegateProxyInstrumentedTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/model/GpuDelegateProxyInstrumentedTest.java
- index bd59051ce4ccb..d7449187cb54c 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/model/GpuDelegateProxyInstrumentedTest.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/model/GpuDelegateProxyInstrumentedTest.java
- @@ -17,6 +17,7 @@ package org.tensorflow.lite.support.model;
- import static com.google.common.truth.Truth.assertThat;
-
- import androidx.test.ext.junit.runners.AndroidJUnit4;
- +
- import org.junit.Test;
- import org.junit.runner.RunWith;
-
- @@ -27,13 +28,12 @@ import org.junit.runner.RunWith;
- */
- @RunWith(AndroidJUnit4.class)
- public final class GpuDelegateProxyInstrumentedTest {
- -
- - @Test
- - public void createGpuDelegateProxyShouldSuccess() {
- - GpuDelegateProxy proxy = GpuDelegateProxy.maybeNewInstance();
- -
- - assertThat(proxy).isNotNull();
- - proxy.getNativeHandle();
- - proxy.close();
- - }
- + @Test
- + public void createGpuDelegateProxyShouldSuccess() {
- + GpuDelegateProxy proxy = GpuDelegateProxy.maybeNewInstance();
- +
- + assertThat(proxy).isNotNull();
- + proxy.getNativeHandle();
- + proxy.close();
- + }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/model/GpuDelegateProxyTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/model/GpuDelegateProxyTest.java
- index c1bbcc223a895..4eb2e2920c3bc 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/model/GpuDelegateProxyTest.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/model/GpuDelegateProxyTest.java
- @@ -23,11 +23,10 @@ import org.robolectric.RobolectricTestRunner;
- /** Tests of {@link org.tensorflow.lite.support.model.GpuDelegateProxy}. */
- @RunWith(RobolectricTestRunner.class)
- public final class GpuDelegateProxyTest {
- + @Test
- + public void createGpuDelegateProxyWithoutDependencyShouldReturnNull() {
- + GpuDelegateProxy proxy = GpuDelegateProxy.maybeNewInstance();
-
- - @Test
- - public void createGpuDelegateProxyWithoutDependencyShouldReturnNull() {
- - GpuDelegateProxy proxy = GpuDelegateProxy.maybeNewInstance();
- -
- - assertThat(proxy).isNull();
- - }
- + assertThat(proxy).isNull();
- + }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/model/ModelTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/model/ModelTest.java
- index 86e4f72769216..342e82b2de3bb 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/model/ModelTest.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/model/ModelTest.java
- @@ -16,143 +16,145 @@ limitations under the License.
- package org.tensorflow.lite.support.model;
-
- import static com.google.common.truth.Truth.assertThat;
- +
- import static org.junit.Assert.fail;
-
- import android.content.Context;
- +
- import androidx.test.core.app.ApplicationProvider;
- -import java.io.IOException;
- -import java.nio.MappedByteBuffer;
- -import java.util.HashMap;
- -import java.util.Map;
- +
- +import org.junit.Ignore;
- import org.junit.Test;
- import org.junit.runner.RunWith;
- import org.robolectric.RobolectricTestRunner;
- import org.tensorflow.lite.support.model.Model.Device;
- import org.tensorflow.lite.support.model.Model.Options;
-
- -import org.junit.Ignore;
- +import java.io.IOException;
- +import java.nio.MappedByteBuffer;
- +import java.util.HashMap;
- +import java.util.Map;
-
- /** Tests of {@link org.tensorflow.lite.support.model.Model}. */
- @RunWith(RobolectricTestRunner.class)
- public final class ModelTest {
- + private final Context context = ApplicationProvider.getApplicationContext();
- + private static final String MODEL_PATH = "add.tflite";
- +
- + @Ignore
- + @Test
- + public void testLoadLocalModel() throws IOException {
- + MappedByteBuffer byteModel = new Model.Builder(context, MODEL_PATH).build().getData();
- + assertThat(byteModel).isNotNull();
- + }
- +
- + @Ignore
- + @Test
- + public void testBuildMultiThreadModel() throws IOException {
- + MappedByteBuffer byteModel =
- + new Model.Builder(context, MODEL_PATH).setNumThreads(4).build().getData();
- + assertThat(byteModel).isNotNull();
- + }
- +
- + @Ignore
- + @Test
- + public void buildModelWithOptionsShouldSuccess() throws IOException {
- + Options options = new Options.Builder().setNumThreads(2).setDevice(Device.NNAPI).build();
- + Model model = Model.createModel(context, MODEL_PATH, options);
- + assertThat(model.getData()).isNotNull();
- + }
-
- - private final Context context = ApplicationProvider.getApplicationContext();
- - private static final String MODEL_PATH = "add.tflite";
- -
- - @Ignore
- - @Test
- - public void testLoadLocalModel() throws IOException {
- - MappedByteBuffer byteModel = new Model.Builder(context, MODEL_PATH).build().getData();
- - assertThat(byteModel).isNotNull();
- - }
- -
- - @Ignore
- - @Test
- - public void testBuildMultiThreadModel() throws IOException {
- - MappedByteBuffer byteModel =
- - new Model.Builder(context, MODEL_PATH).setNumThreads(4).build().getData();
- - assertThat(byteModel).isNotNull();
- - }
- -
- - @Ignore
- - @Test
- - public void buildModelWithOptionsShouldSuccess() throws IOException {
- - Options options = new Options.Builder().setNumThreads(2).setDevice(Device.NNAPI).build();
- - Model model = Model.createModel(context, MODEL_PATH, options);
- - assertThat(model.getData()).isNotNull();
- - }
- -
- - @Ignore
- - @Test
- - public void testGetModelPath() throws IOException {
- - String modelPath = new Model.Builder(context, MODEL_PATH).build().getPath();
- - assertThat(modelPath).isEqualTo(MODEL_PATH);
- - }
- -
- - @Test
- - public void testNonExistingLocalModel() {
- - try {
- - new Model.Builder(context, "non_exist_model_file").build();
- - fail();
- - } catch (IOException e) {
- - assertThat(e).hasMessageThat().contains("non_exist_model_file");
- + @Ignore
- + @Test
- + public void testGetModelPath() throws IOException {
- + String modelPath = new Model.Builder(context, MODEL_PATH).build().getPath();
- + assertThat(modelPath).isEqualTo(MODEL_PATH);
- }
- - }
- -
- - @Test
- - public void testNullLocalModelPath() throws IOException {
- - try {
- - new Model.Builder(context, null).build();
- - fail();
- - } catch (NullPointerException e) {
- - assertThat(e).hasMessageThat().contains("File path cannot be null.");
- +
- + @Test
- + public void testNonExistingLocalModel() {
- + try {
- + new Model.Builder(context, "non_exist_model_file").build();
- + fail();
- + } catch (IOException e) {
- + assertThat(e).hasMessageThat().contains("non_exist_model_file");
- + }
- }
- - }
- -
- - @Test
- - public void testNullContext() throws IOException {
- - try {
- - new Model.Builder(null, MODEL_PATH).build();
- - fail();
- - } catch (NullPointerException e) {
- - assertThat(e).hasMessageThat().contains("Context should not be null.");
- +
- + @Test
- + public void testNullLocalModelPath() throws IOException {
- + try {
- + new Model.Builder(context, null).build();
- + fail();
- + } catch (NullPointerException e) {
- + assertThat(e).hasMessageThat().contains("File path cannot be null.");
- + }
- + }
- +
- + @Test
- + public void testNullContext() throws IOException {
- + try {
- + new Model.Builder(null, MODEL_PATH).build();
- + fail();
- + } catch (NullPointerException e) {
- + assertThat(e).hasMessageThat().contains("Context should not be null.");
- + }
- + }
- +
- + @Ignore
- + @Test
- + public void testGetInputTensor() throws IOException {
- + Options options = new Options.Builder().build();
- + Model model = Model.createModel(context, MODEL_PATH, options);
- + assertThat(model.getInputTensor(0)).isNotNull();
- + }
- +
- + @Ignore
- + @Test
- + public void testGetOutputTensor() throws IOException {
- + Options options = new Options.Builder().build();
- + Model model = Model.createModel(context, MODEL_PATH, options);
- + assertThat(model.getOutputTensor(0)).isNotNull();
- + }
- +
- + @Ignore
- + @Test
- + public void testRun() throws IOException {
- + Context context = ApplicationProvider.getApplicationContext();
- + Model model = new Model.Builder(context, MODEL_PATH).build();
- + runModel(model);
- + }
- +
- + @Ignore
- + @Test
- + public void testMultiThreadingRun() throws IOException {
- + Context context = ApplicationProvider.getApplicationContext();
- + Model model = new Model.Builder(context, MODEL_PATH).setNumThreads(4).build();
- + runModel(model);
- + }
- +
- + @Ignore
- + @Test
- + public void testNnApiRun() throws IOException {
- + Context context = ApplicationProvider.getApplicationContext();
- + Model model = new Model.Builder(context, MODEL_PATH).setDevice(Device.NNAPI).build();
- + runModel(model);
- + }
- +
- + private static void runModel(Model model) throws IOException {
- + // Creates the inputs.
- + float[] x = {1.5f};
- + float[] y = {0.5f};
- + float[] expectedSum = {2.0f};
- + Object[] inputs = {x, y};
- +
- + // Creates the outputs buffer.
- + float[] sum = new float[1];
- + Map<Integer, Object> outputs = new HashMap<>();
- + outputs.put(0, sum);
- +
- + // Runs inference.
- + model.run(inputs, outputs);
- + assertThat(sum).isEqualTo(expectedSum);
- }
- - }
- -
- - @Ignore
- - @Test
- - public void testGetInputTensor() throws IOException {
- - Options options = new Options.Builder().build();
- - Model model = Model.createModel(context, MODEL_PATH, options);
- - assertThat(model.getInputTensor(0)).isNotNull();
- - }
- -
- - @Ignore
- - @Test
- - public void testGetOutputTensor() throws IOException {
- - Options options = new Options.Builder().build();
- - Model model = Model.createModel(context, MODEL_PATH, options);
- - assertThat(model.getOutputTensor(0)).isNotNull();
- - }
- -
- - @Ignore
- - @Test
- - public void testRun() throws IOException {
- - Context context = ApplicationProvider.getApplicationContext();
- - Model model = new Model.Builder(context, MODEL_PATH).build();
- - runModel(model);
- - }
- -
- - @Ignore
- - @Test
- - public void testMultiThreadingRun() throws IOException {
- - Context context = ApplicationProvider.getApplicationContext();
- - Model model = new Model.Builder(context, MODEL_PATH).setNumThreads(4).build();
- - runModel(model);
- - }
- -
- - @Ignore
- - @Test
- - public void testNnApiRun() throws IOException {
- - Context context = ApplicationProvider.getApplicationContext();
- - Model model = new Model.Builder(context, MODEL_PATH).setDevice(Device.NNAPI).build();
- - runModel(model);
- - }
- -
- - private static void runModel(Model model) throws IOException {
- - // Creates the inputs.
- - float[] x = {1.5f};
- - float[] y = {0.5f};
- - float[] expectedSum = {2.0f};
- - Object[] inputs = {x, y};
- -
- - // Creates the outputs buffer.
- - float[] sum = new float[1];
- - Map<Integer, Object> outputs = new HashMap<>();
- - outputs.put(0, sum);
- -
- - // Runs inference.
- - model.run(inputs, outputs);
- - assertThat(sum).isEqualTo(expectedSum);
- - }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/tensorbuffer/TensorBufferFloatTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/tensorbuffer/TensorBufferFloatTest.java
- index 3a4d09d8e5701..82b59b36155f3 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/tensorbuffer/TensorBufferFloatTest.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/tensorbuffer/TensorBufferFloatTest.java
- @@ -26,51 +26,51 @@ import org.tensorflow.lite.DataType;
- /** Tests of {@link org.tensorflow.lite.support.tensorbuffer.TensorBufferFloat}. */
- @RunWith(RobolectricTestRunner.class)
- public final class TensorBufferFloatTest {
- - @Test
- - public void testCreateDynamic() {
- - TensorBufferFloat tensorBufferFloat = new TensorBufferFloat();
- - assertThat(tensorBufferFloat).isNotNull();
- - }
- + @Test
- + public void testCreateDynamic() {
- + TensorBufferFloat tensorBufferFloat = new TensorBufferFloat();
- + assertThat(tensorBufferFloat).isNotNull();
- + }
-
- - @Test
- - public void testCreateFixedSize() {
- - int[] shape = new int[] {1, 2, 3};
- - TensorBufferFloat tensorBufferFloat = new TensorBufferFloat(shape);
- - assertThat(tensorBufferFloat).isNotNull();
- - assertThat(tensorBufferFloat.getFlatSize()).isEqualTo(6);
- - }
- + @Test
- + public void testCreateFixedSize() {
- + int[] shape = new int[] {1, 2, 3};
- + TensorBufferFloat tensorBufferFloat = new TensorBufferFloat(shape);
- + assertThat(tensorBufferFloat).isNotNull();
- + assertThat(tensorBufferFloat.getFlatSize()).isEqualTo(6);
- + }
-
- - @Test
- - public void testCreateFixedSizeWithScalarShape() {
- - int[] shape = new int[] {};
- - TensorBufferFloat tensorBufferFloat = new TensorBufferFloat(shape);
- - assertThat(tensorBufferFloat).isNotNull();
- - assertThat(tensorBufferFloat.getFlatSize()).isEqualTo(1);
- - }
- + @Test
- + public void testCreateFixedSizeWithScalarShape() {
- + int[] shape = new int[] {};
- + TensorBufferFloat tensorBufferFloat = new TensorBufferFloat(shape);
- + assertThat(tensorBufferFloat).isNotNull();
- + assertThat(tensorBufferFloat.getFlatSize()).isEqualTo(1);
- + }
-
- - @Test
- - public void testCreateWithNullShape() {
- - int[] shape = null;
- - Assert.assertThrows(NullPointerException.class, () -> new TensorBufferFloat(shape));
- - }
- + @Test
- + public void testCreateWithNullShape() {
- + int[] shape = null;
- + Assert.assertThrows(NullPointerException.class, () -> new TensorBufferFloat(shape));
- + }
-
- - @Test
- - public void testCreateWithInvalidShape() {
- - int[] shape = new int[] {1, -1, 2};
- - Assert.assertThrows(IllegalArgumentException.class, () -> new TensorBufferFloat(shape));
- - }
- + @Test
- + public void testCreateWithInvalidShape() {
- + int[] shape = new int[] {1, -1, 2};
- + Assert.assertThrows(IllegalArgumentException.class, () -> new TensorBufferFloat(shape));
- + }
-
- - @Test
- - public void testCreateUsingShapeWithZero() {
- - int[] shape = new int[] {1, 0, 2};
- - TensorBufferFloat tensorBufferFloat = new TensorBufferFloat(shape);
- - assertThat(tensorBufferFloat).isNotNull();
- - assertThat(tensorBufferFloat.getFlatSize()).isEqualTo(0);
- - }
- + @Test
- + public void testCreateUsingShapeWithZero() {
- + int[] shape = new int[] {1, 0, 2};
- + TensorBufferFloat tensorBufferFloat = new TensorBufferFloat(shape);
- + assertThat(tensorBufferFloat).isNotNull();
- + assertThat(tensorBufferFloat.getFlatSize()).isEqualTo(0);
- + }
-
- - @Test
- - public void testGetDataType() {
- - TensorBufferFloat tensorBufferFloat = new TensorBufferFloat();
- - assertThat(tensorBufferFloat.getDataType()).isEqualTo(DataType.FLOAT32);
- - }
- + @Test
- + public void testGetDataType() {
- + TensorBufferFloat tensorBufferFloat = new TensorBufferFloat();
- + assertThat(tensorBufferFloat.getDataType()).isEqualTo(DataType.FLOAT32);
- + }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/tensorbuffer/TensorBufferTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/tensorbuffer/TensorBufferTest.java
- index c55affe733eac..763356f493390 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/tensorbuffer/TensorBufferTest.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/tensorbuffer/TensorBufferTest.java
- @@ -16,877 +16,878 @@ limitations under the License.
- package org.tensorflow.lite.support.tensorbuffer;
-
- import static com.google.common.truth.Truth.assertThat;
- +
- import static org.junit.Assert.assertThrows;
-
- -import java.io.IOException;
- -import java.nio.ByteBuffer;
- -import java.nio.FloatBuffer;
- -import java.util.ArrayList;
- import org.junit.Assert;
- import org.junit.Test;
- import org.junit.runner.RunWith;
- import org.robolectric.RobolectricTestRunner;
- import org.tensorflow.lite.DataType;
-
- +import java.io.IOException;
- +import java.nio.ByteBuffer;
- +import java.nio.FloatBuffer;
- +import java.util.ArrayList;
- +
- /** Test helper class for inserting and retrieving arrays. */
- class ArrayTestRunner {
- - // List of TensorBuffer types to be tested.
- - private static final DataType[] BUFFER_TYPE_LIST = {DataType.FLOAT32, DataType.UINT8};
- - // List of source arrays to be loaded into TensorBuffer during the tests.
- - private final ArrayList<Object> srcArrays;
- - // List of array data type with respect to srcArrays.
- - private final ArrayList<DataType> arrDataTypes;
- - // List of array shape with respect to srcArrays.
- - private final ArrayList<int[]> arrShapes;
- - private final int[] tensorBufferShape;
- - private final ExpectedResults expectedResForFloatBuf;
- - private final ExpectedResults expectedResForByteBuf;
- -
- - public ArrayTestRunner(Builder builder) {
- - if (builder.srcArrays.size() != builder.arrDataTypes.size()) {
- - throw new IllegalArgumentException(
- - "Number of source arrays and number of data types do not match.");
- - }
- -
- - this.srcArrays = builder.srcArrays;
- - this.arrDataTypes = builder.arrDataTypes;
- - this.arrShapes = builder.arrShapes;
- - this.tensorBufferShape = builder.tensorBufferShape;
- - this.expectedResForFloatBuf = builder.expectedResForFloatBuf;
- - this.expectedResForByteBuf = builder.expectedResForByteBuf;
- - }
- -
- - static class ExpectedResults {
- - public float[] floatArr;
- - public int[] intArr;
- - public int[] shape;
- - }
- -
- - public static class Builder {
- - private final ArrayList<Object> srcArrays = new ArrayList<>();
- - private final ArrayList<DataType> arrDataTypes = new ArrayList<>();
- - private final ArrayList<int[]> arrShapes = new ArrayList<>();
- - private int[] tensorBufferShape;
- - private final ExpectedResults expectedResForFloatBuf = new ExpectedResults();
- - private final ExpectedResults expectedResForByteBuf = new ExpectedResults();
- -
- - public static Builder newInstance() {
- - return new Builder();
- - }
- -
- - private Builder() {}
- -
- - /** Loads a test array into the test runner. */
- - public Builder addSrcArray(Object src, int[] shape) {
- - // src should be a primitive 1D array.
- - DataType dataType = dataTypeOfArray(src);
- - switch (dataType) {
- - case INT32:
- - case FLOAT32:
- - srcArrays.add(src);
- - arrDataTypes.add(dataType);
- - arrShapes.add(shape);
- - return this;
- - default:
- - throw new AssertionError("Cannot resolve srouce arrays in the DataType of " + dataType);
- - }
- - }
- -
- - public Builder setTensorBufferShape(int[] tensorBufferShape) {
- - this.tensorBufferShape = tensorBufferShape;
- - return this;
- - }
- -
- - public Builder setExpectedResults(
- - DataType bufferType, float[] expectedFloatArr, int[] expectedIntArr) {
- - ExpectedResults er;
- - switch (bufferType) {
- - case UINT8:
- - er = expectedResForByteBuf;
- - break;
- - case FLOAT32:
- - er = expectedResForFloatBuf;
- - break;
- - default:
- - throw new AssertionError("Cannot test TensorBuffer in the DataType of " + bufferType);
- - }
- -
- - er.floatArr = expectedFloatArr;
- - er.intArr = expectedIntArr;
- - return this;
- - }
- -
- - public ArrayTestRunner build() {
- - int[] expectedShape;
- - if (arrShapes.isEmpty()) {
- - // If no array will be loaded, the array is an empty array.
- - expectedShape = new int[] {0};
- - } else {
- - expectedShape = arrShapes.get(arrShapes.size() - 1);
- - }
- - expectedResForByteBuf.shape = expectedShape;
- - expectedResForFloatBuf.shape = expectedShape;
- - return new ArrayTestRunner(this);
- - }
- - }
- -
- - public static DataType[] getBufferTypeList() {
- - return BUFFER_TYPE_LIST;
- - }
- -
- - /**
- - * Runs tests in the following steps: 1. Create a TensorBuffer. If tensorBufferShape is null,
- - * create a dynamic buffer. Otherwise, create a fixed-size buffer accordingly. 2. Load arrays in
- - * srcArrays one by one into the TensotBuffer. 3. Get arrays for each supported primitive types in
- - * TensorBuffer, such as int array and float array for now. Check if the results are correct. 4.
- - * Repeat Step 1 to 3 for all buffer types in BUFFER_TYPE_LIST.
- - */
- - public void run() {
- - for (DataType bufferDataType : BUFFER_TYPE_LIST) {
- - TensorBuffer tensorBuffer;
- - if (tensorBufferShape == null) {
- - tensorBuffer = TensorBuffer.createDynamic(bufferDataType);
- - } else {
- - tensorBuffer = TensorBuffer.createFixedSize(tensorBufferShape, bufferDataType);
- - }
- - for (int i = 0; i < srcArrays.size(); i++) {
- - switch (arrDataTypes.get(i)) {
- - case INT32:
- - int[] arrInt = (int[]) srcArrays.get(i);
- - tensorBuffer.loadArray(arrInt, arrShapes.get(i));
- - break;
- - case FLOAT32:
- - float[] arrFloat = (float[]) srcArrays.get(i);
- - tensorBuffer.loadArray(arrFloat, arrShapes.get(i));
- - break;
- - default:
- - break;
- + // List of TensorBuffer types to be tested.
- + private static final DataType[] BUFFER_TYPE_LIST = {DataType.FLOAT32, DataType.UINT8};
- + // List of source arrays to be loaded into TensorBuffer during the tests.
- + private final ArrayList<Object> srcArrays;
- + // List of array data type with respect to srcArrays.
- + private final ArrayList<DataType> arrDataTypes;
- + // List of array shape with respect to srcArrays.
- + private final ArrayList<int[]> arrShapes;
- + private final int[] tensorBufferShape;
- + private final ExpectedResults expectedResForFloatBuf;
- + private final ExpectedResults expectedResForByteBuf;
- +
- + public ArrayTestRunner(Builder builder) {
- + if (builder.srcArrays.size() != builder.arrDataTypes.size()) {
- + throw new IllegalArgumentException(
- + "Number of source arrays and number of data types do not match.");
- }
- - }
- - checkResults(tensorBuffer);
- - }
- - }
- -
- - private void checkResults(TensorBuffer tensorBuffer) {
- - ExpectedResults er;
- - switch (tensorBuffer.getDataType()) {
- - case UINT8:
- - er = expectedResForByteBuf;
- - break;
- - case FLOAT32:
- - er = expectedResForFloatBuf;
- - break;
- - default:
- - throw new AssertionError(
- - "Cannot test TensorBuffer in the DataType of " + tensorBuffer.getDataType());
- - }
- -
- - // Checks getIntArray() and getFloatArray().
- - int[] resIntArr = tensorBuffer.getIntArray();
- - assertThat(resIntArr).isEqualTo(er.intArr);
- - float[] resFloatArr = tensorBuffer.getFloatArray();
- - assertThat(resFloatArr).isEqualTo(er.floatArr);
- - assertThat(tensorBuffer.getShape()).isEqualTo(er.shape);
- -
- - // Checks getIntValue(int index) and getFloatValue(int index).
- - int flatSize = tensorBuffer.getFlatSize();
- - float[] resFloatValues = new float[flatSize];
- - int[] resIntValues = new int[flatSize];
- - for (int i = 0; i < flatSize; i++) {
- - resFloatValues[i] = tensorBuffer.getFloatValue(i);
- - resIntValues[i] = tensorBuffer.getIntValue(i);
- - }
- - assertThat(resFloatValues).isEqualTo(er.floatArr);
- - assertThat(resIntValues).isEqualTo(er.intArr);
- - }
- -
- - /** Gets the data type of an 1D array. */
- - private static DataType dataTypeOfArray(Object arr) {
- - if (arr != null) {
- - Class<?> c = arr.getClass();
- - if (c.isArray()) {
- - c = c.getComponentType();
- - if (float.class.equals(c)) {
- - return DataType.FLOAT32;
- - } else if (int.class.equals(c)) {
- - return DataType.INT32;
- - } else if (byte.class.equals(c)) {
- - return DataType.UINT8;
- - } else if (long.class.equals(c)) {
- - return DataType.INT64;
- - } else if (String.class.equals(c)) {
- - return DataType.STRING;
- +
- + this.srcArrays = builder.srcArrays;
- + this.arrDataTypes = builder.arrDataTypes;
- + this.arrShapes = builder.arrShapes;
- + this.tensorBufferShape = builder.tensorBufferShape;
- + this.expectedResForFloatBuf = builder.expectedResForFloatBuf;
- + this.expectedResForByteBuf = builder.expectedResForByteBuf;
- + }
- +
- + static class ExpectedResults {
- + public float[] floatArr;
- + public int[] intArr;
- + public int[] shape;
- + }
- +
- + public static class Builder {
- + private final ArrayList<Object> srcArrays = new ArrayList<>();
- + private final ArrayList<DataType> arrDataTypes = new ArrayList<>();
- + private final ArrayList<int[]> arrShapes = new ArrayList<>();
- + private int[] tensorBufferShape;
- + private final ExpectedResults expectedResForFloatBuf = new ExpectedResults();
- + private final ExpectedResults expectedResForByteBuf = new ExpectedResults();
- +
- + public static Builder newInstance() {
- + return new Builder();
- + }
- +
- + private Builder() {}
- +
- + /** Loads a test array into the test runner. */
- + public Builder addSrcArray(Object src, int[] shape) {
- + // src should be a primitive 1D array.
- + DataType dataType = dataTypeOfArray(src);
- + switch (dataType) {
- + case INT32:
- + case FLOAT32:
- + srcArrays.add(src);
- + arrDataTypes.add(dataType);
- + arrShapes.add(shape);
- + return this;
- + default:
- + throw new AssertionError(
- + "Cannot resolve srouce arrays in the DataType of " + dataType);
- + }
- + }
- +
- + public Builder setTensorBufferShape(int[] tensorBufferShape) {
- + this.tensorBufferShape = tensorBufferShape;
- + return this;
- }
- - }
- +
- + public Builder setExpectedResults(
- + DataType bufferType, float[] expectedFloatArr, int[] expectedIntArr) {
- + ExpectedResults er;
- + switch (bufferType) {
- + case UINT8:
- + er = expectedResForByteBuf;
- + break;
- + case FLOAT32:
- + er = expectedResForFloatBuf;
- + break;
- + default:
- + throw new AssertionError(
- + "Cannot test TensorBuffer in the DataType of " + bufferType);
- + }
- +
- + er.floatArr = expectedFloatArr;
- + er.intArr = expectedIntArr;
- + return this;
- + }
- +
- + public ArrayTestRunner build() {
- + int[] expectedShape;
- + if (arrShapes.isEmpty()) {
- + // If no array will be loaded, the array is an empty array.
- + expectedShape = new int[] {0};
- + } else {
- + expectedShape = arrShapes.get(arrShapes.size() - 1);
- + }
- + expectedResForByteBuf.shape = expectedShape;
- + expectedResForFloatBuf.shape = expectedShape;
- + return new ArrayTestRunner(this);
- + }
- + }
- +
- + public static DataType[] getBufferTypeList() {
- + return BUFFER_TYPE_LIST;
- + }
- +
- + /**
- + * Runs tests in the following steps: 1. Create a TensorBuffer. If tensorBufferShape is null,
- + * create a dynamic buffer. Otherwise, create a fixed-size buffer accordingly. 2. Load arrays in
- + * srcArrays one by one into the TensotBuffer. 3. Get arrays for each supported primitive types
- + * in TensorBuffer, such as int array and float array for now. Check if the results are
- + * correct. 4. Repeat Step 1 to 3 for all buffer types in BUFFER_TYPE_LIST.
- + */
- + public void run() {
- + for (DataType bufferDataType : BUFFER_TYPE_LIST) {
- + TensorBuffer tensorBuffer;
- + if (tensorBufferShape == null) {
- + tensorBuffer = TensorBuffer.createDynamic(bufferDataType);
- + } else {
- + tensorBuffer = TensorBuffer.createFixedSize(tensorBufferShape, bufferDataType);
- + }
- + for (int i = 0; i < srcArrays.size(); i++) {
- + switch (arrDataTypes.get(i)) {
- + case INT32:
- + int[] arrInt = (int[]) srcArrays.get(i);
- + tensorBuffer.loadArray(arrInt, arrShapes.get(i));
- + break;
- + case FLOAT32:
- + float[] arrFloat = (float[]) srcArrays.get(i);
- + tensorBuffer.loadArray(arrFloat, arrShapes.get(i));
- + break;
- + default:
- + break;
- + }
- + }
- + checkResults(tensorBuffer);
- + }
- + }
- +
- + private void checkResults(TensorBuffer tensorBuffer) {
- + ExpectedResults er;
- + switch (tensorBuffer.getDataType()) {
- + case UINT8:
- + er = expectedResForByteBuf;
- + break;
- + case FLOAT32:
- + er = expectedResForFloatBuf;
- + break;
- + default:
- + throw new AssertionError("Cannot test TensorBuffer in the DataType of "
- + + tensorBuffer.getDataType());
- + }
- +
- + // Checks getIntArray() and getFloatArray().
- + int[] resIntArr = tensorBuffer.getIntArray();
- + assertThat(resIntArr).isEqualTo(er.intArr);
- + float[] resFloatArr = tensorBuffer.getFloatArray();
- + assertThat(resFloatArr).isEqualTo(er.floatArr);
- + assertThat(tensorBuffer.getShape()).isEqualTo(er.shape);
- +
- + // Checks getIntValue(int index) and getFloatValue(int index).
- + int flatSize = tensorBuffer.getFlatSize();
- + float[] resFloatValues = new float[flatSize];
- + int[] resIntValues = new int[flatSize];
- + for (int i = 0; i < flatSize; i++) {
- + resFloatValues[i] = tensorBuffer.getFloatValue(i);
- + resIntValues[i] = tensorBuffer.getIntValue(i);
- + }
- + assertThat(resFloatValues).isEqualTo(er.floatArr);
- + assertThat(resIntValues).isEqualTo(er.intArr);
- + }
- +
- + /** Gets the data type of an 1D array. */
- + private static DataType dataTypeOfArray(Object arr) {
- + if (arr != null) {
- + Class<?> c = arr.getClass();
- + if (c.isArray()) {
- + c = c.getComponentType();
- + if (float.class.equals(c)) {
- + return DataType.FLOAT32;
- + } else if (int.class.equals(c)) {
- + return DataType.INT32;
- + } else if (byte.class.equals(c)) {
- + return DataType.UINT8;
- + } else if (long.class.equals(c)) {
- + return DataType.INT64;
- + } else if (String.class.equals(c)) {
- + return DataType.STRING;
- + }
- + }
- + }
- + throw new IllegalArgumentException(
- + "Requires a 1D array. Cannot resolve data type of " + arr.getClass().getName());
- }
- - throw new IllegalArgumentException(
- - "Requires a 1D array. Cannot resolve data type of " + arr.getClass().getName());
- - }
- }
-
- /** Tests of {@link org.tensorflow.lite.support.tensorbuffer.TensorBuffer}. */
- @RunWith(RobolectricTestRunner.class)
- public final class TensorBufferTest {
- - // FLOAT_ARRAY1 and INT_ARRAY1 correspond to each other.
- - private static final int[] ARRAY1_SHAPE = new int[] {2, 3};
- - private static final float[] FLOAT_ARRAY1 = new float[] {500.1f, 4.2f, 3.3f, 2.4f, 1.5f, 6.1f};
- - private static final float[] FLOAT_ARRAY1_ROUNDED =
- - new float[] {500.0f, 4.0f, 3.0f, 2.0f, 1.0f, 6.0f};
- - // FLOAT_ARRAY1_CAPPED and INT_ARRAY1_CAPPED correspond to the expected values when converted into
- - // uint8.
- - private static final float[] FLOAT_ARRAY1_CAPPED =
- - new float[] {255.0f, 4.0f, 3.0f, 2.0f, 1.0f, 6.0f};
- - private static final int[] INT_ARRAY1 = new int[] {500, 4, 3, 2, 1, 6};
- - private static final int[] INT_ARRAY1_CAPPED = new int[] {255, 4, 3, 2, 1, 6};
- - // FLOAT_ARRAY2 and INT_ARRAY2 correspond to each other.
- - private static final int[] ARRAY2_SHAPE = new int[] {2, 1};
- - private static final float[] FLOAT_ARRAY2 = new float[] {6.7f, 7.6f};
- - private static final float[] FLOAT_ARRAY2_ROUNDED = new float[] {6.0f, 7.0f};
- - private static final int[] INT_ARRAY2 = new int[] {6, 7};
- - // FLOAT_ARRAY2 and FLOAT_ARRAY3 have the same size.
- - private static final int[] ARRAY3_SHAPE = new int[] {2, 1};
- - private static final float[] FLOAT_ARRAY3 = new float[] {8.2f, 9.9f};
- - private static final float[] FLOAT_ARRAY3_ROUNDED = new float[] {8.0f, 9.0f};
- - // INT_ARRAY2 and INT_ARRAY3 have the same size.
- - private static final int[] INT_ARRAY3 = new int[] {8, 9};
- - private static final int[] EMPTY_ARRAY_SHAPE = new int[] {0};
- - private static final int[] EMPTY_INT_ARRAY = new int[0];
- - private static final float[] EMPTY_FLOAT_ARRAY = new float[0];
- - // Single element array which represents a scalar.
- - private static final int[] SCALAR_ARRAY_SHAPE = new int[] {};
- - private static final float[] FLOAT_SCALAR_ARRAY = new float[] {800.2f};
- - private static final float[] FLOAT_SCALAR_ARRAY_ROUNDED = new float[] {800.0f};
- - private static final float[] FLOAT_SCALAR_ARRAY_CAPPED = new float[] {255.0f};
- - private static final int[] INT_SCALAR_ARRAY = new int[] {800};
- - private static final int[] INT_SCALAR_ARRAY_CAPPED = new int[] {255};
- - // Several different ByteBuffer.
- - private static final ByteBuffer EMPTY_BYTE_BUFFER = ByteBuffer.allocateDirect(0);
- - private static final ByteBuffer FLOAT_BYTE_BUFFER1 = ByteBuffer.allocateDirect(24);
- -
- - static {
- - FLOAT_BYTE_BUFFER1.rewind();
- -
- - FloatBuffer floatBuffer = FLOAT_BYTE_BUFFER1.asFloatBuffer();
- - floatBuffer.put(FLOAT_ARRAY1);
- - }
- -
- - private static final ByteBuffer INT_BYTE_BUFFER2 = ByteBuffer.allocateDirect(2);
- -
- - static {
- - INT_BYTE_BUFFER2.rewind();
- -
- - for (int a : INT_ARRAY2) {
- - INT_BYTE_BUFFER2.put((byte) a);
- - }
- - }
- -
- - @Test
- - public void testCreateFixedSizeTensorBufferFloat() {
- - int[] shape = new int[] {1, 2, 3};
- - TensorBuffer tensorBufferFloat = TensorBuffer.createFixedSize(shape, DataType.FLOAT32);
- - assertThat(tensorBufferFloat).isNotNull();
- - assertThat(tensorBufferFloat.getFlatSize()).isEqualTo(6);
- - }
- -
- - @Test
- - public void testCreateFixedSizeTensorBufferUint8() {
- - int[] shape = new int[] {1, 2, 3};
- - TensorBuffer tensorBufferUint8 = TensorBuffer.createFixedSize(shape, DataType.UINT8);
- - assertThat(tensorBufferUint8).isNotNull();
- - assertThat(tensorBufferUint8.getFlatSize()).isEqualTo(6);
- - }
- -
- - @Test
- - public void testCreateDynamicTensorBufferFloat() {
- - TensorBuffer tensorBufferFloat = TensorBuffer.createDynamic(DataType.FLOAT32);
- - assertThat(tensorBufferFloat).isNotNull();
- - }
- -
- - @Test
- - public void testCreateDynamicTensorBufferUint8() {
- - TensorBuffer tensorBufferUint8 = TensorBuffer.createDynamic(DataType.UINT8);
- - assertThat(tensorBufferUint8).isNotNull();
- - }
- -
- - @Test
- - public void testCreateTensorBufferFromFixedSize() {
- - int[] shape = new int[] {1, 2, 3};
- - TensorBuffer src = TensorBuffer.createFixedSize(shape, DataType.UINT8);
- - TensorBuffer dst = TensorBuffer.createFrom(src, DataType.FLOAT32);
- - assertThat(dst.getShape()).isEqualTo(new int[] {1, 2, 3});
- - }
- -
- - @Test
- - public void testCreateTensorBufferFromDynamicSize() {
- - int[] shape = new int[] {1, 2, 3};
- - TensorBuffer src = TensorBuffer.createDynamic(DataType.UINT8);
- - src.resize(shape);
- - TensorBuffer dst = TensorBuffer.createFrom(src, DataType.FLOAT32);
- - assertThat(dst.getShape()).isEqualTo(new int[] {1, 2, 3});
- - }
- -
- - @Test
- - public void testCreateTensorBufferUInt8FromUInt8() {
- - int[] shape = new int[] {INT_ARRAY1.length};
- - TensorBuffer src = TensorBuffer.createFixedSize(shape, DataType.UINT8);
- - src.loadArray(INT_ARRAY1);
- - TensorBuffer dst = TensorBuffer.createFrom(src, DataType.UINT8);
- - int[] data = dst.getIntArray();
- - assertThat(data).isEqualTo(INT_ARRAY1_CAPPED);
- - }
- -
- - @Test
- - public void testCreateTensorBufferUInt8FromFloat32() {
- - TensorBuffer src = TensorBuffer.createDynamic(DataType.FLOAT32);
- - src.loadArray(FLOAT_ARRAY1, ARRAY1_SHAPE);
- - TensorBuffer dst = TensorBuffer.createFrom(src, DataType.UINT8);
- - int[] data = dst.getIntArray();
- - assertThat(data).isEqualTo(INT_ARRAY1_CAPPED);
- - }
- -
- - @Test
- - public void testCreateTensorBufferFloat32FromUInt8() {
- - TensorBuffer src = TensorBuffer.createDynamic(DataType.UINT8);
- - src.loadArray(INT_ARRAY1, ARRAY1_SHAPE);
- - TensorBuffer dst = TensorBuffer.createFrom(src, DataType.FLOAT32);
- - float[] data = dst.getFloatArray();
- - assertThat(data).isEqualTo(FLOAT_ARRAY1_CAPPED);
- - }
- -
- - @Test
- - public void testCreateTensorBufferFloat32FromFloat32() {
- - int[] shape = new int[] {FLOAT_ARRAY1.length};
- - TensorBuffer src = TensorBuffer.createFixedSize(shape, DataType.FLOAT32);
- - src.loadArray(FLOAT_ARRAY1);
- - TensorBuffer dst = TensorBuffer.createFrom(src, DataType.FLOAT32);
- - float[] data = dst.getFloatArray();
- - assertThat(data).isEqualTo(FLOAT_ARRAY1);
- - }
- -
- - @Test
- - public void testGetBuffer() throws IOException {
- - int[] shape = new int[] {1, 2, 3};
- - TensorBuffer tensorBufferUint8 = TensorBuffer.createFixedSize(shape, DataType.UINT8);
- - assertThat(tensorBufferUint8.getBuffer()).isNotNull();
- - }
- -
- - @Test
- - public void testLoadAndGetIntArrayWithFixedSizeForScalarArray() throws IOException {
- - ArrayTestRunner.Builder.newInstance()
- - .addSrcArray(INT_SCALAR_ARRAY, SCALAR_ARRAY_SHAPE)
- - .setTensorBufferShape(SCALAR_ARRAY_SHAPE)
- - .setExpectedResults(
- - /*bufferType = */ DataType.FLOAT32,
- - /*expectedFloatArr=*/ FLOAT_SCALAR_ARRAY_ROUNDED,
- - /*expectedIntArr=*/ INT_SCALAR_ARRAY)
- - .setExpectedResults(
- - /*bufferType = */ DataType.UINT8,
- - /*expectedFloatArr=*/ FLOAT_SCALAR_ARRAY_CAPPED,
- - /*expectedIntArr=*/ INT_SCALAR_ARRAY_CAPPED)
- - .build()
- - .run();
- - }
- -
- - @Test
- - public void testLoadAndGetFloatArrayWithFixedSizeForScalarArray() throws IOException {
- - ArrayTestRunner.Builder.newInstance()
- - .addSrcArray(FLOAT_SCALAR_ARRAY, SCALAR_ARRAY_SHAPE)
- - .setTensorBufferShape(SCALAR_ARRAY_SHAPE)
- - .setExpectedResults(
- - /*bufferType = */ DataType.FLOAT32,
- - /*expectedFloatArr=*/ FLOAT_SCALAR_ARRAY,
- - /*expectedIntArr=*/ INT_SCALAR_ARRAY)
- - .setExpectedResults(
- - /*bufferType = */ DataType.UINT8,
- - /*expectedFloatArr=*/ FLOAT_SCALAR_ARRAY_CAPPED,
- - /*expectedIntArr=*/ INT_SCALAR_ARRAY_CAPPED)
- - .build()
- - .run();
- - }
- -
- - @Test
- - public void testLoadAndGetIntArrayWithFixedSize() {
- - ArrayTestRunner.Builder.newInstance()
- - .addSrcArray(INT_ARRAY1, ARRAY1_SHAPE)
- - .setTensorBufferShape(ARRAY1_SHAPE)
- - .setExpectedResults(
- - /*bufferType = */ DataType.FLOAT32,
- - /*expectedFloatArr=*/ FLOAT_ARRAY1_ROUNDED,
- - /*expectedIntArr=*/ INT_ARRAY1)
- - .setExpectedResults(
- - /*bufferType = */ DataType.UINT8,
- - /*expectedFloatArr=*/ FLOAT_ARRAY1_CAPPED,
- - /*expectedIntArr=*/ INT_ARRAY1_CAPPED)
- - .build()
- - .run();
- - }
- -
- - @Test
- - public void testLoadAndGetFloatArrayWithFixedSize() {
- - ArrayTestRunner.Builder.newInstance()
- - .addSrcArray(FLOAT_ARRAY1, ARRAY1_SHAPE)
- - .setTensorBufferShape(ARRAY1_SHAPE)
- - .setExpectedResults(
- - /*bufferType = */ DataType.FLOAT32,
- - /*expectedFloatArr=*/ FLOAT_ARRAY1,
- - /*expectedIntArr=*/ INT_ARRAY1)
- - .setExpectedResults(
- - /*bufferType = */ DataType.UINT8,
- - /*expectedFloatArr=*/ FLOAT_ARRAY1_CAPPED,
- - /*expectedIntArr=*/ INT_ARRAY1_CAPPED)
- - .build()
- - .run();
- - }
- -
- - @Test
- - public void testRepeatedLoadAndGetIntArrayWithSameFixedSize() {
- - ArrayTestRunner.Builder.newInstance()
- - .addSrcArray(INT_ARRAY2, ARRAY2_SHAPE)
- - .addSrcArray(INT_ARRAY3, ARRAY3_SHAPE)
- - .setTensorBufferShape(ARRAY2_SHAPE)
- - .setExpectedResults(
- - /*bufferType = */ DataType.FLOAT32,
- - /*expectedFloatArr=*/ FLOAT_ARRAY3_ROUNDED,
- - /*expectedIntArr=*/ INT_ARRAY3)
- - .setExpectedResults(
- - /*bufferType = */ DataType.UINT8,
- - /*expectedFloatArr=*/ FLOAT_ARRAY3_ROUNDED,
- - /*expectedIntArr=*/ INT_ARRAY3)
- - .build()
- - .run();
- - }
- -
- - @Test
- - public void testRepeatedLoadAndGetFloatArrayWithSameFixedSize() {
- - ArrayTestRunner.Builder.newInstance()
- - .addSrcArray(FLOAT_ARRAY2, ARRAY2_SHAPE)
- - .addSrcArray(FLOAT_ARRAY3, ARRAY3_SHAPE)
- - .setTensorBufferShape(ARRAY2_SHAPE)
- - .setExpectedResults(
- - /*bufferType = */ DataType.FLOAT32,
- - /*expectedFloatArr=*/ FLOAT_ARRAY3,
- - /*expectedIntArr=*/ INT_ARRAY3)
- - .setExpectedResults(
- - /*bufferType = */ DataType.UINT8,
- - /*expectedFloatArr=*/ FLOAT_ARRAY3_ROUNDED,
- - /*expectedIntArr=*/ INT_ARRAY3)
- - .build()
- - .run();
- - }
- -
- - @Test
- - public void testRepeatedLoadIntArrayWithDifferentFixedSize() {
- - int[] srcArr1 = INT_ARRAY1;
- - int[] srcArr2 = INT_ARRAY2;
- - for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
- - TensorBuffer tensorBuffer =
- - TensorBuffer.createFixedSize(new int[] {srcArr1.length}, dataType);
- - tensorBuffer.loadArray(srcArr1, new int[] {srcArr1.length});
- - // Load srcArr2 which had different size as srcArr1.
- - Assert.assertThrows(
- - IllegalArgumentException.class,
- - () -> tensorBuffer.loadArray(srcArr2, new int[] {srcArr2.length}));
- - }
- - }
- -
- - @Test
- - public void testRepeatedLoadFloatArrayWithDifferentFixedSize() {
- - float[] srcArr1 = FLOAT_ARRAY1;
- - float[] srcArr2 = FLOAT_ARRAY2;
- - for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
- - TensorBuffer tensorBuffer =
- - TensorBuffer.createFixedSize(new int[] {srcArr1.length}, dataType);
- - tensorBuffer.loadArray(srcArr1, new int[] {srcArr1.length});
- - // Load srcArr2 which had different size as srcArr1.
- - Assert.assertThrows(
- - IllegalArgumentException.class,
- - () -> tensorBuffer.loadArray(srcArr2, new int[] {srcArr2.length}));
- - }
- - }
- -
- - @Test
- - public void testLoadAndGetIntArrayWithDynamicSize() {
- - ArrayTestRunner.Builder.newInstance()
- - .addSrcArray(INT_ARRAY1, ARRAY1_SHAPE)
- - .setExpectedResults(
- - /*bufferType = */ DataType.FLOAT32,
- - /*expectedFloatArr=*/ FLOAT_ARRAY1_ROUNDED,
- - /*expectedIntArr=*/ INT_ARRAY1)
- - .setExpectedResults(
- - /*bufferType = */ DataType.UINT8,
- - /*expectedFloatArr=*/ FLOAT_ARRAY1_CAPPED,
- - /*expectedIntArr=*/ INT_ARRAY1_CAPPED)
- - .build()
- - .run();
- - }
- -
- - @Test
- - public void testLoadAndGetFloatArrayWithDynamicSize() {
- - ArrayTestRunner.Builder.newInstance()
- - .addSrcArray(FLOAT_ARRAY1, ARRAY1_SHAPE)
- - .setExpectedResults(
- - /*bufferType = */ DataType.FLOAT32,
- - /*expectedFloatArr=*/ FLOAT_ARRAY1,
- - /*expectedIntArr=*/ INT_ARRAY1)
- - .setExpectedResults(
- - /*bufferType = */ DataType.UINT8,
- - /*expectedFloatArr=*/ FLOAT_ARRAY1_CAPPED,
- - /*expectedIntArr=*/ INT_ARRAY1_CAPPED)
- - .build()
- - .run();
- - }
- -
- - @Test
- - public void testRepeatedLoadAndGetIntArrayWithDifferentDynamicSize() {
- - ArrayTestRunner.Builder.newInstance()
- - .addSrcArray(INT_ARRAY1, ARRAY1_SHAPE)
- - .addSrcArray(INT_ARRAY2, ARRAY2_SHAPE)
- - .setExpectedResults(
- - /*bufferType = */ DataType.FLOAT32,
- - /*expectedFloatArr=*/ FLOAT_ARRAY2_ROUNDED,
- - /*expectedIntArr=*/ INT_ARRAY2)
- - .setExpectedResults(
- - /*bufferType = */ DataType.UINT8,
- - /*expectedFloatArr=*/ FLOAT_ARRAY2_ROUNDED,
- - /*expectedIntArr=*/ INT_ARRAY2)
- - .build()
- - .run();
- - }
- -
- - @Test
- - public void testRepeatedLoadAndGetFloatArrayWithDifferentDynamicSize() {
- - ArrayTestRunner.Builder.newInstance()
- - .addSrcArray(FLOAT_ARRAY1, ARRAY1_SHAPE)
- - .addSrcArray(FLOAT_ARRAY2, ARRAY2_SHAPE)
- - .setExpectedResults(
- - /*bufferType = */ DataType.FLOAT32,
- - /*expectedFloatArr=*/ FLOAT_ARRAY2,
- - /*expectedIntArr=*/ INT_ARRAY2)
- - .setExpectedResults(
- - /*bufferType = */ DataType.UINT8,
- - /*expectedFloatArr=*/ FLOAT_ARRAY2_ROUNDED,
- - /*expectedIntArr=*/ INT_ARRAY2)
- - .build()
- - .run();
- - }
- -
- - @Test
- - public void testGetForEmptyArrayWithFixedSizeBuffer() {
- - ArrayTestRunner.Builder.newInstance()
- - .setTensorBufferShape(EMPTY_ARRAY_SHAPE)
- - .setExpectedResults(
- - /*bufferType = */ DataType.FLOAT32,
- - /*expectedFloatArr=*/ EMPTY_FLOAT_ARRAY,
- - /*expectedIntArr=*/ EMPTY_INT_ARRAY)
- - .setExpectedResults(
- - /*bufferType = */ DataType.UINT8,
- - /*expectedFloatArr=*/ EMPTY_FLOAT_ARRAY,
- - /*expectedIntArr=*/ EMPTY_INT_ARRAY)
- - .build()
- - .run();
- - }
- -
- - @Test
- - public void testGetForEmptyArrayWithDynamicBuffer() {
- - ArrayTestRunner.Builder.newInstance()
- - .setExpectedResults(
- - /*bufferType = */ DataType.FLOAT32,
- - /*expectedFloatArr=*/ EMPTY_FLOAT_ARRAY,
- - /*expectedIntArr=*/ EMPTY_INT_ARRAY)
- - .setExpectedResults(
- - /*bufferType = */ DataType.UINT8,
- - /*expectedFloatArr=*/ EMPTY_FLOAT_ARRAY,
- - /*expectedIntArr=*/ EMPTY_INT_ARRAY)
- - .build()
- - .run();
- - }
- -
- - @Test
- - public void testRepeatedLoadAndGetForEmptyArray() {
- - ArrayTestRunner.Builder.newInstance()
- - .addSrcArray(EMPTY_INT_ARRAY, EMPTY_ARRAY_SHAPE)
- - .addSrcArray(FLOAT_ARRAY2, ARRAY2_SHAPE)
- - .addSrcArray(EMPTY_FLOAT_ARRAY, EMPTY_ARRAY_SHAPE)
- - .setExpectedResults(
- - /*bufferType = */ DataType.FLOAT32,
- - /*expectedFloatArr=*/ EMPTY_FLOAT_ARRAY,
- - /*expectedIntArr=*/ EMPTY_INT_ARRAY)
- - .setExpectedResults(
- - /*bufferType = */ DataType.UINT8,
- - /*expectedFloatArr=*/ EMPTY_FLOAT_ARRAY,
- - /*expectedIntArr=*/ EMPTY_INT_ARRAY)
- - .build()
- - .run();
- - }
- -
- - @Test
- - public void testLoadNullIntArrays() {
- - int[] nullArray = null;
- - int[] shape = new int[] {};
- - for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
- - TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType);
- - Assert.assertThrows(
- - NullPointerException.class, () -> tensorBuffer.loadArray(nullArray, shape));
- - }
- - }
- -
- - @Test
- - public void testLoadNullFloatArrays() {
- - float[] nullArray = null;
- - int[] shape = new int[] {};
- - for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
- - TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType);
- - Assert.assertThrows(
- - NullPointerException.class, () -> tensorBuffer.loadArray(nullArray, shape));
- - }
- - }
- -
- - @Test
- - public void testLoadFloatArraysWithNullShape() {
- - float[] arr = new float[] {1.0f};
- - int[] nullShape = null;
- - for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
- - TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType);
- - Assert.assertThrows(NullPointerException.class, () -> tensorBuffer.loadArray(arr, nullShape));
- - }
- - }
- -
- - @Test
- - public void testLoadIntArraysWithNullShape() {
- - int[] arr = new int[] {1};
- - int[] nullShape = null;
- - for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
- - TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType);
- - Assert.assertThrows(NullPointerException.class, () -> tensorBuffer.loadArray(arr, nullShape));
- - }
- - }
- -
- - @Test
- - public void testLoadIntArraysWithoutShapeAndArrayDoesNotMatchShape() {
- - for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
- - TensorBuffer fixedTensorBuffer = TensorBuffer.createFixedSize(ARRAY1_SHAPE, dataType);
- - Assert.assertThrows(
- - IllegalArgumentException.class, () -> fixedTensorBuffer.loadArray(INT_ARRAY2));
- - }
- - }
- -
- - @Test
- - public void testLoadFloatArraysWithoutShapeAndArrayDoesNotMatchShape() {
- - for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
- - TensorBuffer fixedTensorBuffer = TensorBuffer.createFixedSize(ARRAY1_SHAPE, dataType);
- - Assert.assertThrows(
- - IllegalArgumentException.class, () -> fixedTensorBuffer.loadArray(FLOAT_ARRAY2));
- - }
- - }
- -
- - @Test
- - public void testLoadByteBufferForNullBuffer() {
- - ByteBuffer byteBuffer = null;
- - int[] shape = new int[] {};
- - for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
- - TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType);
- - Assert.assertThrows(
- - NullPointerException.class, () -> tensorBuffer.loadBuffer(byteBuffer, shape));
- - }
- - }
- -
- - @Test
- - public void testLoadByteBufferForEmptyBuffer() {
- - for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
- - TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType);
- - tensorBuffer.loadBuffer(EMPTY_BYTE_BUFFER, EMPTY_ARRAY_SHAPE);
- - assertThat(tensorBuffer.getFlatSize()).isEqualTo(0);
- - }
- - }
- -
- - @Test
- - public void testLoadByteBufferWithDifferentFixedSize() {
- - // Create a fixed-size TensorBuffer with size 2, and load a ByteBuffer with size 5.
- - int[] tensorBufferShape = new int[] {2};
- - TensorBuffer tensorBuffer = TensorBuffer.createFixedSize(tensorBufferShape, DataType.FLOAT32);
- - Assert.assertThrows(
- - IllegalArgumentException.class,
- - () -> tensorBuffer.loadBuffer(FLOAT_BYTE_BUFFER1, ARRAY1_SHAPE));
- - }
- -
- - @Test
- - public void testLoadByteBufferWithMisMatchDataType() {
- - TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.FLOAT32);
- - int[] wrongShape = new int[] {1};
- - // Size of INT_BYTE_BUFFER is 8 bytes. It does not match the specified shape.
- - Assert.assertThrows(
- - IllegalArgumentException.class,
- - () -> tensorBuffer.loadBuffer(INT_BYTE_BUFFER2, wrongShape));
- - }
- -
- - @Test
- - public void testLoadByteBufferForTensorBufferFloat() {
- - TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.FLOAT32);
- - tensorBuffer.loadBuffer(FLOAT_BYTE_BUFFER1, ARRAY1_SHAPE);
- - assertThat(tensorBuffer.getFloatArray()).isEqualTo(FLOAT_ARRAY1);
- - assertThat(tensorBuffer.getShape()).isEqualTo(ARRAY1_SHAPE);
- - }
- -
- - @Test
- - public void testLoadByteBufferForTensorBufferUint8() {
- - TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.UINT8);
- - tensorBuffer.loadBuffer(INT_BYTE_BUFFER2, ARRAY2_SHAPE);
- - assertThat(tensorBuffer.getIntArray()).isEqualTo(INT_ARRAY2);
- - assertThat(tensorBuffer.getShape()).isEqualTo(ARRAY2_SHAPE);
- - }
- -
- - @Test
- - public void testGetFloatValueWithInvalidIndex() {
- - float[] arrayWithSixElements = FLOAT_ARRAY1;
- - int[] shapeOfArrayWithSixElements = ARRAY1_SHAPE;
- - int[] invalidIndexes = {-1, 7};
- - for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
- - TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType);
- - tensorBuffer.loadArray(arrayWithSixElements, shapeOfArrayWithSixElements);
- - for (int invalidIndex : invalidIndexes) {
- - Assert.assertThrows(
- - IndexOutOfBoundsException.class, () -> tensorBuffer.getFloatValue(invalidIndex));
- - }
- - }
- - }
- -
- - @Test
- - public void testGetFloatValueFromScalarWithInvalidIndex() {
- - int[] shape = new int[] {};
- - float[] arr = new float[] {10.0f};
- - int[] invalidIndexes =
- - new int[] {-1, 1}; // -1 is negative, and 1 is not smaller than the flatsize.
- - for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
- - TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType);
- - tensorBuffer.loadArray(arr, shape);
- - for (int invalidIndex : invalidIndexes) {
- - Assert.assertThrows(
- - IndexOutOfBoundsException.class, () -> tensorBuffer.getFloatValue(invalidIndex));
- - }
- - }
- - }
- -
- - @Test
- - public void testGetIntValueWithInvalidIndex() {
- - float[] arrayWithSixElements = FLOAT_ARRAY1;
- - int[] shapeOfArrayWithSixElements = ARRAY1_SHAPE;
- - int[] invalidIndexes = {-1, 7};
- - for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
- - TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType);
- - tensorBuffer.loadArray(arrayWithSixElements, shapeOfArrayWithSixElements);
- - for (int invalidIndex : invalidIndexes) {
- - Assert.assertThrows(
- - IndexOutOfBoundsException.class, () -> tensorBuffer.getIntValue(invalidIndex));
- - }
- - }
- - }
- -
- - @Test
- - public void testGetIntValueFromScalarWithInvalidIndex() {
- - int[] shape = new int[] {};
- - float[] arr = new float[] {10.0f};
- - int[] invalidIndexes =
- - new int[] {-1, 1}; // -1 is negative, and 1 is not smaller than the flatsize.
- - for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
- - TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType);
- - tensorBuffer.loadArray(arr, shape);
- - for (int invalidIndex : invalidIndexes) {
- - Assert.assertThrows(
- - IndexOutOfBoundsException.class, () -> tensorBuffer.getIntValue(invalidIndex));
- - }
- - }
- - }
- -
- - @Test
- - public void testLoadByteBufferSliceForTensorBufferFloat() {
- - TensorBuffer original = TensorBuffer.createDynamic(DataType.FLOAT32);
- - original.loadArray(new float[] {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}, new int[] {6});
- - ByteBuffer buffer = original.getBuffer();
- - // Slice original buffer to 3 sub-buffer, each of which has 2 element
- - int numBuffers = 3;
- - int numElements = 2;
- - int subArrayLength = numElements * original.getTypeSize();
- - TensorBuffer tensorSlice = TensorBuffer.createDynamic(original.getDataType());
- - for (int i = 0; i < numBuffers; i++) {
- - buffer.position(i * subArrayLength);
- - ByteBuffer subBuffer = buffer.slice();
- - // ByteBuffer.slice doesn't keep order.
- - subBuffer.order(buffer.order()).limit(subArrayLength);
- - tensorSlice.loadBuffer(subBuffer, new int[] {numElements});
- - float[] arraySlice = tensorSlice.getFloatArray();
- - assertThat(arraySlice.length).isEqualTo(numElements);
- - assertThat(arraySlice[0]).isEqualTo(i * numElements + 1);
- - assertThat(arraySlice[1]).isEqualTo(i * numElements + 2);
- - }
- - }
- -
- - @Test
- - public void testLoadByteBufferSliceForTensorBufferUInt8() {
- - TensorBuffer original = TensorBuffer.createDynamic(DataType.UINT8);
- - original.loadArray(new int[] {1, 2, 3, 4, 5, 6}, new int[] {6});
- - ByteBuffer buffer = original.getBuffer();
- - // Slice original buffer to 3 sub-buffer, each of which has 2 element
- - int numBuffers = 3;
- - int numElements = 2;
- - int subArrayLength = numElements * original.getTypeSize();
- - TensorBuffer tensorSlice = TensorBuffer.createDynamic(original.getDataType());
- - for (int i = 0; i < numBuffers; i++) {
- - buffer.position(i * subArrayLength);
- - ByteBuffer subBuffer = buffer.slice();
- - // ByteBuffer.slice doesn't keep order.
- - subBuffer.order(buffer.order()).limit(subArrayLength);
- - tensorSlice.loadBuffer(subBuffer, new int[] {numElements});
- - int[] arraySlice = tensorSlice.getIntArray();
- - assertThat(arraySlice.length).isEqualTo(numElements);
- - assertThat(arraySlice[0]).isEqualTo(i * numElements + 1);
- - assertThat(arraySlice[1]).isEqualTo(i * numElements + 2);
- - }
- - }
- -
- - @Test
- - public void getShapeFailsAfterByteBufferChanged() {
- - TensorBuffer tensorBuffer = TensorBuffer.createFixedSize(new int[] {3, 2}, DataType.FLOAT32);
- - ByteBuffer byteBuffer = tensorBuffer.getBuffer();
- - byteBuffer.limit(5);
- -
- - IllegalStateException exception =
- - assertThrows(IllegalStateException.class, tensorBuffer::getShape);
- - assertThat(exception)
- - .hasMessageThat()
- - .contains(
- - "The size of underlying ByteBuffer (5) and the shape ([3, 2]) do not match. The"
- + // FLOAT_ARRAY1 and INT_ARRAY1 correspond to each other.
- + private static final int[] ARRAY1_SHAPE = new int[] {2, 3};
- + private static final float[] FLOAT_ARRAY1 = new float[] {500.1f, 4.2f, 3.3f, 2.4f, 1.5f, 6.1f};
- + private static final float[] FLOAT_ARRAY1_ROUNDED =
- + new float[] {500.0f, 4.0f, 3.0f, 2.0f, 1.0f, 6.0f};
- + // FLOAT_ARRAY1_CAPPED and INT_ARRAY1_CAPPED correspond to the expected values when converted
- + // into uint8.
- + private static final float[] FLOAT_ARRAY1_CAPPED =
- + new float[] {255.0f, 4.0f, 3.0f, 2.0f, 1.0f, 6.0f};
- + private static final int[] INT_ARRAY1 = new int[] {500, 4, 3, 2, 1, 6};
- + private static final int[] INT_ARRAY1_CAPPED = new int[] {255, 4, 3, 2, 1, 6};
- + // FLOAT_ARRAY2 and INT_ARRAY2 correspond to each other.
- + private static final int[] ARRAY2_SHAPE = new int[] {2, 1};
- + private static final float[] FLOAT_ARRAY2 = new float[] {6.7f, 7.6f};
- + private static final float[] FLOAT_ARRAY2_ROUNDED = new float[] {6.0f, 7.0f};
- + private static final int[] INT_ARRAY2 = new int[] {6, 7};
- + // FLOAT_ARRAY2 and FLOAT_ARRAY3 have the same size.
- + private static final int[] ARRAY3_SHAPE = new int[] {2, 1};
- + private static final float[] FLOAT_ARRAY3 = new float[] {8.2f, 9.9f};
- + private static final float[] FLOAT_ARRAY3_ROUNDED = new float[] {8.0f, 9.0f};
- + // INT_ARRAY2 and INT_ARRAY3 have the same size.
- + private static final int[] INT_ARRAY3 = new int[] {8, 9};
- + private static final int[] EMPTY_ARRAY_SHAPE = new int[] {0};
- + private static final int[] EMPTY_INT_ARRAY = new int[0];
- + private static final float[] EMPTY_FLOAT_ARRAY = new float[0];
- + // Single element array which represents a scalar.
- + private static final int[] SCALAR_ARRAY_SHAPE = new int[] {};
- + private static final float[] FLOAT_SCALAR_ARRAY = new float[] {800.2f};
- + private static final float[] FLOAT_SCALAR_ARRAY_ROUNDED = new float[] {800.0f};
- + private static final float[] FLOAT_SCALAR_ARRAY_CAPPED = new float[] {255.0f};
- + private static final int[] INT_SCALAR_ARRAY = new int[] {800};
- + private static final int[] INT_SCALAR_ARRAY_CAPPED = new int[] {255};
- + // Several different ByteBuffer.
- + private static final ByteBuffer EMPTY_BYTE_BUFFER = ByteBuffer.allocateDirect(0);
- + private static final ByteBuffer FLOAT_BYTE_BUFFER1 = ByteBuffer.allocateDirect(24);
- +
- + static {
- + FLOAT_BYTE_BUFFER1.rewind();
- +
- + FloatBuffer floatBuffer = FLOAT_BYTE_BUFFER1.asFloatBuffer();
- + floatBuffer.put(FLOAT_ARRAY1);
- + }
- +
- + private static final ByteBuffer INT_BYTE_BUFFER2 = ByteBuffer.allocateDirect(2);
- +
- + static {
- + INT_BYTE_BUFFER2.rewind();
- +
- + for (int a : INT_ARRAY2) {
- + INT_BYTE_BUFFER2.put((byte) a);
- + }
- + }
- +
- + @Test
- + public void testCreateFixedSizeTensorBufferFloat() {
- + int[] shape = new int[] {1, 2, 3};
- + TensorBuffer tensorBufferFloat = TensorBuffer.createFixedSize(shape, DataType.FLOAT32);
- + assertThat(tensorBufferFloat).isNotNull();
- + assertThat(tensorBufferFloat.getFlatSize()).isEqualTo(6);
- + }
- +
- + @Test
- + public void testCreateFixedSizeTensorBufferUint8() {
- + int[] shape = new int[] {1, 2, 3};
- + TensorBuffer tensorBufferUint8 = TensorBuffer.createFixedSize(shape, DataType.UINT8);
- + assertThat(tensorBufferUint8).isNotNull();
- + assertThat(tensorBufferUint8.getFlatSize()).isEqualTo(6);
- + }
- +
- + @Test
- + public void testCreateDynamicTensorBufferFloat() {
- + TensorBuffer tensorBufferFloat = TensorBuffer.createDynamic(DataType.FLOAT32);
- + assertThat(tensorBufferFloat).isNotNull();
- + }
- +
- + @Test
- + public void testCreateDynamicTensorBufferUint8() {
- + TensorBuffer tensorBufferUint8 = TensorBuffer.createDynamic(DataType.UINT8);
- + assertThat(tensorBufferUint8).isNotNull();
- + }
- +
- + @Test
- + public void testCreateTensorBufferFromFixedSize() {
- + int[] shape = new int[] {1, 2, 3};
- + TensorBuffer src = TensorBuffer.createFixedSize(shape, DataType.UINT8);
- + TensorBuffer dst = TensorBuffer.createFrom(src, DataType.FLOAT32);
- + assertThat(dst.getShape()).isEqualTo(new int[] {1, 2, 3});
- + }
- +
- + @Test
- + public void testCreateTensorBufferFromDynamicSize() {
- + int[] shape = new int[] {1, 2, 3};
- + TensorBuffer src = TensorBuffer.createDynamic(DataType.UINT8);
- + src.resize(shape);
- + TensorBuffer dst = TensorBuffer.createFrom(src, DataType.FLOAT32);
- + assertThat(dst.getShape()).isEqualTo(new int[] {1, 2, 3});
- + }
- +
- + @Test
- + public void testCreateTensorBufferUInt8FromUInt8() {
- + int[] shape = new int[] {INT_ARRAY1.length};
- + TensorBuffer src = TensorBuffer.createFixedSize(shape, DataType.UINT8);
- + src.loadArray(INT_ARRAY1);
- + TensorBuffer dst = TensorBuffer.createFrom(src, DataType.UINT8);
- + int[] data = dst.getIntArray();
- + assertThat(data).isEqualTo(INT_ARRAY1_CAPPED);
- + }
- +
- + @Test
- + public void testCreateTensorBufferUInt8FromFloat32() {
- + TensorBuffer src = TensorBuffer.createDynamic(DataType.FLOAT32);
- + src.loadArray(FLOAT_ARRAY1, ARRAY1_SHAPE);
- + TensorBuffer dst = TensorBuffer.createFrom(src, DataType.UINT8);
- + int[] data = dst.getIntArray();
- + assertThat(data).isEqualTo(INT_ARRAY1_CAPPED);
- + }
- +
- + @Test
- + public void testCreateTensorBufferFloat32FromUInt8() {
- + TensorBuffer src = TensorBuffer.createDynamic(DataType.UINT8);
- + src.loadArray(INT_ARRAY1, ARRAY1_SHAPE);
- + TensorBuffer dst = TensorBuffer.createFrom(src, DataType.FLOAT32);
- + float[] data = dst.getFloatArray();
- + assertThat(data).isEqualTo(FLOAT_ARRAY1_CAPPED);
- + }
- +
- + @Test
- + public void testCreateTensorBufferFloat32FromFloat32() {
- + int[] shape = new int[] {FLOAT_ARRAY1.length};
- + TensorBuffer src = TensorBuffer.createFixedSize(shape, DataType.FLOAT32);
- + src.loadArray(FLOAT_ARRAY1);
- + TensorBuffer dst = TensorBuffer.createFrom(src, DataType.FLOAT32);
- + float[] data = dst.getFloatArray();
- + assertThat(data).isEqualTo(FLOAT_ARRAY1);
- + }
- +
- + @Test
- + public void testGetBuffer() throws IOException {
- + int[] shape = new int[] {1, 2, 3};
- + TensorBuffer tensorBufferUint8 = TensorBuffer.createFixedSize(shape, DataType.UINT8);
- + assertThat(tensorBufferUint8.getBuffer()).isNotNull();
- + }
- +
- + @Test
- + public void testLoadAndGetIntArrayWithFixedSizeForScalarArray() throws IOException {
- + ArrayTestRunner.Builder.newInstance()
- + .addSrcArray(INT_SCALAR_ARRAY, SCALAR_ARRAY_SHAPE)
- + .setTensorBufferShape(SCALAR_ARRAY_SHAPE)
- + .setExpectedResults(
- + /*bufferType = */ DataType.FLOAT32,
- + /*expectedFloatArr=*/FLOAT_SCALAR_ARRAY_ROUNDED,
- + /*expectedIntArr=*/INT_SCALAR_ARRAY)
- + .setExpectedResults(
- + /*bufferType = */ DataType.UINT8,
- + /*expectedFloatArr=*/FLOAT_SCALAR_ARRAY_CAPPED,
- + /*expectedIntArr=*/INT_SCALAR_ARRAY_CAPPED)
- + .build()
- + .run();
- + }
- +
- + @Test
- + public void testLoadAndGetFloatArrayWithFixedSizeForScalarArray() throws IOException {
- + ArrayTestRunner.Builder.newInstance()
- + .addSrcArray(FLOAT_SCALAR_ARRAY, SCALAR_ARRAY_SHAPE)
- + .setTensorBufferShape(SCALAR_ARRAY_SHAPE)
- + .setExpectedResults(
- + /*bufferType = */ DataType.FLOAT32,
- + /*expectedFloatArr=*/FLOAT_SCALAR_ARRAY,
- + /*expectedIntArr=*/INT_SCALAR_ARRAY)
- + .setExpectedResults(
- + /*bufferType = */ DataType.UINT8,
- + /*expectedFloatArr=*/FLOAT_SCALAR_ARRAY_CAPPED,
- + /*expectedIntArr=*/INT_SCALAR_ARRAY_CAPPED)
- + .build()
- + .run();
- + }
- +
- + @Test
- + public void testLoadAndGetIntArrayWithFixedSize() {
- + ArrayTestRunner.Builder.newInstance()
- + .addSrcArray(INT_ARRAY1, ARRAY1_SHAPE)
- + .setTensorBufferShape(ARRAY1_SHAPE)
- + .setExpectedResults(
- + /*bufferType = */ DataType.FLOAT32,
- + /*expectedFloatArr=*/FLOAT_ARRAY1_ROUNDED,
- + /*expectedIntArr=*/INT_ARRAY1)
- + .setExpectedResults(
- + /*bufferType = */ DataType.UINT8,
- + /*expectedFloatArr=*/FLOAT_ARRAY1_CAPPED,
- + /*expectedIntArr=*/INT_ARRAY1_CAPPED)
- + .build()
- + .run();
- + }
- +
- + @Test
- + public void testLoadAndGetFloatArrayWithFixedSize() {
- + ArrayTestRunner.Builder.newInstance()
- + .addSrcArray(FLOAT_ARRAY1, ARRAY1_SHAPE)
- + .setTensorBufferShape(ARRAY1_SHAPE)
- + .setExpectedResults(
- + /*bufferType = */ DataType.FLOAT32,
- + /*expectedFloatArr=*/FLOAT_ARRAY1,
- + /*expectedIntArr=*/INT_ARRAY1)
- + .setExpectedResults(
- + /*bufferType = */ DataType.UINT8,
- + /*expectedFloatArr=*/FLOAT_ARRAY1_CAPPED,
- + /*expectedIntArr=*/INT_ARRAY1_CAPPED)
- + .build()
- + .run();
- + }
- +
- + @Test
- + public void testRepeatedLoadAndGetIntArrayWithSameFixedSize() {
- + ArrayTestRunner.Builder.newInstance()
- + .addSrcArray(INT_ARRAY2, ARRAY2_SHAPE)
- + .addSrcArray(INT_ARRAY3, ARRAY3_SHAPE)
- + .setTensorBufferShape(ARRAY2_SHAPE)
- + .setExpectedResults(
- + /*bufferType = */ DataType.FLOAT32,
- + /*expectedFloatArr=*/FLOAT_ARRAY3_ROUNDED,
- + /*expectedIntArr=*/INT_ARRAY3)
- + .setExpectedResults(
- + /*bufferType = */ DataType.UINT8,
- + /*expectedFloatArr=*/FLOAT_ARRAY3_ROUNDED,
- + /*expectedIntArr=*/INT_ARRAY3)
- + .build()
- + .run();
- + }
- +
- + @Test
- + public void testRepeatedLoadAndGetFloatArrayWithSameFixedSize() {
- + ArrayTestRunner.Builder.newInstance()
- + .addSrcArray(FLOAT_ARRAY2, ARRAY2_SHAPE)
- + .addSrcArray(FLOAT_ARRAY3, ARRAY3_SHAPE)
- + .setTensorBufferShape(ARRAY2_SHAPE)
- + .setExpectedResults(
- + /*bufferType = */ DataType.FLOAT32,
- + /*expectedFloatArr=*/FLOAT_ARRAY3,
- + /*expectedIntArr=*/INT_ARRAY3)
- + .setExpectedResults(
- + /*bufferType = */ DataType.UINT8,
- + /*expectedFloatArr=*/FLOAT_ARRAY3_ROUNDED,
- + /*expectedIntArr=*/INT_ARRAY3)
- + .build()
- + .run();
- + }
- +
- + @Test
- + public void testRepeatedLoadIntArrayWithDifferentFixedSize() {
- + int[] srcArr1 = INT_ARRAY1;
- + int[] srcArr2 = INT_ARRAY2;
- + for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
- + TensorBuffer tensorBuffer =
- + TensorBuffer.createFixedSize(new int[] {srcArr1.length}, dataType);
- + tensorBuffer.loadArray(srcArr1, new int[] {srcArr1.length});
- + // Load srcArr2 which had different size as srcArr1.
- + Assert.assertThrows(IllegalArgumentException.class,
- + () -> tensorBuffer.loadArray(srcArr2, new int[] {srcArr2.length}));
- + }
- + }
- +
- + @Test
- + public void testRepeatedLoadFloatArrayWithDifferentFixedSize() {
- + float[] srcArr1 = FLOAT_ARRAY1;
- + float[] srcArr2 = FLOAT_ARRAY2;
- + for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
- + TensorBuffer tensorBuffer =
- + TensorBuffer.createFixedSize(new int[] {srcArr1.length}, dataType);
- + tensorBuffer.loadArray(srcArr1, new int[] {srcArr1.length});
- + // Load srcArr2 which had different size as srcArr1.
- + Assert.assertThrows(IllegalArgumentException.class,
- + () -> tensorBuffer.loadArray(srcArr2, new int[] {srcArr2.length}));
- + }
- + }
- +
- + @Test
- + public void testLoadAndGetIntArrayWithDynamicSize() {
- + ArrayTestRunner.Builder.newInstance()
- + .addSrcArray(INT_ARRAY1, ARRAY1_SHAPE)
- + .setExpectedResults(
- + /*bufferType = */ DataType.FLOAT32,
- + /*expectedFloatArr=*/FLOAT_ARRAY1_ROUNDED,
- + /*expectedIntArr=*/INT_ARRAY1)
- + .setExpectedResults(
- + /*bufferType = */ DataType.UINT8,
- + /*expectedFloatArr=*/FLOAT_ARRAY1_CAPPED,
- + /*expectedIntArr=*/INT_ARRAY1_CAPPED)
- + .build()
- + .run();
- + }
- +
- + @Test
- + public void testLoadAndGetFloatArrayWithDynamicSize() {
- + ArrayTestRunner.Builder.newInstance()
- + .addSrcArray(FLOAT_ARRAY1, ARRAY1_SHAPE)
- + .setExpectedResults(
- + /*bufferType = */ DataType.FLOAT32,
- + /*expectedFloatArr=*/FLOAT_ARRAY1,
- + /*expectedIntArr=*/INT_ARRAY1)
- + .setExpectedResults(
- + /*bufferType = */ DataType.UINT8,
- + /*expectedFloatArr=*/FLOAT_ARRAY1_CAPPED,
- + /*expectedIntArr=*/INT_ARRAY1_CAPPED)
- + .build()
- + .run();
- + }
- +
- + @Test
- + public void testRepeatedLoadAndGetIntArrayWithDifferentDynamicSize() {
- + ArrayTestRunner.Builder.newInstance()
- + .addSrcArray(INT_ARRAY1, ARRAY1_SHAPE)
- + .addSrcArray(INT_ARRAY2, ARRAY2_SHAPE)
- + .setExpectedResults(
- + /*bufferType = */ DataType.FLOAT32,
- + /*expectedFloatArr=*/FLOAT_ARRAY2_ROUNDED,
- + /*expectedIntArr=*/INT_ARRAY2)
- + .setExpectedResults(
- + /*bufferType = */ DataType.UINT8,
- + /*expectedFloatArr=*/FLOAT_ARRAY2_ROUNDED,
- + /*expectedIntArr=*/INT_ARRAY2)
- + .build()
- + .run();
- + }
- +
- + @Test
- + public void testRepeatedLoadAndGetFloatArrayWithDifferentDynamicSize() {
- + ArrayTestRunner.Builder.newInstance()
- + .addSrcArray(FLOAT_ARRAY1, ARRAY1_SHAPE)
- + .addSrcArray(FLOAT_ARRAY2, ARRAY2_SHAPE)
- + .setExpectedResults(
- + /*bufferType = */ DataType.FLOAT32,
- + /*expectedFloatArr=*/FLOAT_ARRAY2,
- + /*expectedIntArr=*/INT_ARRAY2)
- + .setExpectedResults(
- + /*bufferType = */ DataType.UINT8,
- + /*expectedFloatArr=*/FLOAT_ARRAY2_ROUNDED,
- + /*expectedIntArr=*/INT_ARRAY2)
- + .build()
- + .run();
- + }
- +
- + @Test
- + public void testGetForEmptyArrayWithFixedSizeBuffer() {
- + ArrayTestRunner.Builder.newInstance()
- + .setTensorBufferShape(EMPTY_ARRAY_SHAPE)
- + .setExpectedResults(
- + /*bufferType = */ DataType.FLOAT32,
- + /*expectedFloatArr=*/EMPTY_FLOAT_ARRAY,
- + /*expectedIntArr=*/EMPTY_INT_ARRAY)
- + .setExpectedResults(
- + /*bufferType = */ DataType.UINT8,
- + /*expectedFloatArr=*/EMPTY_FLOAT_ARRAY,
- + /*expectedIntArr=*/EMPTY_INT_ARRAY)
- + .build()
- + .run();
- + }
- +
- + @Test
- + public void testGetForEmptyArrayWithDynamicBuffer() {
- + ArrayTestRunner.Builder.newInstance()
- + .setExpectedResults(
- + /*bufferType = */ DataType.FLOAT32,
- + /*expectedFloatArr=*/EMPTY_FLOAT_ARRAY,
- + /*expectedIntArr=*/EMPTY_INT_ARRAY)
- + .setExpectedResults(
- + /*bufferType = */ DataType.UINT8,
- + /*expectedFloatArr=*/EMPTY_FLOAT_ARRAY,
- + /*expectedIntArr=*/EMPTY_INT_ARRAY)
- + .build()
- + .run();
- + }
- +
- + @Test
- + public void testRepeatedLoadAndGetForEmptyArray() {
- + ArrayTestRunner.Builder.newInstance()
- + .addSrcArray(EMPTY_INT_ARRAY, EMPTY_ARRAY_SHAPE)
- + .addSrcArray(FLOAT_ARRAY2, ARRAY2_SHAPE)
- + .addSrcArray(EMPTY_FLOAT_ARRAY, EMPTY_ARRAY_SHAPE)
- + .setExpectedResults(
- + /*bufferType = */ DataType.FLOAT32,
- + /*expectedFloatArr=*/EMPTY_FLOAT_ARRAY,
- + /*expectedIntArr=*/EMPTY_INT_ARRAY)
- + .setExpectedResults(
- + /*bufferType = */ DataType.UINT8,
- + /*expectedFloatArr=*/EMPTY_FLOAT_ARRAY,
- + /*expectedIntArr=*/EMPTY_INT_ARRAY)
- + .build()
- + .run();
- + }
- +
- + @Test
- + public void testLoadNullIntArrays() {
- + int[] nullArray = null;
- + int[] shape = new int[] {};
- + for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
- + TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType);
- + Assert.assertThrows(
- + NullPointerException.class, () -> tensorBuffer.loadArray(nullArray, shape));
- + }
- + }
- +
- + @Test
- + public void testLoadNullFloatArrays() {
- + float[] nullArray = null;
- + int[] shape = new int[] {};
- + for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
- + TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType);
- + Assert.assertThrows(
- + NullPointerException.class, () -> tensorBuffer.loadArray(nullArray, shape));
- + }
- + }
- +
- + @Test
- + public void testLoadFloatArraysWithNullShape() {
- + float[] arr = new float[] {1.0f};
- + int[] nullShape = null;
- + for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
- + TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType);
- + Assert.assertThrows(
- + NullPointerException.class, () -> tensorBuffer.loadArray(arr, nullShape));
- + }
- + }
- +
- + @Test
- + public void testLoadIntArraysWithNullShape() {
- + int[] arr = new int[] {1};
- + int[] nullShape = null;
- + for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
- + TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType);
- + Assert.assertThrows(
- + NullPointerException.class, () -> tensorBuffer.loadArray(arr, nullShape));
- + }
- + }
- +
- + @Test
- + public void testLoadIntArraysWithoutShapeAndArrayDoesNotMatchShape() {
- + for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
- + TensorBuffer fixedTensorBuffer = TensorBuffer.createFixedSize(ARRAY1_SHAPE, dataType);
- + Assert.assertThrows(
- + IllegalArgumentException.class, () -> fixedTensorBuffer.loadArray(INT_ARRAY2));
- + }
- + }
- +
- + @Test
- + public void testLoadFloatArraysWithoutShapeAndArrayDoesNotMatchShape() {
- + for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
- + TensorBuffer fixedTensorBuffer = TensorBuffer.createFixedSize(ARRAY1_SHAPE, dataType);
- + Assert.assertThrows(IllegalArgumentException.class,
- + () -> fixedTensorBuffer.loadArray(FLOAT_ARRAY2));
- + }
- + }
- +
- + @Test
- + public void testLoadByteBufferForNullBuffer() {
- + ByteBuffer byteBuffer = null;
- + int[] shape = new int[] {};
- + for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
- + TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType);
- + Assert.assertThrows(
- + NullPointerException.class, () -> tensorBuffer.loadBuffer(byteBuffer, shape));
- + }
- + }
- +
- + @Test
- + public void testLoadByteBufferForEmptyBuffer() {
- + for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
- + TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType);
- + tensorBuffer.loadBuffer(EMPTY_BYTE_BUFFER, EMPTY_ARRAY_SHAPE);
- + assertThat(tensorBuffer.getFlatSize()).isEqualTo(0);
- + }
- + }
- +
- + @Test
- + public void testLoadByteBufferWithDifferentFixedSize() {
- + // Create a fixed-size TensorBuffer with size 2, and load a ByteBuffer with size 5.
- + int[] tensorBufferShape = new int[] {2};
- + TensorBuffer tensorBuffer =
- + TensorBuffer.createFixedSize(tensorBufferShape, DataType.FLOAT32);
- + Assert.assertThrows(IllegalArgumentException.class,
- + () -> tensorBuffer.loadBuffer(FLOAT_BYTE_BUFFER1, ARRAY1_SHAPE));
- + }
- +
- + @Test
- + public void testLoadByteBufferWithMisMatchDataType() {
- + TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.FLOAT32);
- + int[] wrongShape = new int[] {1};
- + // Size of INT_BYTE_BUFFER is 8 bytes. It does not match the specified shape.
- + Assert.assertThrows(IllegalArgumentException.class,
- + () -> tensorBuffer.loadBuffer(INT_BYTE_BUFFER2, wrongShape));
- + }
- +
- + @Test
- + public void testLoadByteBufferForTensorBufferFloat() {
- + TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.FLOAT32);
- + tensorBuffer.loadBuffer(FLOAT_BYTE_BUFFER1, ARRAY1_SHAPE);
- + assertThat(tensorBuffer.getFloatArray()).isEqualTo(FLOAT_ARRAY1);
- + assertThat(tensorBuffer.getShape()).isEqualTo(ARRAY1_SHAPE);
- + }
- +
- + @Test
- + public void testLoadByteBufferForTensorBufferUint8() {
- + TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.UINT8);
- + tensorBuffer.loadBuffer(INT_BYTE_BUFFER2, ARRAY2_SHAPE);
- + assertThat(tensorBuffer.getIntArray()).isEqualTo(INT_ARRAY2);
- + assertThat(tensorBuffer.getShape()).isEqualTo(ARRAY2_SHAPE);
- + }
- +
- + @Test
- + public void testGetFloatValueWithInvalidIndex() {
- + float[] arrayWithSixElements = FLOAT_ARRAY1;
- + int[] shapeOfArrayWithSixElements = ARRAY1_SHAPE;
- + int[] invalidIndexes = {-1, 7};
- + for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
- + TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType);
- + tensorBuffer.loadArray(arrayWithSixElements, shapeOfArrayWithSixElements);
- + for (int invalidIndex : invalidIndexes) {
- + Assert.assertThrows(IndexOutOfBoundsException.class,
- + () -> tensorBuffer.getFloatValue(invalidIndex));
- + }
- + }
- + }
- +
- + @Test
- + public void testGetFloatValueFromScalarWithInvalidIndex() {
- + int[] shape = new int[] {};
- + float[] arr = new float[] {10.0f};
- + int[] invalidIndexes =
- + new int[] {-1, 1}; // -1 is negative, and 1 is not smaller than the flatsize.
- + for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
- + TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType);
- + tensorBuffer.loadArray(arr, shape);
- + for (int invalidIndex : invalidIndexes) {
- + Assert.assertThrows(IndexOutOfBoundsException.class,
- + () -> tensorBuffer.getFloatValue(invalidIndex));
- + }
- + }
- + }
- +
- + @Test
- + public void testGetIntValueWithInvalidIndex() {
- + float[] arrayWithSixElements = FLOAT_ARRAY1;
- + int[] shapeOfArrayWithSixElements = ARRAY1_SHAPE;
- + int[] invalidIndexes = {-1, 7};
- + for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
- + TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType);
- + tensorBuffer.loadArray(arrayWithSixElements, shapeOfArrayWithSixElements);
- + for (int invalidIndex : invalidIndexes) {
- + Assert.assertThrows(IndexOutOfBoundsException.class,
- + () -> tensorBuffer.getIntValue(invalidIndex));
- + }
- + }
- + }
- +
- + @Test
- + public void testGetIntValueFromScalarWithInvalidIndex() {
- + int[] shape = new int[] {};
- + float[] arr = new float[] {10.0f};
- + int[] invalidIndexes =
- + new int[] {-1, 1}; // -1 is negative, and 1 is not smaller than the flatsize.
- + for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
- + TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType);
- + tensorBuffer.loadArray(arr, shape);
- + for (int invalidIndex : invalidIndexes) {
- + Assert.assertThrows(IndexOutOfBoundsException.class,
- + () -> tensorBuffer.getIntValue(invalidIndex));
- + }
- + }
- + }
- +
- + @Test
- + public void testLoadByteBufferSliceForTensorBufferFloat() {
- + TensorBuffer original = TensorBuffer.createDynamic(DataType.FLOAT32);
- + original.loadArray(new float[] {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}, new int[] {6});
- + ByteBuffer buffer = original.getBuffer();
- + // Slice original buffer to 3 sub-buffer, each of which has 2 element
- + int numBuffers = 3;
- + int numElements = 2;
- + int subArrayLength = numElements * original.getTypeSize();
- + TensorBuffer tensorSlice = TensorBuffer.createDynamic(original.getDataType());
- + for (int i = 0; i < numBuffers; i++) {
- + buffer.position(i * subArrayLength);
- + ByteBuffer subBuffer = buffer.slice();
- + // ByteBuffer.slice doesn't keep order.
- + subBuffer.order(buffer.order()).limit(subArrayLength);
- + tensorSlice.loadBuffer(subBuffer, new int[] {numElements});
- + float[] arraySlice = tensorSlice.getFloatArray();
- + assertThat(arraySlice.length).isEqualTo(numElements);
- + assertThat(arraySlice[0]).isEqualTo(i * numElements + 1);
- + assertThat(arraySlice[1]).isEqualTo(i * numElements + 2);
- + }
- + }
- +
- + @Test
- + public void testLoadByteBufferSliceForTensorBufferUInt8() {
- + TensorBuffer original = TensorBuffer.createDynamic(DataType.UINT8);
- + original.loadArray(new int[] {1, 2, 3, 4, 5, 6}, new int[] {6});
- + ByteBuffer buffer = original.getBuffer();
- + // Slice original buffer to 3 sub-buffer, each of which has 2 element
- + int numBuffers = 3;
- + int numElements = 2;
- + int subArrayLength = numElements * original.getTypeSize();
- + TensorBuffer tensorSlice = TensorBuffer.createDynamic(original.getDataType());
- + for (int i = 0; i < numBuffers; i++) {
- + buffer.position(i * subArrayLength);
- + ByteBuffer subBuffer = buffer.slice();
- + // ByteBuffer.slice doesn't keep order.
- + subBuffer.order(buffer.order()).limit(subArrayLength);
- + tensorSlice.loadBuffer(subBuffer, new int[] {numElements});
- + int[] arraySlice = tensorSlice.getIntArray();
- + assertThat(arraySlice.length).isEqualTo(numElements);
- + assertThat(arraySlice[0]).isEqualTo(i * numElements + 1);
- + assertThat(arraySlice[1]).isEqualTo(i * numElements + 2);
- + }
- + }
- +
- + @Test
- + public void getShapeFailsAfterByteBufferChanged() {
- + TensorBuffer tensorBuffer =
- + TensorBuffer.createFixedSize(new int[] {3, 2}, DataType.FLOAT32);
- + ByteBuffer byteBuffer = tensorBuffer.getBuffer();
- + byteBuffer.limit(5);
- +
- + IllegalStateException exception =
- + assertThrows(IllegalStateException.class, tensorBuffer::getShape);
- + assertThat(exception).hasMessageThat().contains(
- + "The size of underlying ByteBuffer (5) and the shape ([3, 2]) do not match. The"
- + " ByteBuffer may have been changed.");
- - }
- -
- - @Test
- - public void getFlatSizeFailsAfterByteBufferChanged() {
- - TensorBuffer tensorBuffer = TensorBuffer.createFixedSize(new int[] {3, 2}, DataType.FLOAT32);
- - ByteBuffer byteBuffer = tensorBuffer.getBuffer();
- - byteBuffer.limit(5);
- -
- - IllegalStateException exception =
- - assertThrows(IllegalStateException.class, tensorBuffer::getFlatSize);
- - assertThat(exception)
- - .hasMessageThat()
- - .contains(
- - "The size of underlying ByteBuffer (5) and the shape ([3, 2]) do not match. The"
- + }
- +
- + @Test
- + public void getFlatSizeFailsAfterByteBufferChanged() {
- + TensorBuffer tensorBuffer =
- + TensorBuffer.createFixedSize(new int[] {3, 2}, DataType.FLOAT32);
- + ByteBuffer byteBuffer = tensorBuffer.getBuffer();
- + byteBuffer.limit(5);
- +
- + IllegalStateException exception =
- + assertThrows(IllegalStateException.class, tensorBuffer::getFlatSize);
- + assertThat(exception).hasMessageThat().contains(
- + "The size of underlying ByteBuffer (5) and the shape ([3, 2]) do not match. The"
- + " ByteBuffer may have been changed.");
- - }
- -
- - @Test
- - public void loadReadOnlyBuffersCopiesOnWrite() {
- - TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.UINT8);
- - ByteBuffer originalByteBuffer = ByteBuffer.allocateDirect(1);
- - originalByteBuffer.put(new byte[]{99});
- - originalByteBuffer.rewind();
- - ByteBuffer readOnlyByteBuffer = originalByteBuffer.asReadOnlyBuffer();
- -
- - tensorBuffer.loadBuffer(readOnlyByteBuffer, new int[]{1});
- - assertThat(tensorBuffer.getBuffer()).isSameInstanceAs(readOnlyByteBuffer);
- -
- - tensorBuffer.loadArray(new int[]{42});
- - assertThat(tensorBuffer.getBuffer()).isNotSameInstanceAs(readOnlyByteBuffer);
- - assertThat(tensorBuffer.getBuffer().get(0)).isEqualTo(42); // updated
- - assertThat(originalByteBuffer.get(0)).isEqualTo(99); // original one not changed
- - }
- + }
- +
- + @Test
- + public void loadReadOnlyBuffersCopiesOnWrite() {
- + TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.UINT8);
- + ByteBuffer originalByteBuffer = ByteBuffer.allocateDirect(1);
- + originalByteBuffer.put(new byte[] {99});
- + originalByteBuffer.rewind();
- + ByteBuffer readOnlyByteBuffer = originalByteBuffer.asReadOnlyBuffer();
- +
- + tensorBuffer.loadBuffer(readOnlyByteBuffer, new int[] {1});
- + assertThat(tensorBuffer.getBuffer()).isSameInstanceAs(readOnlyByteBuffer);
- +
- + tensorBuffer.loadArray(new int[] {42});
- + assertThat(tensorBuffer.getBuffer()).isNotSameInstanceAs(readOnlyByteBuffer);
- + assertThat(tensorBuffer.getBuffer().get(0)).isEqualTo(42); // updated
- + assertThat(originalByteBuffer.get(0)).isEqualTo(99); // original one not changed
- + }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/tensorbuffer/TensorBufferUint8Test.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/tensorbuffer/TensorBufferUint8Test.java
- index e843133275d61..1921f4e467d01 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/tensorbuffer/TensorBufferUint8Test.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/tensorbuffer/TensorBufferUint8Test.java
- @@ -26,51 +26,51 @@ import org.tensorflow.lite.DataType;
- /** Tests of {@link org.tensorflow.lite.support.tensorbuffer.TensorBufferUint8}. */
- @RunWith(RobolectricTestRunner.class)
- public final class TensorBufferUint8Test {
- - @Test
- - public void testCreateDynamic() {
- - TensorBufferUint8 tensorBufferUint8 = new TensorBufferUint8();
- - assertThat(tensorBufferUint8).isNotNull();
- - }
- + @Test
- + public void testCreateDynamic() {
- + TensorBufferUint8 tensorBufferUint8 = new TensorBufferUint8();
- + assertThat(tensorBufferUint8).isNotNull();
- + }
-
- - @Test
- - public void testCreateFixedSize() {
- - int[] shape = new int[] {1, 2, 3};
- - TensorBufferUint8 tensorBufferUint8 = new TensorBufferUint8(shape);
- - assertThat(tensorBufferUint8).isNotNull();
- - assertThat(tensorBufferUint8.getFlatSize()).isEqualTo(6);
- - }
- + @Test
- + public void testCreateFixedSize() {
- + int[] shape = new int[] {1, 2, 3};
- + TensorBufferUint8 tensorBufferUint8 = new TensorBufferUint8(shape);
- + assertThat(tensorBufferUint8).isNotNull();
- + assertThat(tensorBufferUint8.getFlatSize()).isEqualTo(6);
- + }
-
- - @Test
- - public void testCreateFixedSizeWithScalarShape() {
- - int[] shape = new int[] {};
- - TensorBufferUint8 tensorBufferUint8 = new TensorBufferUint8(shape);
- - assertThat(tensorBufferUint8).isNotNull();
- - assertThat(tensorBufferUint8.getFlatSize()).isEqualTo(1);
- - }
- + @Test
- + public void testCreateFixedSizeWithScalarShape() {
- + int[] shape = new int[] {};
- + TensorBufferUint8 tensorBufferUint8 = new TensorBufferUint8(shape);
- + assertThat(tensorBufferUint8).isNotNull();
- + assertThat(tensorBufferUint8.getFlatSize()).isEqualTo(1);
- + }
-
- - @Test
- - public void testCreateWithNullShape() {
- - int[] shape = null;
- - Assert.assertThrows(NullPointerException.class, () -> new TensorBufferUint8(shape));
- - }
- + @Test
- + public void testCreateWithNullShape() {
- + int[] shape = null;
- + Assert.assertThrows(NullPointerException.class, () -> new TensorBufferUint8(shape));
- + }
-
- - @Test
- - public void testCreateWithInvalidShape() {
- - int[] shape = new int[] {1, -1, 2};
- - Assert.assertThrows(IllegalArgumentException.class, () -> new TensorBufferUint8(shape));
- - }
- + @Test
- + public void testCreateWithInvalidShape() {
- + int[] shape = new int[] {1, -1, 2};
- + Assert.assertThrows(IllegalArgumentException.class, () -> new TensorBufferUint8(shape));
- + }
-
- - @Test
- - public void testCreateUsingShapeWithZero() {
- - int[] shape = new int[] {1, 0, 2};
- - TensorBufferUint8 tensorBufferUint8 = new TensorBufferUint8(shape);
- - assertThat(tensorBufferUint8).isNotNull();
- - assertThat(tensorBufferUint8.getFlatSize()).isEqualTo(0);
- - }
- + @Test
- + public void testCreateUsingShapeWithZero() {
- + int[] shape = new int[] {1, 0, 2};
- + TensorBufferUint8 tensorBufferUint8 = new TensorBufferUint8(shape);
- + assertThat(tensorBufferUint8).isNotNull();
- + assertThat(tensorBufferUint8.getFlatSize()).isEqualTo(0);
- + }
-
- - @Test
- - public void testGetDataType() {
- - TensorBufferUint8 tensorBufferUint8 = new TensorBufferUint8();
- - assertThat(tensorBufferUint8.getDataType()).isEqualTo(DataType.UINT8);
- - }
- + @Test
- + public void testGetDataType() {
- + TensorBufferUint8 tensorBufferUint8 = new TensorBufferUint8();
- + assertThat(tensorBufferUint8.getDataType()).isEqualTo(DataType.UINT8);
- + }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/audio/classifier/audio_classifier_jni.cc b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/audio/classifier/audio_classifier_jni.cc
- index d62da546a484b..c3c21fa43ab49 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/audio/classifier/audio_classifier_jni.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/audio/classifier/audio_classifier_jni.cc
- @@ -134,7 +134,8 @@ jobject ConvertToClassificationResults(JNIEnv* env,
- }
-
- // Creates an AudioClassifierOptions proto based on the Java class.
- -AudioClassifierOptions ConvertToProtoOptions(JNIEnv* env, jobject java_options,
- +AudioClassifierOptions ConvertToProtoOptions(JNIEnv* env,
- + jobject java_options,
- jlong base_options_handle) {
- AudioClassifierOptions proto_options;
-
- @@ -214,7 +215,9 @@ jlong CreateAudioClassifierFromOptions(JNIEnv* env,
-
- extern "C" JNIEXPORT void JNICALL
- Java_org_tensorflow_lite_task_audio_classifier_AudioClassifier_deinitJni(
- - JNIEnv* env, jobject thiz, jlong native_handle) {
- + JNIEnv* env,
- + jobject thiz,
- + jlong native_handle) {
- delete reinterpret_cast<AudioClassifier*>(native_handle);
- }
-
- @@ -223,9 +226,13 @@ Java_org_tensorflow_lite_task_audio_classifier_AudioClassifier_deinitJni(
- // values will be ignored.
- extern "C" JNIEXPORT jlong JNICALL
- Java_org_tensorflow_lite_task_audio_classifier_AudioClassifier_initJniWithModelFdAndOptions(
- - JNIEnv* env, jclass thiz, jint file_descriptor,
- - jlong file_descriptor_length, jlong file_descriptor_offset,
- - jobject java_options, jlong base_options_handle) {
- + JNIEnv* env,
- + jclass thiz,
- + jint file_descriptor,
- + jlong file_descriptor_length,
- + jlong file_descriptor_offset,
- + jobject java_options,
- + jlong base_options_handle) {
- AudioClassifierOptions proto_options =
- ConvertToProtoOptions(env, java_options, base_options_handle);
- auto file_descriptor_meta = proto_options.mutable_base_options()
- @@ -243,7 +250,10 @@ Java_org_tensorflow_lite_task_audio_classifier_AudioClassifier_initJniWithModelF
-
- extern "C" JNIEXPORT jlong JNICALL
- Java_org_tensorflow_lite_task_audio_classifier_AudioClassifier_initJniWithByteBuffer(
- - JNIEnv* env, jclass thiz, jobject model_buffer, jobject java_options,
- + JNIEnv* env,
- + jclass thiz,
- + jobject model_buffer,
- + jobject java_options,
- jlong base_options_handle) {
- AudioClassifierOptions proto_options =
- ConvertToProtoOptions(env, java_options, base_options_handle);
- @@ -262,7 +272,9 @@ Java_org_tensorflow_lite_task_audio_classifier_AudioClassifier_initJniWithByteBu
- // caching it in JAVA layer.
- extern "C" JNIEXPORT jlong JNICALL
- Java_org_tensorflow_lite_task_audio_classifier_AudioClassifier_getRequiredSampleRateNative(
- - JNIEnv* env, jclass thiz, jlong native_handle) {
- + JNIEnv* env,
- + jclass thiz,
- + jlong native_handle) {
- auto* classifier = reinterpret_cast<AudioClassifier*>(native_handle);
- StatusOr<AudioBuffer::AudioFormat> format_or =
- classifier->GetRequiredAudioFormat();
- @@ -279,7 +291,9 @@ Java_org_tensorflow_lite_task_audio_classifier_AudioClassifier_getRequiredSample
-
- extern "C" JNIEXPORT jlong JNICALL
- Java_org_tensorflow_lite_task_audio_classifier_AudioClassifier_getRequiredChannelsNative(
- - JNIEnv* env, jclass thiz, jlong native_handle) {
- + JNIEnv* env,
- + jclass thiz,
- + jlong native_handle) {
- auto* classifier = reinterpret_cast<AudioClassifier*>(native_handle);
- StatusOr<AudioBuffer::AudioFormat> format_or =
- classifier->GetRequiredAudioFormat();
- @@ -296,15 +310,21 @@ Java_org_tensorflow_lite_task_audio_classifier_AudioClassifier_getRequiredChanne
-
- extern "C" JNIEXPORT jlong JNICALL
- Java_org_tensorflow_lite_task_audio_classifier_AudioClassifier_getRequiredInputBufferSizeNative(
- - JNIEnv* env, jclass thiz, jlong native_handle) {
- + JNIEnv* env,
- + jclass thiz,
- + jlong native_handle) {
- auto* classifier = reinterpret_cast<AudioClassifier*>(native_handle);
- return classifier->GetRequiredInputBufferSize();
- }
-
- extern "C" JNIEXPORT jobject JNICALL
- Java_org_tensorflow_lite_task_audio_classifier_AudioClassifier_classifyNative(
- - JNIEnv* env, jclass thiz, jlong native_handle, jbyteArray java_array,
- - jint channels, jint sample_rate) {
- + JNIEnv* env,
- + jclass thiz,
- + jlong native_handle,
- + jbyteArray java_array,
- + jint channels,
- + jint sample_rate) {
- // Get the primitive native array. Depending on the JAVA runtime, the returned
- // array might be a copy of the JAVA array (or not).
- jbyte* native_array = env->GetByteArrayElements(java_array, nullptr);
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/core/task_jni_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/core/task_jni_utils.cc
- index 2fd1d7ca9a593..75f93d6f2e458 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/core/task_jni_utils.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/core/task_jni_utils.cc
- @@ -30,7 +30,10 @@ using ::tflite::task::core::BaseOptions;
-
- extern "C" JNIEXPORT jlong JNICALL
- Java_org_tensorflow_lite_task_core_TaskJniUtils_createProtoBaseOptions(
- - JNIEnv* env, jclass thiz, jint delegate, jint num_threads) {
- + JNIEnv* env,
- + jclass thiz,
- + jint delegate,
- + jint num_threads) {
- StatusOr<Delegate> delegate_proto_or = ConvertToProtoDelegate(delegate);
- if (!delegate_proto_or.ok()) {
- ThrowException(env, kIllegalStateException,
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert/bert_nl_classifier_jni.cc b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert/bert_nl_classifier_jni.cc
- index 6657ef4ca2d95..2daacdf893903 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert/bert_nl_classifier_jni.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert/bert_nl_classifier_jni.cc
- @@ -32,7 +32,9 @@ using ::tflite::task::text::BertNLClassifierOptions;
- using ::tflite::task::text::nlclassifier::RunClassifier;
-
- BertNLClassifierOptions ConvertJavaBertNLClassifierOptions(
- - JNIEnv* env, jobject java_options, jlong base_options_handle) {
- + JNIEnv* env,
- + jobject java_options,
- + jlong base_options_handle) {
- BertNLClassifierOptions proto_options;
-
- if (base_options_handle != kInvalidPointer) {
- @@ -47,13 +49,18 @@ BertNLClassifierOptions ConvertJavaBertNLClassifierOptions(
-
- extern "C" JNIEXPORT void JNICALL
- Java_org_tensorflow_lite_task_text_nlclassifier_BertNLClassifier_deinitJni(
- - JNIEnv* env, jobject thiz, jlong native_handle) {
- + JNIEnv* env,
- + jobject thiz,
- + jlong native_handle) {
- delete reinterpret_cast<BertNLClassifier*>(native_handle);
- }
-
- extern "C" JNIEXPORT jlong JNICALL
- Java_org_tensorflow_lite_task_text_nlclassifier_BertNLClassifier_initJniWithByteBuffer(
- - JNIEnv* env, jclass thiz, jobject model_buffer, jobject java_options,
- + JNIEnv* env,
- + jclass thiz,
- + jobject model_buffer,
- + jobject java_options,
- jlong base_options_handle) {
- BertNLClassifierOptions proto_options = ConvertJavaBertNLClassifierOptions(
- env, java_options, base_options_handle);
- @@ -76,7 +83,10 @@ Java_org_tensorflow_lite_task_text_nlclassifier_BertNLClassifier_initJniWithByte
-
- extern "C" JNIEXPORT jlong JNICALL
- Java_org_tensorflow_lite_task_text_nlclassifier_BertNLClassifier_initJniWithFileDescriptor(
- - JNIEnv* env, jclass thiz, jint fd, jobject java_options,
- + JNIEnv* env,
- + jclass thiz,
- + jint fd,
- + jobject java_options,
- jlong base_options_handle) {
- BertNLClassifierOptions proto_options = ConvertJavaBertNLClassifierOptions(
- env, java_options, base_options_handle);
- @@ -100,6 +110,9 @@ Java_org_tensorflow_lite_task_text_nlclassifier_BertNLClassifier_initJniWithFile
-
- extern "C" JNIEXPORT jobject JNICALL
- Java_org_tensorflow_lite_task_text_nlclassifier_BertNLClassifier_classifyNative(
- - JNIEnv* env, jclass clazz, jlong native_handle, jstring text) {
- + JNIEnv* env,
- + jclass clazz,
- + jlong native_handle,
- + jstring text) {
- return RunClassifier(env, native_handle, text);
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni.cc b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni.cc
- index f6d34a5f74e2b..4c71a80ea1528 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni.cc
- @@ -94,14 +94,19 @@ NLClassifierOptions ConvertToProtoOptions(JNIEnv* env,
-
- extern "C" JNIEXPORT void JNICALL
- Java_org_tensorflow_lite_task_text_nlclassifier_NLClassifier_deinitJni(
- - JNIEnv* env, jobject thiz, jlong native_handle) {
- + JNIEnv* env,
- + jobject thiz,
- + jlong native_handle) {
- delete reinterpret_cast<NLClassifier*>(native_handle);
- }
-
- extern "C" JNIEXPORT jlong JNICALL
- Java_org_tensorflow_lite_task_text_nlclassifier_NLClassifier_initJniWithByteBuffer(
- - JNIEnv* env, jclass thiz, jobject nl_classifier_options,
- - jobject model_buffer, jlong base_options_handle) {
- + JNIEnv* env,
- + jclass thiz,
- + jobject nl_classifier_options,
- + jobject model_buffer,
- + jlong base_options_handle) {
- auto model = GetMappedFileBuffer(env, model_buffer);
- tflite::support::StatusOr<std::unique_ptr<NLClassifier>> classifier_or;
-
- @@ -125,7 +130,10 @@ Java_org_tensorflow_lite_task_text_nlclassifier_NLClassifier_initJniWithByteBuff
-
- extern "C" JNIEXPORT jlong JNICALL
- Java_org_tensorflow_lite_task_text_nlclassifier_NLClassifier_initJniWithFileDescriptor(
- - JNIEnv* env, jclass thiz, jobject nl_classifier_options, jint fd,
- + JNIEnv* env,
- + jclass thiz,
- + jobject nl_classifier_options,
- + jint fd,
- jlong base_options_handle) {
- tflite::support::StatusOr<std::unique_ptr<NLClassifier>> classifier_or;
-
- @@ -151,6 +159,9 @@ Java_org_tensorflow_lite_task_text_nlclassifier_NLClassifier_initJniWithFileDesc
-
- extern "C" JNIEXPORT jobject JNICALL
- Java_org_tensorflow_lite_task_text_nlclassifier_NLClassifier_classifyNative(
- - JNIEnv* env, jclass thiz, jlong native_handle, jstring text) {
- + JNIEnv* env,
- + jclass thiz,
- + jlong native_handle,
- + jstring text) {
- return RunClassifier(env, native_handle, text);
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/qa/bert_question_answerer_jni.cc b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/qa/bert_question_answerer_jni.cc
- index 1ff0d9fc46161..b77746a2eee68 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/qa/bert_question_answerer_jni.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/qa/bert_question_answerer_jni.cc
- @@ -52,14 +52,19 @@ BertQuestionAnswererOptions ConvertToProtoOptions(jlong base_options_handle) {
-
- extern "C" JNIEXPORT void JNICALL
- Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_deinitJni(
- - JNIEnv* env, jobject thiz, jlong native_handle) {
- + JNIEnv* env,
- + jobject thiz,
- + jlong native_handle) {
- delete reinterpret_cast<QuestionAnswerer*>(native_handle);
- }
-
- extern "C" JNIEXPORT jlong JNICALL
- Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_initJniWithFileDescriptor(
- - JNIEnv* env, jclass thiz, jint file_descriptor,
- - jlong file_descriptor_length, jlong file_descriptor_offset,
- + JNIEnv* env,
- + jclass thiz,
- + jint file_descriptor,
- + jlong file_descriptor_length,
- + jlong file_descriptor_offset,
- jlong base_options_handle) {
- BertQuestionAnswererOptions proto_options =
- ConvertToProtoOptions(base_options_handle);
- @@ -89,7 +94,9 @@ Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_initJniWithFileDescri
-
- extern "C" JNIEXPORT jlong JNICALL
- Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_initJniWithBertByteBuffers(
- - JNIEnv* env, jclass thiz, jobjectArray model_buffers) {
- + JNIEnv* env,
- + jclass thiz,
- + jobjectArray model_buffers) {
- absl::string_view model =
- GetMappedFileBuffer(env, env->GetObjectArrayElement(model_buffers, 0));
- absl::string_view vocab =
- @@ -111,7 +118,9 @@ Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_initJniWithBertByteBu
-
- extern "C" JNIEXPORT jlong JNICALL
- Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_initJniWithAlbertByteBuffers(
- - JNIEnv* env, jclass thiz, jobjectArray model_buffers) {
- + JNIEnv* env,
- + jclass thiz,
- + jobjectArray model_buffers) {
- absl::string_view model =
- GetMappedFileBuffer(env, env->GetObjectArrayElement(model_buffers, 0));
- absl::string_view sp_model =
- @@ -133,7 +142,10 @@ Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_initJniWithAlbertByte
-
- extern "C" JNIEXPORT jobject JNICALL
- Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_answerNative(
- - JNIEnv* env, jclass thiz, jlong native_handle, jstring context,
- + JNIEnv* env,
- + jclass thiz,
- + jlong native_handle,
- + jstring context,
- jstring question) {
- auto* question_answerer = reinterpret_cast<QuestionAnswerer*>(native_handle);
-
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/searcher/text_searcher_jni.cc b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/searcher/text_searcher_jni.cc
- index 8573b0f444626..c207755d3393f 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/searcher/text_searcher_jni.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/searcher/text_searcher_jni.cc
- @@ -48,7 +48,8 @@ using ::tflite::task::text::TextSearcherOptions;
-
- // Creates an TextSearcherOptions proto based on the Java class.
- TextSearcherOptions ConvertToProtoOptions(jlong base_options_handle,
- - bool l2_normalize, bool quantize,
- + bool l2_normalize,
- + bool quantize,
- int index_descriptor,
- int max_results) {
- TextSearcherOptions proto_options;
- @@ -120,7 +121,9 @@ jobject ConvertToSearchResults(JNIEnv* env, const SearchResult& results) {
-
- extern "C" JNIEXPORT void JNICALL
- Java_org_tensorflow_lite_task_text_searcher_TextSearcher_deinitJni(
- - JNIEnv* env, jobject thiz, jlong native_handle) {
- + JNIEnv* env,
- + jobject thiz,
- + jlong native_handle) {
- delete reinterpret_cast<TextSearcher*>(native_handle);
- }
-
- @@ -129,10 +132,16 @@ Java_org_tensorflow_lite_task_text_searcher_TextSearcher_deinitJni(
- // values will be ignored.
- extern "C" JNIEXPORT jlong JNICALL
- Java_org_tensorflow_lite_task_text_searcher_TextSearcher_initJniWithModelFdAndOptions(
- - JNIEnv* env, jclass thiz, jint model_descriptor,
- - jlong model_descriptor_length, jlong model_descriptor_offset,
- - jlong base_options_handle, bool l2_normalize, bool quantize,
- - jint index_descriptor, int max_results) {
- + JNIEnv* env,
- + jclass thiz,
- + jint model_descriptor,
- + jlong model_descriptor_length,
- + jlong model_descriptor_offset,
- + jlong base_options_handle,
- + bool l2_normalize,
- + bool quantize,
- + jint index_descriptor,
- + int max_results) {
- TextSearcherOptions proto_options =
- ConvertToProtoOptions(base_options_handle, l2_normalize, quantize,
- index_descriptor, max_results);
- @@ -152,8 +161,14 @@ Java_org_tensorflow_lite_task_text_searcher_TextSearcher_initJniWithModelFdAndOp
-
- extern "C" JNIEXPORT jlong JNICALL
- Java_org_tensorflow_lite_task_text_searcher_TextSearcher_initJniWithByteBuffer(
- - JNIEnv* env, jclass thiz, jobject model_buffer, jlong base_options_handle,
- - bool l2_normalize, bool quantize, jlong index_descriptor, int max_results) {
- + JNIEnv* env,
- + jclass thiz,
- + jobject model_buffer,
- + jlong base_options_handle,
- + bool l2_normalize,
- + bool quantize,
- + jlong index_descriptor,
- + int max_results) {
- TextSearcherOptions proto_options =
- ConvertToProtoOptions(base_options_handle, l2_normalize, quantize,
- index_descriptor, max_results);
- @@ -166,7 +181,10 @@ Java_org_tensorflow_lite_task_text_searcher_TextSearcher_initJniWithByteBuffer(
-
- extern "C" JNIEXPORT jobject JNICALL
- Java_org_tensorflow_lite_task_text_searcher_TextSearcher_searchNative(
- - JNIEnv* env, jclass thiz, jlong native_handle, jstring text) {
- + JNIEnv* env,
- + jclass thiz,
- + jlong native_handle,
- + jstring text) {
- auto* searcher = reinterpret_cast<TextSearcher*>(native_handle);
- auto results_or = searcher->Search(JStringToString(env, text));
-
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/classifier/image_classifier_jni.cc b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/classifier/image_classifier_jni.cc
- index 18e2ee1a7d4ab..2a713cf8b63cf 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/classifier/image_classifier_jni.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/classifier/image_classifier_jni.cc
- @@ -54,7 +54,8 @@ using ::tflite::task::vision::ImageClassifier;
- using ::tflite::task::vision::ImageClassifierOptions;
-
- // Creates an ImageClassifierOptions proto based on the Java class.
- -ImageClassifierOptions ConvertToProtoOptions(JNIEnv* env, jobject java_options,
- +ImageClassifierOptions ConvertToProtoOptions(JNIEnv* env,
- + jobject java_options,
- jlong base_options_handle) {
- ImageClassifierOptions proto_options;
-
- @@ -175,7 +176,9 @@ jlong CreateImageClassifierFromOptions(JNIEnv* env,
-
- extern "C" JNIEXPORT void JNICALL
- Java_org_tensorflow_lite_task_vision_classifier_ImageClassifier_deinitJni(
- - JNIEnv* env, jobject thiz, jlong native_handle) {
- + JNIEnv* env,
- + jobject thiz,
- + jlong native_handle) {
- delete reinterpret_cast<ImageClassifier*>(native_handle);
- }
-
- @@ -184,9 +187,13 @@ Java_org_tensorflow_lite_task_vision_classifier_ImageClassifier_deinitJni(
- // values will be ignored.
- extern "C" JNIEXPORT jlong JNICALL
- Java_org_tensorflow_lite_task_vision_classifier_ImageClassifier_initJniWithModelFdAndOptions(
- - JNIEnv* env, jclass thiz, jint file_descriptor,
- - jlong file_descriptor_length, jlong file_descriptor_offset,
- - jobject java_options, jlong base_options_handle) {
- + JNIEnv* env,
- + jclass thiz,
- + jint file_descriptor,
- + jlong file_descriptor_length,
- + jlong file_descriptor_offset,
- + jobject java_options,
- + jlong base_options_handle) {
- ImageClassifierOptions proto_options =
- ConvertToProtoOptions(env, java_options, base_options_handle);
- auto file_descriptor_meta = proto_options.mutable_base_options()
- @@ -204,7 +211,10 @@ Java_org_tensorflow_lite_task_vision_classifier_ImageClassifier_initJniWithModel
-
- extern "C" JNIEXPORT jlong JNICALL
- Java_org_tensorflow_lite_task_vision_classifier_ImageClassifier_initJniWithByteBuffer(
- - JNIEnv* env, jclass thiz, jobject model_buffer, jobject java_options,
- + JNIEnv* env,
- + jclass thiz,
- + jobject model_buffer,
- + jobject java_options,
- jlong base_options_handle) {
- ImageClassifierOptions proto_options =
- ConvertToProtoOptions(env, java_options, base_options_handle);
- @@ -220,7 +230,10 @@ Java_org_tensorflow_lite_task_vision_classifier_ImageClassifier_initJniWithByteB
-
- extern "C" JNIEXPORT jobject JNICALL
- Java_org_tensorflow_lite_task_vision_classifier_ImageClassifier_classifyNative(
- - JNIEnv* env, jclass thiz, jlong native_handle, jlong frame_buffer_handle,
- + JNIEnv* env,
- + jclass thiz,
- + jlong native_handle,
- + jlong frame_buffer_handle,
- jintArray jroi) {
- auto* classifier = reinterpret_cast<ImageClassifier*>(native_handle);
- // frame_buffer will be deleted after inference is done in
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/core/base_vision_task_api_jni.cc b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/core/base_vision_task_api_jni.cc
- index 84bff227f2543..2cda1b500aeb5 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/core/base_vision_task_api_jni.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/core/base_vision_task_api_jni.cc
- @@ -31,8 +31,13 @@ using ::tflite::task::vision::FrameBuffer;
-
- extern "C" JNIEXPORT jlong JNICALL
- Java_org_tensorflow_lite_task_vision_core_BaseVisionTaskApi_createFrameBufferFromByteBuffer(
- - JNIEnv* env, jclass thiz, jobject jimage_byte_buffer, jint width,
- - jint height, jint jorientation, jint jcolor_space_type) {
- + JNIEnv* env,
- + jclass thiz,
- + jobject jimage_byte_buffer,
- + jint width,
- + jint height,
- + jint jorientation,
- + jint jcolor_space_type) {
- auto frame_buffer_or = CreateFrameBufferFromByteBuffer(
- env, jimage_byte_buffer, width, height, jorientation, jcolor_space_type);
- if (frame_buffer_or.ok()) {
- @@ -49,8 +54,14 @@ Java_org_tensorflow_lite_task_vision_core_BaseVisionTaskApi_createFrameBufferFro
-
- extern "C" JNIEXPORT jlong JNICALL
- Java_org_tensorflow_lite_task_vision_core_BaseVisionTaskApi_createFrameBufferFromBytes(
- - JNIEnv* env, jclass thiz, jbyteArray jimage_bytes, jint width, jint height,
- - jint jorientation, jint jcolor_space_type, jlongArray jbyte_array_handle) {
- + JNIEnv* env,
- + jclass thiz,
- + jbyteArray jimage_bytes,
- + jint width,
- + jint height,
- + jint jorientation,
- + jint jcolor_space_type,
- + jlongArray jbyte_array_handle) {
- auto frame_buffer_or =
- CreateFrameBufferFromBytes(env, jimage_bytes, width, height, jorientation,
- jcolor_space_type, jbyte_array_handle);
- @@ -68,9 +79,17 @@ Java_org_tensorflow_lite_task_vision_core_BaseVisionTaskApi_createFrameBufferFro
-
- extern "C" JNIEXPORT jlong JNICALL
- Java_org_tensorflow_lite_task_vision_core_BaseVisionTaskApi_createFrameBufferFromPlanes(
- - JNIEnv* env, jclass thiz, jobject jy_plane, jobject ju_plane,
- - jobject jv_plane, jint width, jint height, jint row_stride_y,
- - jint row_stride_uv, jint pixel_stride_uv, jint orientation) {
- + JNIEnv* env,
- + jclass thiz,
- + jobject jy_plane,
- + jobject ju_plane,
- + jobject jv_plane,
- + jint width,
- + jint height,
- + jint row_stride_y,
- + jint row_stride_uv,
- + jint pixel_stride_uv,
- + jint orientation) {
- auto frame_buffer_or = CreateFrameBufferFromYuvPlanes(
- env, jy_plane, ju_plane, jv_plane, width, height, row_stride_y,
- row_stride_uv, pixel_stride_uv, orientation);
- @@ -88,8 +107,11 @@ Java_org_tensorflow_lite_task_vision_core_BaseVisionTaskApi_createFrameBufferFro
-
- extern "C" JNIEXPORT void JNICALL
- Java_org_tensorflow_lite_task_vision_core_BaseVisionTaskApi_deleteFrameBuffer(
- - JNIEnv* env, jobject thiz, jlong frame_buffer_handle,
- - jlong byte_array_handle, jbyteArray jbyte_array) {
- + JNIEnv* env,
- + jobject thiz,
- + jlong frame_buffer_handle,
- + jlong byte_array_handle,
- + jbyteArray jbyte_array) {
- delete reinterpret_cast<FrameBuffer*>(frame_buffer_handle);
- jbyte* bytes_ptr = reinterpret_cast<jbyte*>(byte_array_handle);
- if (bytes_ptr != NULL) {
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/detector/object_detector_jni.cc b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/detector/object_detector_jni.cc
- index ddb0b72a25b65..f720795263791 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/detector/object_detector_jni.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/detector/object_detector_jni.cc
- @@ -54,7 +54,8 @@ using ::tflite::task::vision::ObjectDetector;
- using ::tflite::task::vision::ObjectDetectorOptions;
-
- // Creates an ObjectDetectorOptions proto based on the Java class.
- -ObjectDetectorOptions ConvertToProtoOptions(JNIEnv* env, jobject java_options,
- +ObjectDetectorOptions ConvertToProtoOptions(JNIEnv* env,
- + jobject java_options,
- jlong base_options_handle) {
- ObjectDetectorOptions proto_options;
-
- @@ -183,7 +184,9 @@ jlong CreateObjectDetectorFromOptions(JNIEnv* env,
-
- extern "C" JNIEXPORT void JNICALL
- Java_org_tensorflow_lite_task_vision_detector_ObjectDetector_deinitJni(
- - JNIEnv* env, jobject thiz, jlong native_handle) {
- + JNIEnv* env,
- + jobject thiz,
- + jlong native_handle) {
- delete reinterpret_cast<ObjectDetector*>(native_handle);
- }
-
- @@ -192,9 +195,13 @@ Java_org_tensorflow_lite_task_vision_detector_ObjectDetector_deinitJni(
- // values will be ignored.
- extern "C" JNIEXPORT jlong JNICALL
- Java_org_tensorflow_lite_task_vision_detector_ObjectDetector_initJniWithModelFdAndOptions(
- - JNIEnv* env, jclass thiz, jint file_descriptor,
- - jlong file_descriptor_length, jlong file_descriptor_offset,
- - jobject java_options, jlong base_options_handle) {
- + JNIEnv* env,
- + jclass thiz,
- + jint file_descriptor,
- + jlong file_descriptor_length,
- + jlong file_descriptor_offset,
- + jobject java_options,
- + jlong base_options_handle) {
- ObjectDetectorOptions proto_options =
- ConvertToProtoOptions(env, java_options, base_options_handle);
- auto file_descriptor_meta = proto_options.mutable_base_options()
- @@ -212,7 +219,10 @@ Java_org_tensorflow_lite_task_vision_detector_ObjectDetector_initJniWithModelFdA
-
- extern "C" JNIEXPORT jlong JNICALL
- Java_org_tensorflow_lite_task_vision_detector_ObjectDetector_initJniWithByteBuffer(
- - JNIEnv* env, jclass thiz, jobject model_buffer, jobject java_options,
- + JNIEnv* env,
- + jclass thiz,
- + jobject model_buffer,
- + jobject java_options,
- jlong base_options_handle) {
- ObjectDetectorOptions proto_options =
- ConvertToProtoOptions(env, java_options, base_options_handle);
- @@ -224,7 +234,10 @@ Java_org_tensorflow_lite_task_vision_detector_ObjectDetector_initJniWithByteBuff
-
- extern "C" JNIEXPORT jobject JNICALL
- Java_org_tensorflow_lite_task_vision_detector_ObjectDetector_detectNative(
- - JNIEnv* env, jclass thiz, jlong native_handle, jlong frame_buffer_handle) {
- + JNIEnv* env,
- + jclass thiz,
- + jlong native_handle,
- + jlong frame_buffer_handle) {
- auto* detector = reinterpret_cast<ObjectDetector*>(native_handle);
- // frame_buffer will be deleted after inference is done in
- // base_vision_api_jni.cc.
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/jni_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/jni_utils.cc
- index 1b08e56ed509b..e0c94e2ec72c6 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/jni_utils.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/jni_utils.cc
- @@ -135,8 +135,12 @@ StatusOr<FrameBuffer::Format> GetYUVImageFormat(const uint8* u_buffer,
- }
-
- StatusOr<std::unique_ptr<FrameBuffer>> CreateFrameBufferFromByteBuffer(
- - JNIEnv* env, jobject jimage_byte_buffer, jint width, jint height,
- - jint jorientation, jint jcolor_space_type) {
- + JNIEnv* env,
- + jobject jimage_byte_buffer,
- + jint width,
- + jint height,
- + jint jorientation,
- + jint jcolor_space_type) {
- absl::string_view image = GetMappedFileBuffer(env, jimage_byte_buffer);
- return CreateFromRawBuffer(
- reinterpret_cast<const uint8*>(image.data()),
- @@ -146,8 +150,13 @@ StatusOr<std::unique_ptr<FrameBuffer>> CreateFrameBufferFromByteBuffer(
- }
-
- StatusOr<std::unique_ptr<FrameBuffer>> CreateFrameBufferFromBytes(
- - JNIEnv* env, jbyteArray jimage_bytes, jint width, jint height,
- - jint jorientation, jint jcolor_space_type, jlongArray jbyte_array_handle) {
- + JNIEnv* env,
- + jbyteArray jimage_bytes,
- + jint width,
- + jint height,
- + jint jorientation,
- + jint jcolor_space_type,
- + jlongArray jbyte_array_handle) {
- jbyte* jimage_ptr = env->GetByteArrayElements(jimage_bytes, NULL);
- // Free jimage_ptr together with frame_buffer after inference is finished.
- jlong jimage_ptr_handle = reinterpret_cast<jlong>(jimage_ptr);
- @@ -168,9 +177,16 @@ StatusOr<std::unique_ptr<FrameBuffer>> CreateFrameBufferFromBytes(
- }
-
- StatusOr<std::unique_ptr<FrameBuffer>> CreateFrameBufferFromYuvPlanes(
- - JNIEnv* env, jobject jy_plane, jobject ju_plane, jobject jv_plane,
- - jint width, jint height, jint row_stride_y, jint row_stride_uv,
- - jint pixel_stride_uv, jint jorientation) {
- + JNIEnv* env,
- + jobject jy_plane,
- + jobject ju_plane,
- + jobject jv_plane,
- + jint width,
- + jint height,
- + jint row_stride_y,
- + jint row_stride_uv,
- + jint pixel_stride_uv,
- + jint jorientation) {
- const uint8* y_plane =
- reinterpret_cast<const uint8*>(GetMappedFileBuffer(env, jy_plane).data());
- const uint8* u_plane =
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/jni_utils.h b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/jni_utils.h
- index dbe32f8a3f2a5..4d7ec17a1c042 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/jni_utils.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/jni_utils.h
- @@ -34,23 +34,35 @@ FrameBuffer::Orientation ConvertToFrameBufferOrientation(JNIEnv* env,
-
- // Creates FrameBuffer from a direct ByteBuffer.
- ::tflite::support::StatusOr<std::unique_ptr<FrameBuffer>>
- -CreateFrameBufferFromByteBuffer(JNIEnv* env, jobject jimage_byte_buffer,
- - jint width, jint height, jint jorientation,
- +CreateFrameBufferFromByteBuffer(JNIEnv* env,
- + jobject jimage_byte_buffer,
- + jint width,
- + jint height,
- + jint jorientation,
- jint jcolor_space_type);
-
- // Creates FrameBuffer from a byte array.
- ::tflite::support::StatusOr<std::unique_ptr<FrameBuffer>>
- -CreateFrameBufferFromBytes(JNIEnv* env, jbyteArray jimage_bytes, jint width,
- - jint height, jint jorientation,
- +CreateFrameBufferFromBytes(JNIEnv* env,
- + jbyteArray jimage_bytes,
- + jint width,
- + jint height,
- + jint jorientation,
- jint jcolor_space_type,
- jlongArray jbyte_array_handle);
-
- // Creates FrameBuffer from YUV planes.
- ::tflite::support::StatusOr<std::unique_ptr<FrameBuffer>>
- -CreateFrameBufferFromYuvPlanes(JNIEnv* env, jobject jy_plane, jobject ju_plane,
- - jobject jv_plane, jint width, jint height,
- - jint row_stride_y, jint row_stride_uv,
- - jint pixel_stride_uv, jint jorientation);
- +CreateFrameBufferFromYuvPlanes(JNIEnv* env,
- + jobject jy_plane,
- + jobject ju_plane,
- + jobject jv_plane,
- + jint width,
- + jint height,
- + jint row_stride_y,
- + jint row_stride_uv,
- + jint pixel_stride_uv,
- + jint jorientation);
- } // namespace vision
- } // namespace task
- } // namespace tflite
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/searcher/image_searcher_jni.cc b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/searcher/image_searcher_jni.cc
- index e57f12a16aab3..84cad5db43ea2 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/searcher/image_searcher_jni.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/searcher/image_searcher_jni.cc
- @@ -52,7 +52,8 @@ using ::tflite::task::vision::ImageSearcherOptions;
-
- // Creates an ImageSearcherOptions proto based on the Java class.
- ImageSearcherOptions ConvertToProtoOptions(jlong base_options_handle,
- - bool l2_normalize, bool quantize,
- + bool l2_normalize,
- + bool quantize,
- int index_descriptor,
- int max_results) {
- ImageSearcherOptions proto_options;
- @@ -124,7 +125,9 @@ jobject ConvertToSearchResults(JNIEnv* env, const SearchResult& results) {
-
- extern "C" JNIEXPORT void JNICALL
- Java_org_tensorflow_lite_task_vision_searcher_ImageSearcher_deinitJni(
- - JNIEnv* env, jobject thiz, jlong native_handle) {
- + JNIEnv* env,
- + jobject thiz,
- + jlong native_handle) {
- delete reinterpret_cast<ImageSearcher*>(native_handle);
- }
-
- @@ -133,10 +136,16 @@ Java_org_tensorflow_lite_task_vision_searcher_ImageSearcher_deinitJni(
- // values will be ignored.
- extern "C" JNIEXPORT jlong JNICALL
- Java_org_tensorflow_lite_task_vision_searcher_ImageSearcher_initJniWithModelFdAndOptions(
- - JNIEnv* env, jclass thiz, jint model_descriptor,
- - jlong model_descriptor_length, jlong model_descriptor_offset,
- - jlong base_options_handle, bool l2_normalize, bool quantize,
- - jint index_descriptor, int max_results) {
- + JNIEnv* env,
- + jclass thiz,
- + jint model_descriptor,
- + jlong model_descriptor_length,
- + jlong model_descriptor_offset,
- + jlong base_options_handle,
- + bool l2_normalize,
- + bool quantize,
- + jint index_descriptor,
- + int max_results) {
- ImageSearcherOptions proto_options =
- ConvertToProtoOptions(base_options_handle, l2_normalize, quantize,
- index_descriptor, max_results);
- @@ -156,8 +165,14 @@ Java_org_tensorflow_lite_task_vision_searcher_ImageSearcher_initJniWithModelFdAn
-
- extern "C" JNIEXPORT jlong JNICALL
- Java_org_tensorflow_lite_task_vision_searcher_ImageSearcher_initJniWithByteBuffer(
- - JNIEnv* env, jclass thiz, jobject model_buffer, jlong base_options_handle,
- - bool l2_normalize, bool quantize, jlong index_descriptor, int max_results) {
- + JNIEnv* env,
- + jclass thiz,
- + jobject model_buffer,
- + jlong base_options_handle,
- + bool l2_normalize,
- + bool quantize,
- + jlong index_descriptor,
- + int max_results) {
- ImageSearcherOptions proto_options =
- ConvertToProtoOptions(base_options_handle, l2_normalize, quantize,
- index_descriptor, max_results);
- @@ -170,7 +185,10 @@ Java_org_tensorflow_lite_task_vision_searcher_ImageSearcher_initJniWithByteBuffe
-
- extern "C" JNIEXPORT jobject JNICALL
- Java_org_tensorflow_lite_task_vision_searcher_ImageSearcher_searchNative(
- - JNIEnv* env, jclass thiz, jlong native_handle, jlong frame_buffer_handle,
- + JNIEnv* env,
- + jclass thiz,
- + jlong native_handle,
- + jlong frame_buffer_handle,
- jintArray jroi) {
- auto* searcher = reinterpret_cast<ImageSearcher*>(native_handle);
- // frame_buffer will be deleted after inference is done in
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/segmenter/image_segmenter_jni.cc b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/segmenter/image_segmenter_jni.cc
- index 40fa4472d37e1..8d8c8eec34295 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/segmenter/image_segmenter_jni.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/segmenter/image_segmenter_jni.cc
- @@ -194,7 +194,9 @@ jlong CreateImageSegmenterFromOptions(JNIEnv* env,
-
- extern "C" JNIEXPORT void JNICALL
- Java_org_tensorflow_lite_task_vision_segmenter_ImageSegmenter_deinitJni(
- - JNIEnv* env, jobject thiz, jlong native_handle) {
- + JNIEnv* env,
- + jobject thiz,
- + jlong native_handle) {
- delete reinterpret_cast<ImageSegmenter*>(native_handle);
- }
-
- @@ -203,9 +205,14 @@ Java_org_tensorflow_lite_task_vision_segmenter_ImageSegmenter_deinitJni(
- // values will be ignored.
- extern "C" JNIEXPORT jlong JNICALL
- Java_org_tensorflow_lite_task_vision_segmenter_ImageSegmenter_initJniWithModelFdAndOptions(
- - JNIEnv* env, jclass thiz, jint file_descriptor,
- - jlong file_descriptor_length, jlong file_descriptor_offset,
- - jstring display_names_locale, jint output_type, jlong base_options_handle) {
- + JNIEnv* env,
- + jclass thiz,
- + jint file_descriptor,
- + jlong file_descriptor_length,
- + jlong file_descriptor_offset,
- + jstring display_names_locale,
- + jint output_type,
- + jlong base_options_handle) {
- ImageSegmenterOptions proto_options = ConvertToProtoOptions(
- env, display_names_locale, output_type, base_options_handle);
- auto file_descriptor_meta = proto_options.mutable_base_options()
- @@ -223,8 +230,12 @@ Java_org_tensorflow_lite_task_vision_segmenter_ImageSegmenter_initJniWithModelFd
-
- extern "C" JNIEXPORT jlong JNICALL
- Java_org_tensorflow_lite_task_vision_segmenter_ImageSegmenter_initJniWithByteBuffer(
- - JNIEnv* env, jclass thiz, jobject model_buffer,
- - jstring display_names_locale, jint output_type, jlong base_options_handle) {
- + JNIEnv* env,
- + jclass thiz,
- + jobject model_buffer,
- + jstring display_names_locale,
- + jint output_type,
- + jlong base_options_handle) {
- ImageSegmenterOptions proto_options = ConvertToProtoOptions(
- env, display_names_locale, output_type, base_options_handle);
- proto_options.mutable_base_options()->mutable_model_file()->set_file_content(
- @@ -235,8 +246,13 @@ Java_org_tensorflow_lite_task_vision_segmenter_ImageSegmenter_initJniWithByteBuf
-
- extern "C" JNIEXPORT void JNICALL
- Java_org_tensorflow_lite_task_vision_segmenter_ImageSegmenter_segmentNative(
- - JNIEnv* env, jclass thiz, jlong native_handle, jlong frame_buffer_handle,
- - jobject jmask_buffers, jintArray jmask_shape, jobject jcolored_labels) {
- + JNIEnv* env,
- + jclass thiz,
- + jlong native_handle,
- + jlong frame_buffer_handle,
- + jobject jmask_buffers,
- + jintArray jmask_shape,
- + jobject jcolored_labels) {
- auto* segmenter = reinterpret_cast<ImageSegmenter*>(native_handle);
- // frame_buffer will be deleted after inference is done in
- // base_vision_api_jni.cc.
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_extractor.cc b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_extractor.cc
- index 65a01c0b9d33a..2a72338741626 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_extractor.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_extractor.cc
- @@ -17,13 +17,13 @@ limitations under the License.
-
- #include <string>
-
- -#include "absl/memory/memory.h" // from @com_google_absl
- -#include "absl/status/status.h" // from @com_google_absl
- -#include "absl/strings/str_format.h" // from @com_google_absl
- +#include "absl/memory/memory.h" // from @com_google_absl
- +#include "absl/status/status.h" // from @com_google_absl
- +#include "absl/strings/str_format.h" // from @com_google_absl
- #include "absl/strings/string_view.h" // from @com_google_absl
- -#include "flatbuffers/flatbuffers.h" // from @flatbuffers
- #include "contrib/minizip/ioapi.h"
- #include "contrib/minizip/unzip.h"
- +#include "flatbuffers/flatbuffers.h" // from @flatbuffers
- #include "tensorflow/lite/schema/schema_generated.h"
- #include "tensorflow_lite_support/cc/common.h"
- #include "tensorflow_lite_support/cc/port/status_macros.h"
- @@ -46,7 +46,8 @@ using ::tflite::support::TfLiteSupportStatus;
- // Util to get item from src_vector specified by index.
- template <typename T>
- const T* GetItemFromVector(
- - const flatbuffers::Vector<flatbuffers::Offset<T>>* src_vector, int index) {
- + const flatbuffers::Vector<flatbuffers::Offset<T>>* src_vector,
- + int index) {
- if (src_vector == nullptr || index < 0 || index >= src_vector->size()) {
- return nullptr;
- }
- @@ -158,7 +159,8 @@ ModelMetadataExtractor::FindFirstProcessUnit(
- /* static */
- std::string ModelMetadataExtractor::FindFirstAssociatedFileName(
- const tflite::TensorMetadata& tensor_metadata,
- - tflite::AssociatedFileType type, absl::string_view locale) {
- + tflite::AssociatedFileType type,
- + absl::string_view locale) {
- if (tensor_metadata.associated_files() == nullptr) {
- return std::string();
- }
- @@ -175,7 +177,8 @@ std::string ModelMetadataExtractor::FindFirstAssociatedFileName(
- }
-
- absl::Status ModelMetadataExtractor::InitFromModelBuffer(
- - const char* buffer_data, size_t buffer_size) {
- + const char* buffer_data,
- + size_t buffer_size) {
- // Rely on the simplest, base flatbuffers verifier. Here is not the place to
- // e.g. use an OpResolver: we just want to make sure the buffer is valid to
- // access the metadata.
- @@ -234,7 +237,8 @@ absl::Status ModelMetadataExtractor::InitFromModelBuffer(
- }
-
- absl::Status ModelMetadataExtractor::ExtractAssociatedFiles(
- - const char* buffer_data, size_t buffer_size) {
- + const char* buffer_data,
- + size_t buffer_size) {
- // Create in-memory read-only zip file.
- ZipReadOnlyMemFile mem_file = ZipReadOnlyMemFile(buffer_data, buffer_size);
- // Open zip.
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_extractor.h b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_extractor.h
- index c2b28d18ef7d8..007919d581431 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_extractor.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_extractor.h
- @@ -16,8 +16,8 @@ limitations under the License.
- #define TENSORFLOW_LITE_SUPPORT_METADATA_CC_METADATA_EXTRACTOR_H_
-
- #include "absl/container/flat_hash_map.h" // from @com_google_absl
- -#include "absl/status/status.h" // from @com_google_absl
- -#include "absl/strings/string_view.h" // from @com_google_absl
- +#include "absl/status/status.h" // from @com_google_absl
- +#include "absl/strings/string_view.h" // from @com_google_absl
- #include "tensorflow/lite/schema/schema_generated.h"
- #include "tensorflow_lite_support/cc/port/statusor.h"
- #include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_populator.cc b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_populator.cc
- index 9d256b3322fb0..299ade3e95d54 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_populator.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_populator.cc
- @@ -19,9 +19,9 @@ limitations under the License.
- #include <cstring>
- #include <functional>
-
- -#include "flatbuffers/flatbuffers.h" // from @flatbuffers
- #include "contrib/minizip/ioapi.h"
- #include "contrib/minizip/zip.h"
- +#include "flatbuffers/flatbuffers.h" // from @flatbuffers
- #include "tensorflow/lite/schema/schema_generated.h"
- #include "tensorflow_lite_support/cc/common.h"
- #include "tensorflow_lite_support/cc/port/status_macros.h"
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_populator.h b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_populator.h
- index 510e6c04cdda1..4410f8481f97d 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_populator.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_populator.h
- @@ -17,8 +17,8 @@ limitations under the License.
- #define TENSORFLOW_LITE_SUPPORT_METADATA_CC_METADATA_POPULATOR_H_
-
- #include "absl/container/flat_hash_map.h" // from @com_google_absl
- -#include "absl/status/status.h" // from @com_google_absl
- -#include "flatbuffers/flatbuffers.h" // from @flatbuffers
- +#include "absl/status/status.h" // from @com_google_absl
- +#include "flatbuffers/flatbuffers.h" // from @flatbuffers
- #include "tensorflow/lite/schema/schema_generated.h"
- #include "tensorflow_lite_support/cc/port/statusor.h"
- #include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
- @@ -79,7 +79,8 @@ class ModelMetadataPopulator {
- // Zips and appends associated files to the provided model buffer. Called
- // internally by `Populate()`.
- tflite::support::StatusOr<std::string> AppendAssociatedFiles(
- - const char* model_buffer_data, size_t model_buffer_size);
- + const char* model_buffer_data,
- + size_t model_buffer_size);
-
- // The unpacked model FlatBuffer.
- tflite::ModelT model_t_;
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_version.cc b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_version.cc
- index 17ffbbc67fbec..78e9a9f1abec1 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_version.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_version.cc
- @@ -137,7 +137,8 @@ template <typename T>
- void UpdateMinimumVersionForArray(
- const flatbuffers::Vector<flatbuffers::Offset<T>>* array,
- Version* min_version) {
- - if (array == nullptr) return;
- + if (array == nullptr)
- + return;
-
- for (int i = 0; i < array->size(); ++i) {
- UpdateMinimumVersionForTable<T>(array->Get(i), min_version);
- @@ -146,8 +147,10 @@ void UpdateMinimumVersionForArray(
-
- template <>
- void UpdateMinimumVersionForTable<tflite::AssociatedFile>(
- - const tflite::AssociatedFile* table, Version* min_version) {
- - if (table == nullptr) return;
- + const tflite::AssociatedFile* table,
- + Version* min_version) {
- + if (table == nullptr)
- + return;
-
- if (table->type() == AssociatedFileType_VOCABULARY) {
- UpdateMinimumVersion(
- @@ -164,8 +167,10 @@ void UpdateMinimumVersionForTable<tflite::AssociatedFile>(
-
- template <>
- void UpdateMinimumVersionForTable<tflite::ProcessUnit>(
- - const tflite::ProcessUnit* table, Version* min_version) {
- - if (table == nullptr) return;
- + const tflite::ProcessUnit* table,
- + Version* min_version) {
- + if (table == nullptr)
- + return;
-
- tflite::ProcessUnitOptions process_unit_type = table->options_type();
- if (process_unit_type == ProcessUnitOptions_BertTokenizerOptions) {
- @@ -191,7 +196,8 @@ void UpdateMinimumVersionForTable<tflite::ProcessUnit>(
- template <>
- void UpdateMinimumVersionForTable<tflite::Content>(const tflite::Content* table,
- Version* min_version) {
- - if (table == nullptr) return;
- + if (table == nullptr)
- + return;
-
- // Checks the ContenProperties field.
- if (table->content_properties_type() == ContentProperties_AudioProperties) {
- @@ -203,8 +209,10 @@ void UpdateMinimumVersionForTable<tflite::Content>(const tflite::Content* table,
-
- template <>
- void UpdateMinimumVersionForTable<tflite::TensorMetadata>(
- - const tflite::TensorMetadata* table, Version* min_version) {
- - if (table == nullptr) return;
- + const tflite::TensorMetadata* table,
- + Version* min_version) {
- + if (table == nullptr)
- + return;
-
- // Checks the associated_files field.
- UpdateMinimumVersionForArray<tflite::AssociatedFile>(
- @@ -220,8 +228,10 @@ void UpdateMinimumVersionForTable<tflite::TensorMetadata>(
-
- template <>
- void UpdateMinimumVersionForTable<tflite::SubGraphMetadata>(
- - const tflite::SubGraphMetadata* table, Version* min_version) {
- - if (table == nullptr) return;
- + const tflite::SubGraphMetadata* table,
- + Version* min_version) {
- + if (table == nullptr)
- + return;
-
- // Checks in the input/output metadata arrays.
- UpdateMinimumVersionForArray<tflite::TensorMetadata>(
- @@ -268,7 +278,8 @@ void UpdateMinimumVersionForTable<tflite::SubGraphMetadata>(
-
- template <>
- void UpdateMinimumVersionForTable<tflite::ModelMetadata>(
- - const tflite::ModelMetadata* table, Version* min_version) {
- + const tflite::ModelMetadata* table,
- + Version* min_version) {
- if (table == nullptr) {
- // Should never happen, because VerifyModelMetadataBuffer has verified it.
- TFLITE_LOG(FATAL) << "The ModelMetadata object is null.";
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_readonly_mem_file.cc b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_readonly_mem_file.cc
- index 3dac8c24af942..392b6b411fe03 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_readonly_mem_file.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_readonly_mem_file.cc
- @@ -41,14 +41,17 @@ zlib_filefunc64_def& ZipReadOnlyMemFile::GetFileFunc64Def() {
- }
-
- /* static */
- -voidpf ZipReadOnlyMemFile::OpenFile(voidpf opaque, const void* filename,
- +voidpf ZipReadOnlyMemFile::OpenFile(voidpf opaque,
- + const void* filename,
- int mode) {
- // Result is never used, but needs to be non-null for `zipOpen2` not to fail.
- return opaque;
- }
-
- /* static */
- -uLong ZipReadOnlyMemFile::ReadFile(voidpf opaque, voidpf stream, void* buf,
- +uLong ZipReadOnlyMemFile::ReadFile(voidpf opaque,
- + voidpf stream,
- + void* buf,
- uLong size) {
- auto* mem_file = static_cast<ZipReadOnlyMemFile*>(opaque);
- if (mem_file->offset_ < 0 || mem_file->Size() < mem_file->offset_) {
- @@ -65,8 +68,10 @@ uLong ZipReadOnlyMemFile::ReadFile(voidpf opaque, voidpf stream, void* buf,
- }
-
- /* static */
- -uLong ZipReadOnlyMemFile::WriteFile(voidpf opaque, voidpf stream,
- - const void* buf, uLong size) {
- +uLong ZipReadOnlyMemFile::WriteFile(voidpf opaque,
- + voidpf stream,
- + const void* buf,
- + uLong size) {
- // File is not writable.
- return 0;
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_readonly_mem_file.h b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_readonly_mem_file.h
- index 13927a7afa698..a1799ff509de5 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_readonly_mem_file.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_readonly_mem_file.h
- @@ -58,7 +58,9 @@ class ZipReadOnlyMemFile {
- // The file function implementations used in the `zlib_filefunc64_def`.
- static voidpf OpenFile(voidpf opaque, const void* filename, int mode);
- static uLong ReadFile(voidpf opaque, voidpf stream, void* buf, uLong size);
- - static uLong WriteFile(voidpf opaque, voidpf stream, const void* buf,
- + static uLong WriteFile(voidpf opaque,
- + voidpf stream,
- + const void* buf,
- uLong size);
- static ZPOS64_T TellFile(voidpf opaque, voidpf stream);
- static long SeekFile // NOLINT
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_writable_mem_file.cc b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_writable_mem_file.cc
- index 5999be028689a..38ad17ad8935c 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_writable_mem_file.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_writable_mem_file.cc
- @@ -40,17 +40,22 @@ zlib_filefunc64_def& ZipWritableMemFile::GetFileFunc64Def() {
- return zlib_filefunc64_def_;
- }
-
- -absl::string_view ZipWritableMemFile::GetFileContent() const { return data_; }
- +absl::string_view ZipWritableMemFile::GetFileContent() const {
- + return data_;
- +}
-
- /* static */
- -voidpf ZipWritableMemFile::OpenFile(voidpf opaque, const void* filename,
- +voidpf ZipWritableMemFile::OpenFile(voidpf opaque,
- + const void* filename,
- int mode) {
- // Result is never used, but needs to be non-null for `zipOpen2` not to fail.
- return opaque;
- }
-
- /* static */
- -uLong ZipWritableMemFile::ReadFile(voidpf opaque, voidpf stream, void* buf,
- +uLong ZipWritableMemFile::ReadFile(voidpf opaque,
- + voidpf stream,
- + void* buf,
- uLong size) {
- auto* mem_file = static_cast<ZipWritableMemFile*>(opaque);
- if (mem_file->offset_ < 0 || mem_file->Size() < mem_file->offset_) {
- @@ -67,8 +72,10 @@ uLong ZipWritableMemFile::ReadFile(voidpf opaque, voidpf stream, void* buf,
- }
-
- /* static */
- -uLong ZipWritableMemFile::WriteFile(voidpf opaque, voidpf stream,
- - const void* buf, uLong size) {
- +uLong ZipWritableMemFile::WriteFile(voidpf opaque,
- + voidpf stream,
- + const void* buf,
- + uLong size) {
- auto* mem_file = static_cast<ZipWritableMemFile*>(opaque);
- if (mem_file->offset_ + size > mem_file->Size()) {
- mem_file->data_.resize(mem_file->offset_ + size);
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_writable_mem_file.h b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_writable_mem_file.h
- index 762dd58f0fb41..30e42fdb72a31 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_writable_mem_file.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_writable_mem_file.h
- @@ -59,7 +59,9 @@ class ZipWritableMemFile {
- // The file function implementations used in the `zlib_filefunc64_def`.
- static voidpf OpenFile(voidpf opaque, const void* filename, int mode);
- static uLong ReadFile(voidpf opaque, voidpf stream, void* buf, uLong size);
- - static uLong WriteFile(voidpf opaque, voidpf stream, const void* buf,
- + static uLong WriteFile(voidpf opaque,
- + voidpf stream,
- + const void* buf,
- uLong size);
- static ZPOS64_T TellFile(voidpf opaque, voidpf stream);
- static long SeekFile // NOLINT
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/flatbuffers_lib/flatbuffers_lib.cc b/third_party/tflite_support/src/tensorflow_lite_support/metadata/flatbuffers_lib/flatbuffers_lib.cc
- index 6185722504f69..8e00452bea983 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/flatbuffers_lib/flatbuffers_lib.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/flatbuffers_lib/flatbuffers_lib.cc
- @@ -14,7 +14,7 @@ limitations under the License.
- ==============================================================================*/
-
- #include "flatbuffers/flatbuffers.h" // from @flatbuffers
- -#include "flatbuffers/idl.h" // from @flatbuffers
- +#include "flatbuffers/idl.h" // from @flatbuffers
- #include "pybind11/pybind11.h"
- #include "pybind11/pytypes.h"
- #include "pybind11/stl.h"
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/BoundedInputStream.java b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/BoundedInputStream.java
- index 6c3d23270f3f0..15bcb45c1a4b1 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/BoundedInputStream.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/BoundedInputStream.java
- @@ -33,84 +33,84 @@ import java.nio.ByteBuffer;
- * synchronized as well.
- */
- final class BoundedInputStream extends InputStream {
- - private final ByteBuffer singleByteBuffer = ByteBuffer.allocate(1);
- - private final long end; // The valid data for the stream is between [start, end).
- - private long position;
- - private final SeekableByteChannelCompat channel;
- -
- - /**
- - * Creates a {@link BoundedInputStream} with a {@link SeekableByteChannelCompat}.
- - *
- - * @param channel the {@link SeekableByteChannelCompat} that backs up this {@link
- - * BoundedInputStream}
- - * @param start the starting position of this {@link BoundedInputStream} in the given {@link
- - * SeekableByteChannelCompat}
- - * @param remaining the length of this {@link BoundedInputStream}
- - * @throws IllegalArgumentException if {@code start} or {@code remaining} is negative
- - */
- - BoundedInputStream(SeekableByteChannelCompat channel, long start, long remaining) {
- - checkArgument(
- - remaining >= 0 && start >= 0,
- - String.format("Invalid length of stream at offset=%d, length=%d", start, remaining));
- -
- - end = start + remaining;
- - this.channel = channel;
- - position = start;
- - }
- -
- - @Override
- - public int available() throws IOException {
- - return (int) (Math.min(end, channel.size()) - position);
- - }
- -
- - @Override
- - public int read() throws IOException {
- - if (position >= end) {
- - return -1;
- + private final ByteBuffer singleByteBuffer = ByteBuffer.allocate(1);
- + private final long end; // The valid data for the stream is between [start, end).
- + private long position;
- + private final SeekableByteChannelCompat channel;
- +
- + /**
- + * Creates a {@link BoundedInputStream} with a {@link SeekableByteChannelCompat}.
- + *
- + * @param channel the {@link SeekableByteChannelCompat} that backs up this {@link
- + * BoundedInputStream}
- + * @param start the starting position of this {@link BoundedInputStream} in the given {@link
- + * SeekableByteChannelCompat}
- + * @param remaining the length of this {@link BoundedInputStream}
- + * @throws IllegalArgumentException if {@code start} or {@code remaining} is negative
- + */
- + BoundedInputStream(SeekableByteChannelCompat channel, long start, long remaining) {
- + checkArgument(remaining >= 0 && start >= 0,
- + String.format(
- + "Invalid length of stream at offset=%d, length=%d", start, remaining));
- +
- + end = start + remaining;
- + this.channel = channel;
- + position = start;
- }
-
- - singleByteBuffer.rewind();
- - int count = read(position, singleByteBuffer);
- - if (count < 0) {
- - return count;
- + @Override
- + public int available() throws IOException {
- + return (int) (Math.min(end, channel.size()) - position);
- }
-
- - position++;
- - return singleByteBuffer.get() & 0xff;
- - }
- + @Override
- + public int read() throws IOException {
- + if (position >= end) {
- + return -1;
- + }
-
- - @Override
- - public int read(byte[] b, int off, int len) throws IOException {
- - checkNotNull(b);
- - checkElementIndex(off, b.length, "The start offset");
- - checkElementIndex(len, b.length - off + 1, "The maximumn number of bytes to read");
- + singleByteBuffer.rewind();
- + int count = read(position, singleByteBuffer);
- + if (count < 0) {
- + return count;
- + }
-
- - if (len == 0) {
- - return 0;
- + position++;
- + return singleByteBuffer.get() & 0xff;
- }
-
- - if (len > end - position) {
- - if (position >= end) {
- - return -1;
- - }
- - len = (int) (end - position);
- + @Override
- + public int read(byte[] b, int off, int len) throws IOException {
- + checkNotNull(b);
- + checkElementIndex(off, b.length, "The start offset");
- + checkElementIndex(len, b.length - off + 1, "The maximumn number of bytes to read");
- +
- + if (len == 0) {
- + return 0;
- + }
- +
- + if (len > end - position) {
- + if (position >= end) {
- + return -1;
- + }
- + len = (int) (end - position);
- + }
- +
- + ByteBuffer buf = ByteBuffer.wrap(b, off, len);
- + int count = read(position, buf);
- + if (count > 0) {
- + position += count;
- + }
- + return count;
- }
-
- - ByteBuffer buf = ByteBuffer.wrap(b, off, len);
- - int count = read(position, buf);
- - if (count > 0) {
- - position += count;
- + private int read(long position, ByteBuffer buf) throws IOException {
- + int count;
- + synchronized (channel) {
- + channel.position(position);
- + count = channel.read(buf);
- + }
- + buf.flip();
- + return count;
- }
- - return count;
- - }
- -
- - private int read(long position, ByteBuffer buf) throws IOException {
- - int count;
- - synchronized (channel) {
- - channel.position(position);
- - count = channel.read(buf);
- - }
- - buf.flip();
- - return count;
- - }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ByteBufferChannel.java b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ByteBufferChannel.java
- index e5d54a415edc4..354119b02822e 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ByteBufferChannel.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ByteBufferChannel.java
- @@ -15,116 +15,114 @@ limitations under the License.
-
- package org.tensorflow.lite.support.metadata;
-
- -import static java.lang.Math.min;
- import static org.tensorflow.lite.support.metadata.Preconditions.checkArgument;
- import static org.tensorflow.lite.support.metadata.Preconditions.checkNotNull;
-
- +import static java.lang.Math.min;
- +
- import java.nio.ByteBuffer;
- import java.nio.channels.NonWritableChannelException;
-
- /** Implements the {@link SeekableByteChannelCompat} on top of {@link ByteBuffer}. */
- final class ByteBufferChannel implements SeekableByteChannelCompat {
- + /** The ByteBuffer that holds the data. */
- + private final ByteBuffer buffer;
- +
- + /**
- + * Creates a {@link ByteBufferChannel} that wraps a {@link ByteBuffer}.
- + *
- + * @param buffer the {@link ByteBuffer} that backs this {@link ByteBufferChannel}
- + * @throws NullPointerException if {@code buffer} is null
- + */
- + public ByteBufferChannel(ByteBuffer buffer) {
- + checkNotNull(buffer, "The ByteBuffer cannot be null.");
- + this.buffer = buffer;
- + }
- +
- + @Override
- + public void close() {}
-
- - /** The ByteBuffer that holds the data. */
- - private final ByteBuffer buffer;
- -
- - /**
- - * Creates a {@link ByteBufferChannel} that wraps a {@link ByteBuffer}.
- - *
- - * @param buffer the {@link ByteBuffer} that backs this {@link ByteBufferChannel}
- - * @throws NullPointerException if {@code buffer} is null
- - */
- - public ByteBufferChannel(ByteBuffer buffer) {
- - checkNotNull(buffer, "The ByteBuffer cannot be null.");
- - this.buffer = buffer;
- - }
- -
- - @Override
- - public void close() {}
- -
- - @Override
- - public boolean isOpen() {
- - return true;
- - }
- -
- - @Override
- - public long position() {
- - return buffer.position();
- - }
- -
- - /**
- - * Sets this channel's position.
- - *
- - * @param newPosition the new position, a non-negative integer counting the number of bytes from
- - * the beginning of the entity
- - * @return this channel
- - * @throws IllegalArgumentException if the new position is negative, or greater than the size of
- - * the underlying {@link ByteBuffer}, or greater than Integer.MAX_VALUE
- - */
- - @Override
- - public synchronized ByteBufferChannel position(long newPosition) {
- - checkArgument(
- - (newPosition >= 0 && newPosition <= Integer.MAX_VALUE),
- - "The new position should be non-negative and be less than Integer.MAX_VALUE.");
- - buffer.position((int) newPosition);
- - return this;
- - }
- -
- - /**
- - * {@inheritDoc}
- - *
- - * <p>Bytes are read starting at this channel's current position, and then the position is updated
- - * with the number of bytes actually read. Otherwise this method behaves exactly as specified in
- - * the {@link ReadableByteChannel} interface.
- - */
- - @Override
- - public synchronized int read(ByteBuffer dst) {
- - if (buffer.remaining() == 0) {
- - return -1;
- + @Override
- + public boolean isOpen() {
- + return true;
- }
-
- - int count = min(dst.remaining(), buffer.remaining());
- - if (count > 0) {
- - ByteBuffer tempBuffer = buffer.slice();
- - tempBuffer.order(buffer.order()).limit(count);
- - dst.put(tempBuffer);
- - buffer.position(buffer.position() + count);
- + @Override
- + public long position() {
- + return buffer.position();
- }
- - return count;
- - }
- -
- - @Override
- - public long size() {
- - return buffer.limit();
- - }
- -
- - @Override
- - public synchronized ByteBufferChannel truncate(long size) {
- - checkArgument(
- - (size >= 0 && size <= Integer.MAX_VALUE),
- - "The new size should be non-negative and be less than Integer.MAX_VALUE.");
- -
- - if (size < buffer.limit()) {
- - buffer.limit((int) size);
- - if (buffer.position() > size) {
- - buffer.position((int) size);
- - }
- +
- + /**
- + * Sets this channel's position.
- + *
- + * @param newPosition the new position, a non-negative integer counting the number of bytes from
- + * the beginning of the entity
- + * @return this channel
- + * @throws IllegalArgumentException if the new position is negative, or greater than the size of
- + * the underlying {@link ByteBuffer}, or greater than Integer.MAX_VALUE
- + */
- + @Override
- + public synchronized ByteBufferChannel position(long newPosition) {
- + checkArgument((newPosition >= 0 && newPosition <= Integer.MAX_VALUE),
- + "The new position should be non-negative and be less than Integer.MAX_VALUE.");
- + buffer.position((int) newPosition);
- + return this;
- + }
- +
- + /**
- + * {@inheritDoc}
- + *
- + * <p>Bytes are read starting at this channel's current position, and then the position is
- + * updated with the number of bytes actually read. Otherwise this method behaves exactly as
- + * specified in the {@link ReadableByteChannel} interface.
- + */
- + @Override
- + public synchronized int read(ByteBuffer dst) {
- + if (buffer.remaining() == 0) {
- + return -1;
- + }
- +
- + int count = min(dst.remaining(), buffer.remaining());
- + if (count > 0) {
- + ByteBuffer tempBuffer = buffer.slice();
- + tempBuffer.order(buffer.order()).limit(count);
- + dst.put(tempBuffer);
- + buffer.position(buffer.position() + count);
- + }
- + return count;
- + }
- +
- + @Override
- + public long size() {
- + return buffer.limit();
- }
- - return this;
- - }
-
- - @Override
- - public synchronized int write(ByteBuffer src) {
- - if (buffer.isReadOnly()) {
- - throw new NonWritableChannelException();
- + @Override
- + public synchronized ByteBufferChannel truncate(long size) {
- + checkArgument((size >= 0 && size <= Integer.MAX_VALUE),
- + "The new size should be non-negative and be less than Integer.MAX_VALUE.");
- +
- + if (size < buffer.limit()) {
- + buffer.limit((int) size);
- + if (buffer.position() > size) {
- + buffer.position((int) size);
- + }
- + }
- + return this;
- }
-
- - int count = min(src.remaining(), buffer.remaining());
- - if (count > 0) {
- - ByteBuffer tempBuffer = src.slice();
- - tempBuffer.order(buffer.order()).limit(count);
- - buffer.put(tempBuffer);
- + @Override
- + public synchronized int write(ByteBuffer src) {
- + if (buffer.isReadOnly()) {
- + throw new NonWritableChannelException();
- + }
- +
- + int count = min(src.remaining(), buffer.remaining());
- + if (count > 0) {
- + ByteBuffer tempBuffer = src.slice();
- + tempBuffer.order(buffer.order()).limit(count);
- + buffer.put(tempBuffer);
- + }
- + return count;
- }
- - return count;
- - }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/MetadataExtractor.java b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/MetadataExtractor.java
- index 183d416481156..3fb3c48118748 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/MetadataExtractor.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/MetadataExtractor.java
- @@ -17,15 +17,16 @@ package org.tensorflow.lite.support.metadata;
-
- import static org.tensorflow.lite.support.metadata.Preconditions.checkArgument;
-
- +import org.checkerframework.checker.nullness.qual.Nullable;
- +import org.tensorflow.lite.schema.Tensor;
- +import org.tensorflow.lite.support.metadata.schema.ModelMetadata;
- +import org.tensorflow.lite.support.metadata.schema.TensorMetadata;
- +
- import java.io.IOException;
- import java.io.InputStream;
- import java.nio.ByteBuffer;
- import java.util.Set;
- import java.util.zip.ZipException;
- -import org.checkerframework.checker.nullness.qual.Nullable;
- -import org.tensorflow.lite.schema.Tensor;
- -import org.tensorflow.lite.support.metadata.schema.ModelMetadata;
- -import org.tensorflow.lite.support.metadata.schema.TensorMetadata;
-
- /**
- * Loads metadata from TFLite Model FlatBuffer.
- @@ -53,328 +54,329 @@ import org.tensorflow.lite.support.metadata.schema.TensorMetadata;
- * MetadataExtractor} omits subgraph index as an input in its methods.
- */
- public class MetadataExtractor {
- + /** The helper class to load metadata from TFLite model FlatBuffer. */
- + private final ModelInfo modelInfo;
- +
- + /** The helper class to load metadata from TFLite metadata FlatBuffer. */
- + @Nullable
- + private final ModelMetadataInfo metadataInfo;
- +
- + /** The handler to load associated files through zip. */
- + @Nullable
- + private final ZipFile zipFile;
- +
- + /**
- + * Creates a {@link MetadataExtractor} with TFLite model FlatBuffer.
- + *
- + * @param buffer the TFLite model FlatBuffer
- + * @throws IllegalArgumentException if the number of input or output tensors in the model does
- + * not
- + * match that in the metadata
- + * @throws IOException if an error occurs while reading the model as a Zip file
- + */
- + public MetadataExtractor(ByteBuffer buffer) throws IOException {
- + modelInfo = new ModelInfo(buffer);
- + ByteBuffer metadataBuffer = modelInfo.getMetadataBuffer();
- + if (metadataBuffer != null) {
- + metadataInfo = new ModelMetadataInfo(metadataBuffer);
- +
- + // Prints warning message if the minimum parser version is not satisfied.
- + if (!isMinimumParserVersionSatisfied()) {
- + System.err.printf(
- + "<Warning> Some fields in the metadata belong to a future schema. The minimum parser"
- + + " version required is %s, but the version of the current metadata parser is %s",
- + metadataInfo.getMininumParserVersion(), MetadataParser.VERSION);
- + }
- +
- + checkArgument(modelInfo.getInputTensorCount() == metadataInfo.getInputTensorCount(),
- + String.format(
- + "The number of input tensors in the model is %d. The number of input tensors that"
- + + " recorded in the metadata is %d. These two values does not match.",
- + modelInfo.getInputTensorCount(), metadataInfo.getInputTensorCount()));
- + checkArgument(modelInfo.getOutputTensorCount() == metadataInfo.getOutputTensorCount(),
- + String.format(
- + "The number of output tensors in the model is %d. The number of output tensors that"
- + + " recorded in the metadata is %d. These two values does not match.",
- + modelInfo.getOutputTensorCount(), metadataInfo.getOutputTensorCount()));
- + } else {
- + // It is allowed to pass in a model FlatBuffer without TFLite metadata. However,
- + // invoking methods that read from TFLite metadata will cause runtime errors.
- + metadataInfo = null;
- + }
- +
- + zipFile = createZipFile(buffer);
- + }
-
- - /** The helper class to load metadata from TFLite model FlatBuffer. */
- - private final ModelInfo modelInfo;
- -
- - /** The helper class to load metadata from TFLite metadata FlatBuffer. */
- - @Nullable private final ModelMetadataInfo metadataInfo;
- -
- - /** The handler to load associated files through zip. */
- - @Nullable private final ZipFile zipFile;
- -
- - /**
- - * Creates a {@link MetadataExtractor} with TFLite model FlatBuffer.
- - *
- - * @param buffer the TFLite model FlatBuffer
- - * @throws IllegalArgumentException if the number of input or output tensors in the model does not
- - * match that in the metadata
- - * @throws IOException if an error occurs while reading the model as a Zip file
- - */
- - public MetadataExtractor(ByteBuffer buffer) throws IOException {
- - modelInfo = new ModelInfo(buffer);
- - ByteBuffer metadataBuffer = modelInfo.getMetadataBuffer();
- - if (metadataBuffer != null) {
- - metadataInfo = new ModelMetadataInfo(metadataBuffer);
- -
- - // Prints warning message if the minimum parser version is not satisfied.
- - if (!isMinimumParserVersionSatisfied()) {
- - System.err.printf(
- - "<Warning> Some fields in the metadata belong to a future schema. The minimum parser"
- - + " version required is %s, but the version of the current metadata parser is %s",
- - metadataInfo.getMininumParserVersion(), MetadataParser.VERSION);
- - }
- -
- - checkArgument(
- - modelInfo.getInputTensorCount() == metadataInfo.getInputTensorCount(),
- - String.format(
- - "The number of input tensors in the model is %d. The number of input tensors that"
- - + " recorded in the metadata is %d. These two values does not match.",
- - modelInfo.getInputTensorCount(), metadataInfo.getInputTensorCount()));
- - checkArgument(
- - modelInfo.getOutputTensorCount() == metadataInfo.getOutputTensorCount(),
- - String.format(
- - "The number of output tensors in the model is %d. The number of output tensors that"
- - + " recorded in the metadata is %d. These two values does not match.",
- - modelInfo.getOutputTensorCount(), metadataInfo.getOutputTensorCount()));
- - } else {
- - // It is allowed to pass in a model FlatBuffer without TFLite metadata. However, invoking
- - // methods that read from TFLite metadata will cause runtime errors.
- - metadataInfo = null;
- + /**
- + * Quantization parameters that corresponds to the table, {@code QuantizationParameters}, in the
- + * <a
- + * href="https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/schema/schema.fbs">TFLite
- + * Model schema file.</a>
- + *
- + * <p>Since per-channel quantization does not apply to input and output tensors, {@code scale}
- + * and
- + * {@code zero_point} are both single values instead of arrays.
- + *
- + * <p>For tensor that are not quantized, the values of scale and zero_point are both 0.
- + *
- + * <p>Given a quantized value q, the corresponding float value f should be: <br>
- + * f = scale * (q - zero_point) <br>
- + */
- + public static class QuantizationParams {
- + /** The scale value used in quantization. */
- + private final float scale;
- + /** The zero point value used in quantization. */
- + private final int zeroPoint;
- +
- + /**
- + * Creates a {@link QuantizationParams} with {@code scale} and {@code zero_point}.
- + *
- + * @param scale The scale value used in quantization.
- + * @param zeroPoint The zero point value used in quantization.
- + */
- + public QuantizationParams(final float scale, final int zeroPoint) {
- + this.scale = scale;
- + this.zeroPoint = zeroPoint;
- + }
- +
- + /** Returns the scale value. */
- + public float getScale() {
- + return scale;
- + }
- +
- + /** Returns the zero point value. */
- + public int getZeroPoint() {
- + return zeroPoint;
- + }
- }
-
- - zipFile = createZipFile(buffer);
- - }
- -
- - /**
- - * Quantization parameters that corresponds to the table, {@code QuantizationParameters}, in the
- - * <a
- - * href="https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/schema/schema.fbs">TFLite
- - * Model schema file.</a>
- - *
- - * <p>Since per-channel quantization does not apply to input and output tensors, {@code scale} and
- - * {@code zero_point} are both single values instead of arrays.
- - *
- - * <p>For tensor that are not quantized, the values of scale and zero_point are both 0.
- - *
- - * <p>Given a quantized value q, the corresponding float value f should be: <br>
- - * f = scale * (q - zero_point) <br>
- - */
- - public static class QuantizationParams {
- - /** The scale value used in quantization. */
- - private final float scale;
- - /** The zero point value used in quantization. */
- - private final int zeroPoint;
- + /** Returns {@code true} if the model has metadata. Otherwise, returns {@code false}. */
- + public boolean hasMetadata() {
- + return metadataInfo != null;
- + }
-
- /**
- - * Creates a {@link QuantizationParams} with {@code scale} and {@code zero_point}.
- + * Gets the packed associated file with the specified {@code fileName}.
- *
- - * @param scale The scale value used in quantization.
- - * @param zeroPoint The zero point value used in quantization.
- + * @param fileName the name of the associated file
- + * @return the raw input stream containing specified file
- + * @throws IllegalStateException if the model is not a zip file
- + * @throws IllegalArgumentException if the specified file does not exist in the model
- */
- - public QuantizationParams(final float scale, final int zeroPoint) {
- - this.scale = scale;
- - this.zeroPoint = zeroPoint;
- + public InputStream getAssociatedFile(String fileName) {
- + assertZipFile();
- + return zipFile.getRawInputStream(fileName);
- }
-
- - /** Returns the scale value. */
- - public float getScale() {
- - return scale;
- + /**
- + * Gets the file names of the associated files.
- + *
- + * @return the file names of the associated files
- + * @throws IllegalStateException if the model is not a zip file
- + */
- + public Set<String> getAssociatedFileNames() {
- + assertZipFile();
- + return zipFile.getFileNames();
- }
-
- - /** Returns the zero point value. */
- - public int getZeroPoint() {
- - return zeroPoint;
- + /** Gets the count of input tensors in the model. */
- + public int getInputTensorCount() {
- + return modelInfo.getInputTensorCount();
- }
- - }
- -
- - /** Returns {@code true} if the model has metadata. Otherwise, returns {@code false}. */
- - public boolean hasMetadata() {
- - return metadataInfo != null;
- - }
- -
- - /**
- - * Gets the packed associated file with the specified {@code fileName}.
- - *
- - * @param fileName the name of the associated file
- - * @return the raw input stream containing specified file
- - * @throws IllegalStateException if the model is not a zip file
- - * @throws IllegalArgumentException if the specified file does not exist in the model
- - */
- - public InputStream getAssociatedFile(String fileName) {
- - assertZipFile();
- - return zipFile.getRawInputStream(fileName);
- - }
- -
- - /**
- - * Gets the file names of the associated files.
- - *
- - * @return the file names of the associated files
- - * @throws IllegalStateException if the model is not a zip file
- - */
- - public Set<String> getAssociatedFileNames() {
- - assertZipFile();
- - return zipFile.getFileNames();
- - }
- -
- - /** Gets the count of input tensors in the model. */
- - public int getInputTensorCount() {
- - return modelInfo.getInputTensorCount();
- - }
- -
- - /**
- - * Gets the metadata for the input tensor specified by {@code inputIndex}.
- - *
- - * @param inputIndex the index of the desired input tensor
- - * @throws IllegalStateException if this model does not contain model metadata
- - */
- - @Nullable
- - public TensorMetadata getInputTensorMetadata(int inputIndex) {
- - assertMetadataInfo();
- - return metadataInfo.getInputTensorMetadata(inputIndex);
- - }
- -
- - /**
- - * Gets the quantization parameters for the input tensor specified by {@code inputIndex}.
- - *
- - * @param inputIndex the index of the desired input tensor
- - */
- - public QuantizationParams getInputTensorQuantizationParams(int inputIndex) {
- - Tensor tensor = modelInfo.getInputTensor(inputIndex);
- - return modelInfo.getQuantizationParams(tensor);
- - }
- -
- - /**
- - * Gets the shape of the input tensor with {@code inputIndex}.
- - *
- - * @param inputIndex the index of the desired input tensor
- - */
- - public int[] getInputTensorShape(int inputIndex) {
- - return modelInfo.getInputTensorShape(inputIndex);
- - }
- -
- - /**
- - * Gets the {@link TensorType} of the input tensor with {@code inputIndex}.
- - *
- - * @param inputIndex the index of the desired input tensor
- - */
- - public byte getInputTensorType(int inputIndex) {
- - return modelInfo.getInputTensorType(inputIndex);
- - }
- -
- - /**
- - * Gets the root handler for the model metadata.
- - *
- - * @throws IllegalStateException if this model does not contain model metadata
- - */
- - public ModelMetadata getModelMetadata() {
- - assertMetadataInfo();
- - return metadataInfo.getModelMetadata();
- - }
- -
- - /** Gets the count of output tensors in the model. */
- - public int getOutputTensorCount() {
- - return modelInfo.getOutputTensorCount();
- - }
- -
- - /**
- - * Gets the metadata for the output tensor specified by {@code outputIndex}.
- - *
- - * @param outputIndex the index of the desired output tensor
- - * @throws IllegalStateException if this model does not contain model metadata
- - */
- - @Nullable
- - public TensorMetadata getOutputTensorMetadata(int outputIndex) {
- - assertMetadataInfo();
- - return metadataInfo.getOutputTensorMetadata(outputIndex);
- - }
- -
- - /**
- - * Gets the quantization parameters for the output tensor specified by {@code outputIndex}.
- - *
- - * @param outputIndex the index of the desired output tensor
- - */
- - public QuantizationParams getOutputTensorQuantizationParams(int outputIndex) {
- - Tensor tensor = modelInfo.getOutputTensor(outputIndex);
- - return modelInfo.getQuantizationParams(tensor);
- - }
- -
- - /**
- - * Gets the shape of the output tensor with {@code outputIndex}.
- - *
- - * @param outputIndex the index of the desired output tensor
- - */
- - public int[] getOutputTensorShape(int outputIndex) {
- - return modelInfo.getOutputTensorShape(outputIndex);
- - }
- -
- - /**
- - * Gets the {@link TensorType} of the output tensor with {@code outputIndex}.
- - *
- - * @param outputIndex the index of the desired output tensor
- - */
- - public byte getOutputTensorType(int outputIndex) {
- - return modelInfo.getOutputTensorType(outputIndex);
- - }
- -
- - /**
- - * Returns {@code true} if the minimum parser version required by the given metadata flatbuffer
- - * precedes or equals to the version of the metadata parser that this MetadataExtractor library is
- - * relying on. All fields in the metadata can be parsed correctly with this metadata extractor
- - * library in this case. Otherwise, it returns {@code false}.
- - *
- - * <p>For example, assume the underlying metadata parser version is {@code 1.14.1},
- - *
- - * <ul>
- - * <li>it returns {@code true}, if the required minimum parser version is the same or older,
- - * such as {@code 1.14.1} or {@code 1.14.0}. Null version precedes all numeric versions,
- - * because some metadata flatbuffers are generated before the first versioned release; <br>
- - * <li>it returns {@code false}, if the required minimum parser version is newer, such as {@code
- - * 1.14.2}.
- - * </ul>
- - */
- - public final boolean isMinimumParserVersionSatisfied() {
- - String minVersion = metadataInfo.getMininumParserVersion();
- - if (minVersion == null) {
- - return true;
- +
- + /**
- + * Gets the metadata for the input tensor specified by {@code inputIndex}.
- + *
- + * @param inputIndex the index of the desired input tensor
- + * @throws IllegalStateException if this model does not contain model metadata
- + */
- + @Nullable
- + public TensorMetadata getInputTensorMetadata(int inputIndex) {
- + assertMetadataInfo();
- + return metadataInfo.getInputTensorMetadata(inputIndex);
- }
- - return compareVersions(minVersion, MetadataParser.VERSION) <= 0;
- - }
- -
- - /**
- - * Asserts if {@link #metadataInfo} is not initialized. Some models may not have metadata and this
- - * is allowed. However, invoking methods that reads the metadata is not allowed.
- - *
- - * @throws IllegalStateException if this model does not contain model metadata
- - */
- - private void assertMetadataInfo() {
- - if (metadataInfo == null) {
- - throw new IllegalStateException("This model does not contain model metadata.");
- +
- + /**
- + * Gets the quantization parameters for the input tensor specified by {@code inputIndex}.
- + *
- + * @param inputIndex the index of the desired input tensor
- + */
- + public QuantizationParams getInputTensorQuantizationParams(int inputIndex) {
- + Tensor tensor = modelInfo.getInputTensor(inputIndex);
- + return modelInfo.getQuantizationParams(tensor);
- }
- - }
- -
- - /**
- - * Asserts if {@link #zipFile} is not initialized. Some models may not have associated files, thus
- - * are not Zip files. This is allowed. However, invoking methods that reads those associated files
- - * is not allowed.
- - *
- - * @throws IllegalStateException if this model is not a Zip file
- - */
- - private void assertZipFile() {
- - if (zipFile == null) {
- - throw new IllegalStateException(
- - "This model does not contain associated files, and is not a Zip file.");
- +
- + /**
- + * Gets the shape of the input tensor with {@code inputIndex}.
- + *
- + * @param inputIndex the index of the desired input tensor
- + */
- + public int[] getInputTensorShape(int inputIndex) {
- + return modelInfo.getInputTensorShape(inputIndex);
- }
- - }
- -
- - /**
- - * Creates a Zip file handler to read the associated files. If the model is not a zip file, i.e.
- - * it does not have associated files, return a null handler.
- - *
- - * @param buffer the TFLite model FlatBuffer
- - * @throws IOException if an error occurs while reading the model as a Zip file
- - */
- - @Nullable
- - private static ZipFile createZipFile(ByteBuffer buffer) throws IOException {
- - try {
- - // Creates the handler to hold the associated files through the Zip.
- - ByteBufferChannel byteBufferChannel = new ByteBufferChannel(buffer);
- - return ZipFile.createFrom(byteBufferChannel);
- - } catch (ZipException e) {
- - // Some models may not have associate files. Therefore, Those models are not zip files.
- - // However, invoking methods that read associated files later will lead into errors.
- - return null;
- +
- + /**
- + * Gets the {@link TensorType} of the input tensor with {@code inputIndex}.
- + *
- + * @param inputIndex the index of the desired input tensor
- + */
- + public byte getInputTensorType(int inputIndex) {
- + return modelInfo.getInputTensorType(inputIndex);
- }
- - }
- -
- - /**
- - * Compares two semantic version numbers.
- - *
- - * <p>Examples of comparing two versions: <br>
- - * {@code 1.9} precedes {@code 1.14}; <br>
- - * {@code 1.14} precedes {@code 1.14.1}; <br>
- - * {@code 1.14} and {@code 1.14.0} are euqal;
- - *
- - * @return the value {@code 0} if the two versions are equal; a value less than {@code 0} if
- - * {@code version1} precedes {@code version2}; a value greater than {@code 0} if {@code
- - * version2} precedes {@code version1}.
- - */
- - private static int compareVersions(String version1, String version2) {
- - // Using String.split instead of the recommanded Guava Splitter because we've been avoiding
- - // depending on other third party libraries in this project.
- - String[] levels1 = version1.split("\\.", 0);
- - String[] levels2 = version2.split("\\.", 0);
- -
- - int length = Math.max(levels1.length, levels2.length);
- - for (int i = 0; i < length; i++) {
- - Integer v1 = i < levels1.length ? Integer.parseInt(levels1[i]) : 0;
- - Integer v2 = i < levels2.length ? Integer.parseInt(levels2[i]) : 0;
- - int compare = v1.compareTo(v2);
- - if (compare != 0) {
- - return compare;
- - }
- +
- + /**
- + * Gets the root handler for the model metadata.
- + *
- + * @throws IllegalStateException if this model does not contain model metadata
- + */
- + public ModelMetadata getModelMetadata() {
- + assertMetadataInfo();
- + return metadataInfo.getModelMetadata();
- + }
- +
- + /** Gets the count of output tensors in the model. */
- + public int getOutputTensorCount() {
- + return modelInfo.getOutputTensorCount();
- }
-
- - return 0;
- - }
- + /**
- + * Gets the metadata for the output tensor specified by {@code outputIndex}.
- + *
- + * @param outputIndex the index of the desired output tensor
- + * @throws IllegalStateException if this model does not contain model metadata
- + */
- + @Nullable
- + public TensorMetadata getOutputTensorMetadata(int outputIndex) {
- + assertMetadataInfo();
- + return metadataInfo.getOutputTensorMetadata(outputIndex);
- + }
- +
- + /**
- + * Gets the quantization parameters for the output tensor specified by {@code outputIndex}.
- + *
- + * @param outputIndex the index of the desired output tensor
- + */
- + public QuantizationParams getOutputTensorQuantizationParams(int outputIndex) {
- + Tensor tensor = modelInfo.getOutputTensor(outputIndex);
- + return modelInfo.getQuantizationParams(tensor);
- + }
- +
- + /**
- + * Gets the shape of the output tensor with {@code outputIndex}.
- + *
- + * @param outputIndex the index of the desired output tensor
- + */
- + public int[] getOutputTensorShape(int outputIndex) {
- + return modelInfo.getOutputTensorShape(outputIndex);
- + }
- +
- + /**
- + * Gets the {@link TensorType} of the output tensor with {@code outputIndex}.
- + *
- + * @param outputIndex the index of the desired output tensor
- + */
- + public byte getOutputTensorType(int outputIndex) {
- + return modelInfo.getOutputTensorType(outputIndex);
- + }
- +
- + /**
- + * Returns {@code true} if the minimum parser version required by the given metadata flatbuffer
- + * precedes or equals to the version of the metadata parser that this MetadataExtractor library
- + * is relying on. All fields in the metadata can be parsed correctly with this metadata
- + * extractor library in this case. Otherwise, it returns {@code false}.
- + *
- + * <p>For example, assume the underlying metadata parser version is {@code 1.14.1},
- + *
- + * <ul>
- + * <li>it returns {@code true}, if the required minimum parser version is the same or older,
- + * such as {@code 1.14.1} or {@code 1.14.0}. Null version precedes all numeric versions,
- + * because some metadata flatbuffers are generated before the first versioned release;
- + * <br> <li>it returns {@code false}, if the required minimum parser version is newer, such as
- + * {@code 1.14.2}.
- + * </ul>
- + */
- + public final boolean isMinimumParserVersionSatisfied() {
- + String minVersion = metadataInfo.getMininumParserVersion();
- + if (minVersion == null) {
- + return true;
- + }
- + return compareVersions(minVersion, MetadataParser.VERSION) <= 0;
- + }
- +
- + /**
- + * Asserts if {@link #metadataInfo} is not initialized. Some models may not have metadata and
- + * this is allowed. However, invoking methods that reads the metadata is not allowed.
- + *
- + * @throws IllegalStateException if this model does not contain model metadata
- + */
- + private void assertMetadataInfo() {
- + if (metadataInfo == null) {
- + throw new IllegalStateException("This model does not contain model metadata.");
- + }
- + }
- +
- + /**
- + * Asserts if {@link #zipFile} is not initialized. Some models may not have associated files,
- + * thus are not Zip files. This is allowed. However, invoking methods that reads those
- + * associated files is not allowed.
- + *
- + * @throws IllegalStateException if this model is not a Zip file
- + */
- + private void assertZipFile() {
- + if (zipFile == null) {
- + throw new IllegalStateException(
- + "This model does not contain associated files, and is not a Zip file.");
- + }
- + }
- +
- + /**
- + * Creates a Zip file handler to read the associated files. If the model is not a zip file, i.e.
- + * it does not have associated files, return a null handler.
- + *
- + * @param buffer the TFLite model FlatBuffer
- + * @throws IOException if an error occurs while reading the model as a Zip file
- + */
- + @Nullable
- + private static ZipFile createZipFile(ByteBuffer buffer) throws IOException {
- + try {
- + // Creates the handler to hold the associated files through the Zip.
- + ByteBufferChannel byteBufferChannel = new ByteBufferChannel(buffer);
- + return ZipFile.createFrom(byteBufferChannel);
- + } catch (ZipException e) {
- + // Some models may not have associate files. Therefore, Those models are not zip files.
- + // However, invoking methods that read associated files later will lead into errors.
- + return null;
- + }
- + }
- +
- + /**
- + * Compares two semantic version numbers.
- + *
- + * <p>Examples of comparing two versions: <br>
- + * {@code 1.9} precedes {@code 1.14}; <br>
- + * {@code 1.14} precedes {@code 1.14.1}; <br>
- + * {@code 1.14} and {@code 1.14.0} are euqal;
- + *
- + * @return the value {@code 0} if the two versions are equal; a value less than {@code 0} if
- + * {@code version1} precedes {@code version2}; a value greater than {@code 0} if {@code
- + * version2} precedes {@code version1}.
- + */
- + private static int compareVersions(String version1, String version2) {
- + // Using String.split instead of the recommanded Guava Splitter because we've been avoiding
- + // depending on other third party libraries in this project.
- + String[] levels1 = version1.split("\\.", 0);
- + String[] levels2 = version2.split("\\.", 0);
- +
- + int length = Math.max(levels1.length, levels2.length);
- + for (int i = 0; i < length; i++) {
- + Integer v1 = i < levels1.length ? Integer.parseInt(levels1[i]) : 0;
- + Integer v2 = i < levels2.length ? Integer.parseInt(levels2[i]) : 0;
- + int compare = v1.compareTo(v2);
- + if (compare != 0) {
- + return compare;
- + }
- + }
- +
- + return 0;
- + }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/MetadataParser.java b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/MetadataParser.java
- index 8a262a02eab14..1dbf9ebb46386 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/MetadataParser.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/MetadataParser.java
- @@ -17,11 +17,11 @@ package org.tensorflow.lite.support.metadata;
-
- /** Information about the metadata parser that this metadata extractor library is depending on. */
- public final class MetadataParser {
- - /**
- - * The version of the metadata parser that this metadata extractor library is depending on. The
- - * value should match the value of "Schema Semantic version" in metadata_schema.fbs.
- - */
- - public static final String VERSION = "1.4.0";
- + /**
- + * The version of the metadata parser that this metadata extractor library is depending on. The
- + * value should match the value of "Schema Semantic version" in metadata_schema.fbs.
- + */
- + public static final String VERSION = "1.4.0";
-
- - private MetadataParser() {}
- + private MetadataParser() {}
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ModelInfo.java b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ModelInfo.java
- index 309a3dbe77470..863ab83e306fb 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ModelInfo.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ModelInfo.java
- @@ -18,10 +18,6 @@ package org.tensorflow.lite.support.metadata;
- import static org.tensorflow.lite.support.metadata.Preconditions.checkArgument;
- import static org.tensorflow.lite.support.metadata.Preconditions.checkNotNull;
-
- -import java.nio.ByteBuffer;
- -import java.util.ArrayList;
- -import java.util.Collections;
- -import java.util.List;
- import org.checkerframework.checker.nullness.qual.Nullable;
- import org.tensorflow.lite.schema.Buffer;
- import org.tensorflow.lite.schema.Metadata;
- @@ -32,235 +28,237 @@ import org.tensorflow.lite.schema.Tensor;
- import org.tensorflow.lite.schema.TensorType;
- import org.tensorflow.lite.support.metadata.MetadataExtractor.QuantizationParams;
-
- +import java.nio.ByteBuffer;
- +import java.util.ArrayList;
- +import java.util.Collections;
- +import java.util.List;
- +
- /** Extracts model information out of TFLite model FLatBuffer. */
- final class ModelInfo {
- - /** The model that is loaded from TFLite model FlatBuffer. */
- - private final Model model;
- -
- - /** A list of input tensors. */
- - private final List</* @Nullable */ Tensor> inputTensors;
- -
- - /** A list of output tensors. */
- - private final List</* @Nullable */ Tensor> outputTensors;
- -
- - /** Identifier of the TFLite model metadata in the Metadata array. */
- - static final String METADATA_FIELD_NAME = "TFLITE_METADATA";
- -
- - /**
- - * Creates a {@link ModelInfo} with the model FlatBuffer, {@code buffer}.
- - *
- - * <p>Though TFLite model FlatBuffer supports multiple subgraphs, TFLite Interpreter only supports
- - * single subgraph so far. See the <a
- - * href="https://www.tensorflow.org/lite/convert/cmdline_examples#specifying_subgraphs">instruction
- - * of how to specify subgraph during convertion for more information.</a> Therefore, all methods
- - * in {@link ModelInfo} retrieves metadata of the first subgrpah as default.
- - *
- - * @param buffer the TFLite model FlatBuffer
- - * @throws NullPointerException if {@code buffer} is null
- - * @throws IllegalArgumentException if the model does not contain any subgraph, or the model does
- - * not contain the expected identifier
- - */
- - ModelInfo(ByteBuffer buffer) {
- - assertTFLiteModel(buffer);
- -
- - model = Model.getRootAsModel(buffer);
- - checkArgument(model.subgraphsLength() > 0, "The model does not contain any subgraph.");
- -
- - inputTensors = getInputTensors(model);
- - outputTensors = getOutputTensors(model);
- - }
- -
- - /**
- - * Gets the input tensor with {@code inputIndex}.
- - *
- - * @param inputIndex The index of the desired input tensor.
- - * @throws IllegalArgumentException if the inputIndex specified is invalid.
- - */
- - @Nullable
- - Tensor getInputTensor(int inputIndex) {
- - checkArgument(
- - inputIndex >= 0 && inputIndex < inputTensors.size(),
- - "The inputIndex specified is invalid.");
- - return inputTensors.get(inputIndex);
- - }
- -
- - int getInputTensorCount() {
- - return inputTensors.size();
- - }
- -
- - /**
- - * Gets shape of the input tensor with {@code inputIndex}.
- - *
- - * @param inputIndex The index of the desired intput tensor.
- - */
- - int[] getInputTensorShape(int inputIndex) {
- - Tensor tensor = getInputTensor(inputIndex);
- - return getShape(tensor);
- - }
- -
- - /**
- - * Gets the {@link TensorType} in byte of the input tensor with {@code inputIndex}.
- - *
- - * @param inputIndex The index of the desired intput tensor.
- - */
- - byte getInputTensorType(int inputIndex) {
- - return getInputTensor(inputIndex).type();
- - }
- -
- - /** Gets the metadata FlatBuffer from the model FlatBuffer. */
- - @Nullable
- - ByteBuffer getMetadataBuffer() {
- - // Some models may not have metadata, and this is allowed.
- - if (model.metadataLength() == 0) {
- - return null;
- + /** The model that is loaded from TFLite model FlatBuffer. */
- + private final Model model;
- +
- + /** A list of input tensors. */
- + private final List</* @Nullable */ Tensor> inputTensors;
- +
- + /** A list of output tensors. */
- + private final List</* @Nullable */ Tensor> outputTensors;
- +
- + /** Identifier of the TFLite model metadata in the Metadata array. */
- + static final String METADATA_FIELD_NAME = "TFLITE_METADATA";
- +
- + /**
- + * Creates a {@link ModelInfo} with the model FlatBuffer, {@code buffer}.
- + *
- + * <p>Though TFLite model FlatBuffer supports multiple subgraphs, TFLite Interpreter only
- + * supports single subgraph so far. See the <a
- + * href="https://www.tensorflow.org/lite/convert/cmdline_examples#specifying_subgraphs">instruction
- + * of how to specify subgraph during convertion for more information.</a> Therefore, all methods
- + * in {@link ModelInfo} retrieves metadata of the first subgrpah as default.
- + *
- + * @param buffer the TFLite model FlatBuffer
- + * @throws NullPointerException if {@code buffer} is null
- + * @throws IllegalArgumentException if the model does not contain any subgraph, or the model
- + * does
- + * not contain the expected identifier
- + */
- + ModelInfo(ByteBuffer buffer) {
- + assertTFLiteModel(buffer);
- +
- + model = Model.getRootAsModel(buffer);
- + checkArgument(model.subgraphsLength() > 0, "The model does not contain any subgraph.");
- +
- + inputTensors = getInputTensors(model);
- + outputTensors = getOutputTensors(model);
- + }
- +
- + /**
- + * Gets the input tensor with {@code inputIndex}.
- + *
- + * @param inputIndex The index of the desired input tensor.
- + * @throws IllegalArgumentException if the inputIndex specified is invalid.
- + */
- + @Nullable
- + Tensor getInputTensor(int inputIndex) {
- + checkArgument(inputIndex >= 0 && inputIndex < inputTensors.size(),
- + "The inputIndex specified is invalid.");
- + return inputTensors.get(inputIndex);
- + }
- +
- + int getInputTensorCount() {
- + return inputTensors.size();
- + }
- +
- + /**
- + * Gets shape of the input tensor with {@code inputIndex}.
- + *
- + * @param inputIndex The index of the desired intput tensor.
- + */
- + int[] getInputTensorShape(int inputIndex) {
- + Tensor tensor = getInputTensor(inputIndex);
- + return getShape(tensor);
- }
-
- - for (int i = 0; i < model.metadataLength(); i++) {
- - Metadata meta = model.metadata(i);
- - if (METADATA_FIELD_NAME.equals(meta.name())) {
- - long bufferIndex = meta.buffer();
- - Buffer metadataBuf = model.buffers((int) bufferIndex);
- - return metadataBuf.dataAsByteBuffer();
- - }
- + /**
- + * Gets the {@link TensorType} in byte of the input tensor with {@code inputIndex}.
- + *
- + * @param inputIndex The index of the desired intput tensor.
- + */
- + byte getInputTensorType(int inputIndex) {
- + return getInputTensor(inputIndex).type();
- }
- - return null;
- - }
- -
- - /**
- - * Gets the output tensor with {@code outputIndex}.
- - *
- - * @param outputIndex The index of the desired outtput tensor.
- - * @throws IllegalArgumentException if the outputIndex specified is invalid.
- - */
- - @Nullable
- - Tensor getOutputTensor(int outputIndex) {
- - checkArgument(
- - outputIndex >= 0 && outputIndex < outputTensors.size(),
- - "The outputIndex specified is invalid.");
- - return outputTensors.get(outputIndex);
- - }
- -
- - int getOutputTensorCount() {
- - return outputTensors.size();
- - }
- -
- - /**
- - * Gets shape of the output tensor with {@code outputIndex}.
- - *
- - * @param outputIndex The index of the desired outtput tensor.
- - */
- - int[] getOutputTensorShape(int outputIndex) {
- - Tensor tensor = getOutputTensor(outputIndex);
- - return getShape(tensor);
- - }
- -
- - /**
- - * Gets the {@link TensorType} in byte of the output tensor {@code outputIndex}.
- - *
- - * @param outputIndex The index of the desired outtput tensor.
- - */
- - byte getOutputTensorType(int outputIndex) {
- - return getOutputTensor(outputIndex).type();
- - }
- -
- - /**
- - * Gets the quantization parameters of a tensor.
- - *
- - * <p>Only quantized tensors have valid {@code QuantizationParameters}. For tensor that are not
- - * quantized, the values of scale and zero_point are both 0.
- - *
- - * @param tensor The tensor whoes quantization parameters is desired.
- - * @throws NullPointerException if the tensor is null.
- - * @throws IllegalArgumentException if {@code scale} and {@code zeroPoint} of the tensor's {@link
- - * QuantizationParameters} are not single values.
- - */
- - QuantizationParams getQuantizationParams(Tensor tensor) {
- - checkNotNull(tensor, "Tensor cannot be null.");
- -
- - float scale;
- - int zeroPoint;
- - QuantizationParameters quantization = tensor.quantization();
- -
- - // Tensors that are not quantized do not have quantization parameters, which can be null when
- - // being extracted from the flatbuffer.
- - if (quantization == null) {
- - scale = 0.0f;
- - zeroPoint = 0;
- - return new QuantizationParams(scale, zeroPoint);
- +
- + /** Gets the metadata FlatBuffer from the model FlatBuffer. */
- + @Nullable
- + ByteBuffer getMetadataBuffer() {
- + // Some models may not have metadata, and this is allowed.
- + if (model.metadataLength() == 0) {
- + return null;
- + }
- +
- + for (int i = 0; i < model.metadataLength(); i++) {
- + Metadata meta = model.metadata(i);
- + if (METADATA_FIELD_NAME.equals(meta.name())) {
- + long bufferIndex = meta.buffer();
- + Buffer metadataBuf = model.buffers((int) bufferIndex);
- + return metadataBuf.dataAsByteBuffer();
- + }
- + }
- + return null;
- + }
- +
- + /**
- + * Gets the output tensor with {@code outputIndex}.
- + *
- + * @param outputIndex The index of the desired outtput tensor.
- + * @throws IllegalArgumentException if the outputIndex specified is invalid.
- + */
- + @Nullable
- + Tensor getOutputTensor(int outputIndex) {
- + checkArgument(outputIndex >= 0 && outputIndex < outputTensors.size(),
- + "The outputIndex specified is invalid.");
- + return outputTensors.get(outputIndex);
- + }
- +
- + int getOutputTensorCount() {
- + return outputTensors.size();
- + }
- +
- + /**
- + * Gets shape of the output tensor with {@code outputIndex}.
- + *
- + * @param outputIndex The index of the desired outtput tensor.
- + */
- + int[] getOutputTensorShape(int outputIndex) {
- + Tensor tensor = getOutputTensor(outputIndex);
- + return getShape(tensor);
- }
-
- - // Tensors that are not quantized do not have quantization parameters.
- - // quantization.scaleLength() and quantization.zeroPointLength() may both return 0.
- - checkArgument(
- - quantization.scaleLength() <= 1,
- - "Input and output tensors do not support per-channel quantization.");
- - checkArgument(
- - quantization.zeroPointLength() <= 1,
- - "Input and output tensors do not support per-channel quantization.");
- -
- - // For tensors that are not quantized, quantization.scale(0) and quantization.zeroPoint(0) will
- - // both be the default value in flatbuffer, 0. This behavior is consistent with the TFlite C++
- - // runtime.
- - scale = quantization.scale(0);
- - // zeroPoint is a long value in the schema, but an integer in the C++ runtime. Here we keep it
- - // consistent with the C++ runtime.
- - zeroPoint = (int) quantization.zeroPoint(0);
- -
- - return new QuantizationParams(scale, zeroPoint);
- - }
- -
- - /**
- - * Verifies if the buffer is a valid TFLite model.
- - *
- - * @param buffer the TFLite model flatbuffer
- - * @throws NullPointerException if {@code buffer} is null.
- - * @throws IllegalArgumentException if {@code buffer} does not contain the expected identifier
- - */
- - private static void assertTFLiteModel(ByteBuffer buffer) {
- - checkNotNull(buffer, "Model flatbuffer cannot be null.");
- - checkArgument(
- - Model.ModelBufferHasIdentifier(buffer),
- - "The identifier of the model is invalid. The buffer may not be a valid TFLite model"
- - + " flatbuffer.");
- - }
- -
- - /**
- - * Gets the shape of a tensor.
- - *
- - * @param tensor The tensor whoes shape is desired.
- - * @throws NullPointerException if the tensor is null.
- - */
- - private static int[] getShape(Tensor tensor) {
- - checkNotNull(tensor, "Tensor cannot be null.");
- - int shapeDim = tensor.shapeLength();
- - int[] tensorShape = new int[shapeDim];
- - for (int i = 0; i < shapeDim; i++) {
- - tensorShape[i] = tensor.shape(i);
- + /**
- + * Gets the {@link TensorType} in byte of the output tensor {@code outputIndex}.
- + *
- + * @param outputIndex The index of the desired outtput tensor.
- + */
- + byte getOutputTensorType(int outputIndex) {
- + return getOutputTensor(outputIndex).type();
- }
- - return tensorShape;
- - }
- -
- - /** Gets input tensors from a model. */
- - private static List<Tensor> getInputTensors(Model model) {
- - // TFLite only support one subgraph currently.
- - SubGraph subgraph = model.subgraphs(0);
- - int tensorNum = subgraph.inputsLength();
- - ArrayList<Tensor> inputTensors = new ArrayList<>(tensorNum);
- - for (int i = 0; i < tensorNum; i++) {
- - inputTensors.add(subgraph.tensors(subgraph.inputs(i)));
- +
- + /**
- + * Gets the quantization parameters of a tensor.
- + *
- + * <p>Only quantized tensors have valid {@code QuantizationParameters}. For tensor that are not
- + * quantized, the values of scale and zero_point are both 0.
- + *
- + * @param tensor The tensor whoes quantization parameters is desired.
- + * @throws NullPointerException if the tensor is null.
- + * @throws IllegalArgumentException if {@code scale} and {@code zeroPoint} of the tensor's
- + * {@link
- + * QuantizationParameters} are not single values.
- + */
- + QuantizationParams getQuantizationParams(Tensor tensor) {
- + checkNotNull(tensor, "Tensor cannot be null.");
- +
- + float scale;
- + int zeroPoint;
- + QuantizationParameters quantization = tensor.quantization();
- +
- + // Tensors that are not quantized do not have quantization parameters, which can be null
- + // when being extracted from the flatbuffer.
- + if (quantization == null) {
- + scale = 0.0f;
- + zeroPoint = 0;
- + return new QuantizationParams(scale, zeroPoint);
- + }
- +
- + // Tensors that are not quantized do not have quantization parameters.
- + // quantization.scaleLength() and quantization.zeroPointLength() may both return 0.
- + checkArgument(quantization.scaleLength() <= 1,
- + "Input and output tensors do not support per-channel quantization.");
- + checkArgument(quantization.zeroPointLength() <= 1,
- + "Input and output tensors do not support per-channel quantization.");
- +
- + // For tensors that are not quantized, quantization.scale(0) and quantization.zeroPoint(0)
- + // will both be the default value in flatbuffer, 0. This behavior is consistent with the
- + // TFlite C++ runtime.
- + scale = quantization.scale(0);
- + // zeroPoint is a long value in the schema, but an integer in the C++ runtime. Here we keep
- + // it consistent with the C++ runtime.
- + zeroPoint = (int) quantization.zeroPoint(0);
- +
- + return new QuantizationParams(scale, zeroPoint);
- }
- - return Collections.unmodifiableList(inputTensors);
- - }
- -
- - /** Gets output tensors from a model. */
- - private static List<Tensor> getOutputTensors(Model model) {
- - // TFLite only support one subgraph currently.
- - SubGraph subgraph = model.subgraphs(0);
- - int tensorNum = subgraph.outputsLength();
- - ArrayList<Tensor> outputTensors = new ArrayList<>(tensorNum);
- - for (int i = 0; i < tensorNum; i++) {
- - outputTensors.add(subgraph.tensors(subgraph.outputs(i)));
- +
- + /**
- + * Verifies if the buffer is a valid TFLite model.
- + *
- + * @param buffer the TFLite model flatbuffer
- + * @throws NullPointerException if {@code buffer} is null.
- + * @throws IllegalArgumentException if {@code buffer} does not contain the expected identifier
- + */
- + private static void assertTFLiteModel(ByteBuffer buffer) {
- + checkNotNull(buffer, "Model flatbuffer cannot be null.");
- + checkArgument(Model.ModelBufferHasIdentifier(buffer),
- + "The identifier of the model is invalid. The buffer may not be a valid TFLite model"
- + + " flatbuffer.");
- + }
- +
- + /**
- + * Gets the shape of a tensor.
- + *
- + * @param tensor The tensor whoes shape is desired.
- + * @throws NullPointerException if the tensor is null.
- + */
- + private static int[] getShape(Tensor tensor) {
- + checkNotNull(tensor, "Tensor cannot be null.");
- + int shapeDim = tensor.shapeLength();
- + int[] tensorShape = new int[shapeDim];
- + for (int i = 0; i < shapeDim; i++) {
- + tensorShape[i] = tensor.shape(i);
- + }
- + return tensorShape;
- + }
- +
- + /** Gets input tensors from a model. */
- + private static List<Tensor> getInputTensors(Model model) {
- + // TFLite only support one subgraph currently.
- + SubGraph subgraph = model.subgraphs(0);
- + int tensorNum = subgraph.inputsLength();
- + ArrayList<Tensor> inputTensors = new ArrayList<>(tensorNum);
- + for (int i = 0; i < tensorNum; i++) {
- + inputTensors.add(subgraph.tensors(subgraph.inputs(i)));
- + }
- + return Collections.unmodifiableList(inputTensors);
- + }
- +
- + /** Gets output tensors from a model. */
- + private static List<Tensor> getOutputTensors(Model model) {
- + // TFLite only support one subgraph currently.
- + SubGraph subgraph = model.subgraphs(0);
- + int tensorNum = subgraph.outputsLength();
- + ArrayList<Tensor> outputTensors = new ArrayList<>(tensorNum);
- + for (int i = 0; i < tensorNum; i++) {
- + outputTensors.add(subgraph.tensors(subgraph.outputs(i)));
- + }
- + return Collections.unmodifiableList(outputTensors);
- }
- - return Collections.unmodifiableList(outputTensors);
- - }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ModelMetadataInfo.java b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ModelMetadataInfo.java
- index 751ed500dc2fc..7ee01df094283 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ModelMetadataInfo.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ModelMetadataInfo.java
- @@ -18,136 +18,133 @@ package org.tensorflow.lite.support.metadata;
- import static org.tensorflow.lite.support.metadata.Preconditions.checkArgument;
- import static org.tensorflow.lite.support.metadata.Preconditions.checkNotNull;
-
- -import java.nio.ByteBuffer;
- -import java.util.ArrayList;
- -import java.util.Collections;
- -import java.util.List;
- import org.checkerframework.checker.nullness.qual.Nullable;
- import org.tensorflow.lite.support.metadata.schema.ModelMetadata;
- import org.tensorflow.lite.support.metadata.schema.SubGraphMetadata;
- import org.tensorflow.lite.support.metadata.schema.TensorMetadata;
-
- +import java.nio.ByteBuffer;
- +import java.util.ArrayList;
- +import java.util.Collections;
- +import java.util.List;
- +
- /** Extracts model metadata information out of TFLite metadata FlatBuffer. */
- final class ModelMetadataInfo {
- - /** The root handler for the model metadata. */
- - private final ModelMetadata modelMetadata;
- -
- - /** Metadata array of input tensors. */
- - private final List</* @Nullable */ TensorMetadata> inputsMetadata;
- -
- - /** Metadata array of output tensors. */
- - private final List</* @Nullable */ TensorMetadata> outputsMetadata;
- -
- - /** The minimum parser version required to fully understand the metadata flatbuffer. */
- - private final String /* @Nullable */ minVersion;
- -
- - /**
- - * Creates a {@link ModelMetadataInfo} with the metadata FlatBuffer, {@code buffer}.
- - *
- - * @param buffer the TFLite metadata FlatBuffer
- - * @throws NullPointerException if {@code buffer} is null
- - * @throws IllegalArgumentException if {@code buffer} does not contain any subgraph metadata, or
- - * it does not contain the expected identifier
- - */
- - ModelMetadataInfo(ByteBuffer buffer) {
- - assertTFLiteMetadata(buffer);
- -
- - modelMetadata = ModelMetadata.getRootAsModelMetadata(buffer);
- - checkArgument(
- - modelMetadata.subgraphMetadataLength() > 0,
- - "The metadata flatbuffer does not contain any subgraph metadata.");
- -
- - inputsMetadata = getInputsMetadata(modelMetadata);
- - outputsMetadata = getOutputsMetadata(modelMetadata);
- - minVersion = modelMetadata.minParserVersion();
- - }
- -
- - /** Gets the count of input tensors with metadata in the metadata FlatBuffer. */
- - int getInputTensorCount() {
- - return inputsMetadata.size();
- - }
- -
- - /**
- - * Gets the metadata for the input tensor specified by {@code inputIndex}.
- - *
- - * @param inputIndex The index of the desired intput tensor.
- - * @throws IllegalArgumentException if the inputIndex specified is invalid.
- - */
- - @Nullable
- - TensorMetadata getInputTensorMetadata(int inputIndex) {
- - checkArgument(
- - inputIndex >= 0 && inputIndex < inputsMetadata.size(),
- - "The inputIndex specified is invalid.");
- - return inputsMetadata.get(inputIndex);
- - }
- -
- - /**
- - * Gets the minimum parser version of the metadata. It can be {@code null} if the version is not
- - * populated.
- - */
- - @Nullable
- - String getMininumParserVersion() {
- - return minVersion;
- - }
- -
- - /** Gets the root handler for the model metadata. */
- - ModelMetadata getModelMetadata() {
- - return modelMetadata;
- - }
- -
- - /** Gets the count of output tensors with metadata in the metadata FlatBuffer. */
- - int getOutputTensorCount() {
- - return outputsMetadata.size();
- - }
- -
- - /**
- - * Gets the metadata for the output tensor specified by {@code outputIndex}.
- - *
- - * @param outputIndex The index of the desired output tensor.
- - * @throws IllegalArgumentException if the outputIndex specified is invalid.
- - */
- - @Nullable
- - TensorMetadata getOutputTensorMetadata(int outputIndex) {
- - checkArgument(
- - outputIndex >= 0 && outputIndex < outputsMetadata.size(),
- - "The outputIndex specified is invalid.");
- - return outputsMetadata.get(outputIndex);
- - }
- -
- - /**
- - * Verifies if the buffer is a valid TFLite metadata flatbuffer.
- - *
- - * @param buffer the TFLite metadata flatbuffer
- - * @throws NullPointerException if {@code buffer} is null.
- - * @throws IllegalArgumentException if {@code buffer} does not contain the expected identifier
- - */
- - private static void assertTFLiteMetadata(ByteBuffer buffer) {
- - checkNotNull(buffer, "Metadata flatbuffer cannot be null.");
- - checkArgument(
- - ModelMetadata.ModelMetadataBufferHasIdentifier(buffer),
- - "The identifier of the metadata is invalid. The buffer may not be a valid TFLite metadata"
- - + " flatbuffer.");
- - }
- -
- - /** Gets metadata for all input tensors. */
- - private static List<TensorMetadata> getInputsMetadata(ModelMetadata modelMetadata) {
- - SubGraphMetadata subgraphMetadata = modelMetadata.subgraphMetadata(0);
- - int tensorNum = subgraphMetadata.inputTensorMetadataLength();
- - ArrayList<TensorMetadata> inputsMetadata = new ArrayList<>(tensorNum);
- - for (int i = 0; i < tensorNum; i++) {
- - inputsMetadata.add(subgraphMetadata.inputTensorMetadata(i));
- + /** The root handler for the model metadata. */
- + private final ModelMetadata modelMetadata;
- +
- + /** Metadata array of input tensors. */
- + private final List</* @Nullable */ TensorMetadata> inputsMetadata;
- +
- + /** Metadata array of output tensors. */
- + private final List</* @Nullable */ TensorMetadata> outputsMetadata;
- +
- + /** The minimum parser version required to fully understand the metadata flatbuffer. */
- + private final String /* @Nullable */ minVersion;
- +
- + /**
- + * Creates a {@link ModelMetadataInfo} with the metadata FlatBuffer, {@code buffer}.
- + *
- + * @param buffer the TFLite metadata FlatBuffer
- + * @throws NullPointerException if {@code buffer} is null
- + * @throws IllegalArgumentException if {@code buffer} does not contain any subgraph metadata, or
- + * it does not contain the expected identifier
- + */
- + ModelMetadataInfo(ByteBuffer buffer) {
- + assertTFLiteMetadata(buffer);
- +
- + modelMetadata = ModelMetadata.getRootAsModelMetadata(buffer);
- + checkArgument(modelMetadata.subgraphMetadataLength() > 0,
- + "The metadata flatbuffer does not contain any subgraph metadata.");
- +
- + inputsMetadata = getInputsMetadata(modelMetadata);
- + outputsMetadata = getOutputsMetadata(modelMetadata);
- + minVersion = modelMetadata.minParserVersion();
- + }
- +
- + /** Gets the count of input tensors with metadata in the metadata FlatBuffer. */
- + int getInputTensorCount() {
- + return inputsMetadata.size();
- + }
- +
- + /**
- + * Gets the metadata for the input tensor specified by {@code inputIndex}.
- + *
- + * @param inputIndex The index of the desired intput tensor.
- + * @throws IllegalArgumentException if the inputIndex specified is invalid.
- + */
- + @Nullable
- + TensorMetadata getInputTensorMetadata(int inputIndex) {
- + checkArgument(inputIndex >= 0 && inputIndex < inputsMetadata.size(),
- + "The inputIndex specified is invalid.");
- + return inputsMetadata.get(inputIndex);
- }
- - return Collections.unmodifiableList(inputsMetadata);
- - }
- -
- - /** Gets metadata for all output tensors. */
- - private static List<TensorMetadata> getOutputsMetadata(ModelMetadata modelMetadata) {
- - SubGraphMetadata subgraphMetadata = modelMetadata.subgraphMetadata(0);
- - int tensorNum = subgraphMetadata.outputTensorMetadataLength();
- - ArrayList<TensorMetadata> outputsMetadata = new ArrayList<>(tensorNum);
- - for (int i = 0; i < tensorNum; i++) {
- - outputsMetadata.add(subgraphMetadata.outputTensorMetadata(i));
- +
- + /**
- + * Gets the minimum parser version of the metadata. It can be {@code null} if the version is not
- + * populated.
- + */
- + @Nullable
- + String getMininumParserVersion() {
- + return minVersion;
- + }
- +
- + /** Gets the root handler for the model metadata. */
- + ModelMetadata getModelMetadata() {
- + return modelMetadata;
- + }
- +
- + /** Gets the count of output tensors with metadata in the metadata FlatBuffer. */
- + int getOutputTensorCount() {
- + return outputsMetadata.size();
- + }
- +
- + /**
- + * Gets the metadata for the output tensor specified by {@code outputIndex}.
- + *
- + * @param outputIndex The index of the desired output tensor.
- + * @throws IllegalArgumentException if the outputIndex specified is invalid.
- + */
- + @Nullable
- + TensorMetadata getOutputTensorMetadata(int outputIndex) {
- + checkArgument(outputIndex >= 0 && outputIndex < outputsMetadata.size(),
- + "The outputIndex specified is invalid.");
- + return outputsMetadata.get(outputIndex);
- + }
- +
- + /**
- + * Verifies if the buffer is a valid TFLite metadata flatbuffer.
- + *
- + * @param buffer the TFLite metadata flatbuffer
- + * @throws NullPointerException if {@code buffer} is null.
- + * @throws IllegalArgumentException if {@code buffer} does not contain the expected identifier
- + */
- + private static void assertTFLiteMetadata(ByteBuffer buffer) {
- + checkNotNull(buffer, "Metadata flatbuffer cannot be null.");
- + checkArgument(ModelMetadata.ModelMetadataBufferHasIdentifier(buffer),
- + "The identifier of the metadata is invalid. The buffer may not be a valid TFLite metadata"
- + + " flatbuffer.");
- + }
- +
- + /** Gets metadata for all input tensors. */
- + private static List<TensorMetadata> getInputsMetadata(ModelMetadata modelMetadata) {
- + SubGraphMetadata subgraphMetadata = modelMetadata.subgraphMetadata(0);
- + int tensorNum = subgraphMetadata.inputTensorMetadataLength();
- + ArrayList<TensorMetadata> inputsMetadata = new ArrayList<>(tensorNum);
- + for (int i = 0; i < tensorNum; i++) {
- + inputsMetadata.add(subgraphMetadata.inputTensorMetadata(i));
- + }
- + return Collections.unmodifiableList(inputsMetadata);
- + }
- +
- + /** Gets metadata for all output tensors. */
- + private static List<TensorMetadata> getOutputsMetadata(ModelMetadata modelMetadata) {
- + SubGraphMetadata subgraphMetadata = modelMetadata.subgraphMetadata(0);
- + int tensorNum = subgraphMetadata.outputTensorMetadataLength();
- + ArrayList<TensorMetadata> outputsMetadata = new ArrayList<>(tensorNum);
- + for (int i = 0; i < tensorNum; i++) {
- + outputsMetadata.add(subgraphMetadata.outputTensorMetadata(i));
- + }
- + return Collections.unmodifiableList(outputsMetadata);
- }
- - return Collections.unmodifiableList(outputsMetadata);
- - }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/Preconditions.java b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/Preconditions.java
- index c2f20fbaacd76..ca3eed3490644 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/Preconditions.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/Preconditions.java
- @@ -19,166 +19,170 @@ import org.checkerframework.checker.nullness.qual.Nullable;
-
- /** Static error checking util methods. */
- final class Preconditions {
- - /**
- - * Ensures that an object reference passed as a parameter to the calling method is not null.
- - *
- - * @param reference an object reference
- - * @return the non-null reference that was validated
- - * @throws NullPointerException if {@code reference} is null
- - */
- - public static <T extends Object> T checkNotNull(T reference) {
- - if (reference == null) {
- - throw new NullPointerException("The object reference is null.");
- + /**
- + * Ensures that an object reference passed as a parameter to the calling method is not null.
- + *
- + * @param reference an object reference
- + * @return the non-null reference that was validated
- + * @throws NullPointerException if {@code reference} is null
- + */
- + public static <T extends Object> T checkNotNull(T reference) {
- + if (reference == null) {
- + throw new NullPointerException("The object reference is null.");
- + }
- + return reference;
- }
- - return reference;
- - }
- -
- - /**
- - * Ensures that an object reference passed as a parameter to the calling method is not null.
- - *
- - * @param reference an object reference
- - * @param errorMessage the exception message to use if the check fails; will be converted to a
- - * string using {@link String#valueOf(Object)}
- - * @return the non-null reference that was validated
- - * @throws NullPointerException if {@code reference} is null
- - */
- - public static <T extends Object> T checkNotNull(T reference, @Nullable Object errorMessage) {
- - if (reference == null) {
- - throw new NullPointerException(String.valueOf(errorMessage));
- +
- + /**
- + * Ensures that an object reference passed as a parameter to the calling method is not null.
- + *
- + * @param reference an object reference
- + * @param errorMessage the exception message to use if the check fails; will be converted to a
- + * string using {@link String#valueOf(Object)}
- + * @return the non-null reference that was validated
- + * @throws NullPointerException if {@code reference} is null
- + */
- + public static <T extends Object> T checkNotNull(T reference, @Nullable Object errorMessage) {
- + if (reference == null) {
- + throw new NullPointerException(String.valueOf(errorMessage));
- + }
- + return reference;
- + }
- +
- + /**
- + * Ensures that the given String is not empty and not null.
- + *
- + * @param string the String to test
- + * @return the non-null non-empty String that was validated
- + * @throws IllegalArgumentException if {@code string} is null or empty
- + */
- + public static String checkNotEmpty(String string) {
- + if (string == null || string.length() == 0) {
- + throw new IllegalArgumentException("Given String is empty or null.");
- + }
- + return string;
- }
- - return reference;
- - }
- -
- - /**
- - * Ensures that the given String is not empty and not null.
- - *
- - * @param string the String to test
- - * @return the non-null non-empty String that was validated
- - * @throws IllegalArgumentException if {@code string} is null or empty
- - */
- - public static String checkNotEmpty(String string) {
- - if (string == null || string.length() == 0) {
- - throw new IllegalArgumentException("Given String is empty or null.");
- +
- + /**
- + * Ensures that the given String is not empty and not null.
- + *
- + * @param string the String to test
- + * @param errorMessage the exception message to use if the check fails; will be converted to a
- + * string using {@link String#valueOf(Object)}
- + * @return the non-null non-empty String that was validated
- + * @throws IllegalArgumentException if {@code string} is null or empty
- + */
- + public static String checkNotEmpty(String string, Object errorMessage) {
- + if (string == null || string.length() == 0) {
- + throw new IllegalArgumentException(String.valueOf(errorMessage));
- + }
- + return string;
- }
- - return string;
- - }
- -
- - /**
- - * Ensures that the given String is not empty and not null.
- - *
- - * @param string the String to test
- - * @param errorMessage the exception message to use if the check fails; will be converted to a
- - * string using {@link String#valueOf(Object)}
- - * @return the non-null non-empty String that was validated
- - * @throws IllegalArgumentException if {@code string} is null or empty
- - */
- - public static String checkNotEmpty(String string, Object errorMessage) {
- - if (string == null || string.length() == 0) {
- - throw new IllegalArgumentException(String.valueOf(errorMessage));
- +
- + /**
- + * Ensures the truth of an expression involving one or more parameters to the calling method.
- + *
- + * @param expression a boolean expression.
- + * @throws IllegalArgumentException if {@code expression} is false.
- + */
- + public static void checkArgument(boolean expression) {
- + if (!expression) {
- + throw new IllegalArgumentException();
- + }
- }
- - return string;
- - }
- -
- - /**
- - * Ensures the truth of an expression involving one or more parameters to the calling method.
- - *
- - * @param expression a boolean expression.
- - * @throws IllegalArgumentException if {@code expression} is false.
- - */
- - public static void checkArgument(boolean expression) {
- - if (!expression) {
- - throw new IllegalArgumentException();
- +
- + /**
- + * Ensures the truth of an expression involving one or more parameters to the calling method.
- + *
- + * @param expression a boolean expression.
- + * @param errorMessage the exception message to use if the check fails; will be converted to a
- + * string using {@link String#valueOf(Object)}.
- + * @throws IllegalArgumentException if {@code expression} is false.
- + */
- + public static void checkArgument(boolean expression, @Nullable Object errorMessage) {
- + if (!expression) {
- + throw new IllegalArgumentException(String.valueOf(errorMessage));
- + }
- }
- - }
- -
- - /**
- - * Ensures the truth of an expression involving one or more parameters to the calling method.
- - *
- - * @param expression a boolean expression.
- - * @param errorMessage the exception message to use if the check fails; will be converted to a
- - * string using {@link String#valueOf(Object)}.
- - * @throws IllegalArgumentException if {@code expression} is false.
- - */
- - public static void checkArgument(boolean expression, @Nullable Object errorMessage) {
- - if (!expression) {
- - throw new IllegalArgumentException(String.valueOf(errorMessage));
- +
- + /**
- + * Ensures that {@code index} specifies a valid <i>element</i> in an array, list or string of
- + * size
- + * {@code size}. An element index may range from zero, inclusive, to {@code size}, exclusive.
- + *
- + * @param index a user-supplied index identifying an element of an array, list or string
- + * @param size the size of that array, list or string
- + * @return the value of {@code index}
- + * @throws IndexOutOfBoundsException if {@code index} is negative or is not less than {@code
- + * size}
- + * @throws IllegalArgumentException if {@code size} is negative
- + */
- + public static int checkElementIndex(int index, int size) {
- + return checkElementIndex(index, size, "index");
- }
- - }
- -
- - /**
- - * Ensures that {@code index} specifies a valid <i>element</i> in an array, list or string of size
- - * {@code size}. An element index may range from zero, inclusive, to {@code size}, exclusive.
- - *
- - * @param index a user-supplied index identifying an element of an array, list or string
- - * @param size the size of that array, list or string
- - * @return the value of {@code index}
- - * @throws IndexOutOfBoundsException if {@code index} is negative or is not less than {@code size}
- - * @throws IllegalArgumentException if {@code size} is negative
- - */
- - public static int checkElementIndex(int index, int size) {
- - return checkElementIndex(index, size, "index");
- - }
- -
- - /**
- - * Ensures that {@code index} specifies a valid <i>element</i> in an array, list or string of size
- - * {@code size}. An element index may range from zero, inclusive, to {@code size}, exclusive.
- - *
- - * @param index a user-supplied index identifying an element of an array, list or string
- - * @param size the size of that array, list or string
- - * @param desc the text to use to describe this index in an error message
- - * @return the value of {@code index}
- - * @throws IndexOutOfBoundsException if {@code index} is negative or is not less than {@code size}
- - * @throws IllegalArgumentException if {@code size} is negative
- - */
- - public static int checkElementIndex(int index, int size, @Nullable String desc) {
- - // Carefully optimized for execution by hotspot (explanatory comment above)
- - if (index < 0 || index >= size) {
- - throw new IndexOutOfBoundsException(badElementIndex(index, size, desc));
- +
- + /**
- + * Ensures that {@code index} specifies a valid <i>element</i> in an array, list or string of
- + * size
- + * {@code size}. An element index may range from zero, inclusive, to {@code size}, exclusive.
- + *
- + * @param index a user-supplied index identifying an element of an array, list or string
- + * @param size the size of that array, list or string
- + * @param desc the text to use to describe this index in an error message
- + * @return the value of {@code index}
- + * @throws IndexOutOfBoundsException if {@code index} is negative or is not less than {@code
- + * size}
- + * @throws IllegalArgumentException if {@code size} is negative
- + */
- + public static int checkElementIndex(int index, int size, @Nullable String desc) {
- + // Carefully optimized for execution by hotspot (explanatory comment above)
- + if (index < 0 || index >= size) {
- + throw new IndexOutOfBoundsException(badElementIndex(index, size, desc));
- + }
- + return index;
- }
- - return index;
- - }
- -
- - /**
- - * Ensures the truth of an expression involving the state of the calling instance, but not
- - * involving any parameters to the calling method.
- - *
- - * @param expression a boolean expression
- - * @throws IllegalStateException if {@code expression} is false
- - * @see Verify#verify Verify.verify()
- - */
- - public static void checkState(boolean expression) {
- - if (!expression) {
- - throw new IllegalStateException();
- +
- + /**
- + * Ensures the truth of an expression involving the state of the calling instance, but not
- + * involving any parameters to the calling method.
- + *
- + * @param expression a boolean expression
- + * @throws IllegalStateException if {@code expression} is false
- + * @see Verify#verify Verify.verify()
- + */
- + public static void checkState(boolean expression) {
- + if (!expression) {
- + throw new IllegalStateException();
- + }
- }
- - }
- -
- - /**
- - * Ensures the truth of an expression involving the state of the calling instance, but not
- - * involving any parameters to the calling method.
- - *
- - * @param expression a boolean expression
- - * @param errorMessage the exception message to use if the check fails; will be converted to a
- - * string using {@link String#valueOf(Object)}
- - * @throws IllegalStateException if {@code expression} is false
- - * @see Verify#verify Verify.verify()
- - */
- - public static void checkState(boolean expression, @Nullable Object errorMessage) {
- - if (!expression) {
- - throw new IllegalStateException(String.valueOf(errorMessage));
- +
- + /**
- + * Ensures the truth of an expression involving the state of the calling instance, but not
- + * involving any parameters to the calling method.
- + *
- + * @param expression a boolean expression
- + * @param errorMessage the exception message to use if the check fails; will be converted to a
- + * string using {@link String#valueOf(Object)}
- + * @throws IllegalStateException if {@code expression} is false
- + * @see Verify#verify Verify.verify()
- + */
- + public static void checkState(boolean expression, @Nullable Object errorMessage) {
- + if (!expression) {
- + throw new IllegalStateException(String.valueOf(errorMessage));
- + }
- }
- - }
- -
- - private static String badElementIndex(int index, int size, @Nullable String desc) {
- - if (index < 0) {
- - return String.format("%s (%s) must not be negative", desc, index);
- - } else if (size < 0) {
- - throw new IllegalArgumentException("negative size: " + size);
- - } else { // index >= size
- - return String.format("%s (%s) must be less than size (%s)", desc, index, size);
- +
- + private static String badElementIndex(int index, int size, @Nullable String desc) {
- + if (index < 0) {
- + return String.format("%s (%s) must not be negative", desc, index);
- + } else if (size < 0) {
- + throw new IllegalArgumentException("negative size: " + size);
- + } else { // index >= size
- + return String.format("%s (%s) must be less than size (%s)", desc, index, size);
- + }
- }
- - }
-
- - private Preconditions() {
- - throw new AssertionError("Preconditions is Uninstantiable.");
- - }
- + private Preconditions() {
- + throw new AssertionError("Preconditions is Uninstantiable.");
- + }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/SeekableByteChannelCompat.java b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/SeekableByteChannelCompat.java
- index c655786755baa..1408a3a73d86b 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/SeekableByteChannelCompat.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/SeekableByteChannelCompat.java
- @@ -29,79 +29,79 @@ import java.nio.channels.Channel;
- * the MetadtaExtractor library consistent with the common used Java libraries.
- */
- interface SeekableByteChannelCompat extends Channel {
- - /**
- - * Reads a sequence of bytes from this channel into the given buffer.
- - *
- - * @param dst The buffer into which bytes are to be transferred
- - * @return The number of bytes read, possibly zero, or <tt>-1</tt> if the channel has reached
- - * end-of-stream
- - * @throws NonReadableChannelException If this channel was not opened for reading
- - * @throws ClosedChannelException If this channel is closed
- - * @throws AsynchronousCloseException If another thread closes this channel while the read
- - * operation is in progress
- - * @throws ClosedByInterruptException If another thread interrupts the current thread while the
- - * read operation is in progress, thereby closing the channel and setting the current thread's
- - * interrupt status
- - * @throws IOException If some other I/O error occurs
- - */
- - int read(ByteBuffer dst) throws IOException;
- + /**
- + * Reads a sequence of bytes from this channel into the given buffer.
- + *
- + * @param dst The buffer into which bytes are to be transferred
- + * @return The number of bytes read, possibly zero, or <tt>-1</tt> if the channel has reached
- + * end-of-stream
- + * @throws NonReadableChannelException If this channel was not opened for reading
- + * @throws ClosedChannelException If this channel is closed
- + * @throws AsynchronousCloseException If another thread closes this channel while the read
- + * operation is in progress
- + * @throws ClosedByInterruptException If another thread interrupts the current thread while the
- + * read operation is in progress, thereby closing the channel and setting the current
- + * thread's interrupt status
- + * @throws IOException If some other I/O error occurs
- + */
- + int read(ByteBuffer dst) throws IOException;
-
- - /**
- - * Writes a sequence of bytes to this channel from the given buffer.
- - *
- - * @param src The buffer from which bytes are to be retrieved
- - * @return The number of bytes written, possibly zero
- - * @throws NonWritableChannelException If this channel was not opened for writing
- - * @throws ClosedChannelException If this channel is closed
- - * @throws AsynchronousCloseException If another thread closes this channel while the write
- - * operation is in progress
- - * @throws ClosedByInterruptException If another thread interrupts the current thread while the
- - * write operation is in progress, thereby closing the channel and setting the current
- - * thread's interrupt status
- - * @throws IOException If some other I/O error occurs
- - */
- - int write(ByteBuffer src) throws IOException;
- + /**
- + * Writes a sequence of bytes to this channel from the given buffer.
- + *
- + * @param src The buffer from which bytes are to be retrieved
- + * @return The number of bytes written, possibly zero
- + * @throws NonWritableChannelException If this channel was not opened for writing
- + * @throws ClosedChannelException If this channel is closed
- + * @throws AsynchronousCloseException If another thread closes this channel while the write
- + * operation is in progress
- + * @throws ClosedByInterruptException If another thread interrupts the current thread while the
- + * write operation is in progress, thereby closing the channel and setting the current
- + * thread's interrupt status
- + * @throws IOException If some other I/O error occurs
- + */
- + int write(ByteBuffer src) throws IOException;
-
- - /**
- - * Returns this channel's position.
- - *
- - * @return This channel's position, a non-negative integer counting the number of bytes from the
- - * beginning of the entity to the current position
- - * @throws ClosedChannelException If this channel is closed
- - * @throws IOException If some other I/O error occurs
- - */
- - long position() throws IOException;
- + /**
- + * Returns this channel's position.
- + *
- + * @return This channel's position, a non-negative integer counting the number of bytes from the
- + * beginning of the entity to the current position
- + * @throws ClosedChannelException If this channel is closed
- + * @throws IOException If some other I/O error occurs
- + */
- + long position() throws IOException;
-
- - /**
- - * Sets this channel's position.
- - *
- - * @param newPosition The new position, a non-negative integer counting the number of bytes from
- - * the beginning of the entity
- - * @return This channel
- - * @throws ClosedChannelException If this channel is closed
- - * @throws IllegalArgumentException If the new position is negative
- - * @throws IOException If some other I/O error occurs
- - */
- - SeekableByteChannelCompat position(long newPosition) throws IOException;
- + /**
- + * Sets this channel's position.
- + *
- + * @param newPosition The new position, a non-negative integer counting the number of bytes from
- + * the beginning of the entity
- + * @return This channel
- + * @throws ClosedChannelException If this channel is closed
- + * @throws IllegalArgumentException If the new position is negative
- + * @throws IOException If some other I/O error occurs
- + */
- + SeekableByteChannelCompat position(long newPosition) throws IOException;
-
- - /**
- - * Returns the current size of entity to which this channel is connected.
- - *
- - * @return The current size, measured in bytes
- - * @throws ClosedChannelException If this channel is closed
- - * @throws IOException If some other I/O error occurs
- - */
- - long size() throws IOException;
- + /**
- + * Returns the current size of entity to which this channel is connected.
- + *
- + * @return The current size, measured in bytes
- + * @throws ClosedChannelException If this channel is closed
- + * @throws IOException If some other I/O error occurs
- + */
- + long size() throws IOException;
-
- - /**
- - * Truncates the entity, to which this channel is connected, to the given size.
- - *
- - * @param size The new size, a non-negative byte count
- - * @return This channel
- - * @throws NonWritableChannelException If this channel was not opened for writing
- - * @throws ClosedChannelException If this channel is closed
- - * @throws IllegalArgumentException If the new size is negative
- - * @throws IOException If some other I/O error occurs
- - */
- - SeekableByteChannelCompat truncate(long size) throws IOException;
- + /**
- + * Truncates the entity, to which this channel is connected, to the given size.
- + *
- + * @param size The new size, a non-negative byte count
- + * @return This channel
- + * @throws NonWritableChannelException If this channel was not opened for writing
- + * @throws ClosedChannelException If this channel is closed
- + * @throws IllegalArgumentException If the new size is negative
- + * @throws IOException If some other I/O error occurs
- + */
- + SeekableByteChannelCompat truncate(long size) throws IOException;
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ZipFile.java b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ZipFile.java
- index 6b43e724fd814..c8a3fb806d920 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ZipFile.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ZipFile.java
- @@ -45,393 +45,389 @@ import java.util.zip.ZipException;
- * size limit for Zip64, which is 4GB.
- */
- final class ZipFile implements Closeable {
- - /** Maps String to list of ZipEntrys, name -> actual entries. */
- - private final Map<String, List<ZipEntry>> nameMap;
- -
- - /** The actual data source. */
- - private final ByteBufferChannel archive;
- -
- - /**
- - * Opens the given {@link ByteBufferChannel} for reading, assuming "UTF8" for file names. {@link
- - * ZipFile} does not synchronized over the buffer that is passed into it.
- - *
- - * @param channel the archive
- - * @throws IOException if an error occurs while creating this {@link ZipFile}
- - * @throws ZipException if the channel is not a zip archive
- - * @throws NullPointerException if the archive is null
- - */
- - public static ZipFile createFrom(ByteBufferChannel channel) throws IOException {
- - checkNotNull(channel);
- - ZipParser zipParser = new ZipParser(channel);
- - Map<String, List<ZipEntry>> nameMap = zipParser.parseEntries();
- - return new ZipFile(channel, nameMap);
- - }
- -
- - @Override
- - public void close() {
- - archive.close();
- - }
- -
- - /**
- - * Exposes the raw stream of the archive entry.
- - *
- - * <p>Since the associated files will not be compressed when being packed to the zip file, the raw
- - * stream represents the non-compressed files.
- - *
- - * <p><b>WARNING:</b> The returned {@link InputStream}, is <b>not</b> thread-safe. If multiple
- - * threads concurrently reading from the returned {@link InputStream}, it must be synchronized
- - * externally.
- - *
- - * @param name name of the entry to get the stream for
- - * @return the raw input stream containing data
- - * @throws IllegalArgumentException if the specified file does not exist in the zip file
- - */
- - public InputStream getRawInputStream(String name) {
- - checkArgument(
- - nameMap.containsKey(name),
- - String.format("The file, %s, does not exist in the zip file.", name));
- -
- - List<ZipEntry> entriesWithTheSameName = nameMap.get(name);
- - ZipEntry entry = entriesWithTheSameName.get(0);
- - long start = entry.getDataOffset();
- - long remaining = entry.getSize();
- - return new BoundedInputStream(archive, start, remaining);
- - }
- -
- - /**
- - * Exposes the file names of the included files.
- - *
- - * @return the file names of the included files
- - */
- - public Set<String> getFileNames() {
- - return nameMap.keySet();
- - }
- -
- - private ZipFile(ByteBufferChannel channel, Map<String, List<ZipEntry>> nameMap) {
- - archive = channel;
- - this.nameMap = nameMap;
- - }
- -
- - /* Parses a Zip archive and gets the information for each {@link ZipEntry}. */
- - private static class ZipParser {
- - private final ByteBufferChannel archive;
- -
- - // Cached buffers that will only be used locally in the class to reduce garbage collection.
- - private final ByteBuffer longBuffer =
- - ByteBuffer.allocate(ZipConstants.LONG_BYTE_SIZE).order(ByteOrder.LITTLE_ENDIAN);
- - private final ByteBuffer intBuffer =
- - ByteBuffer.allocate(ZipConstants.INT_BYTE_SIZE).order(ByteOrder.LITTLE_ENDIAN);
- - private final ByteBuffer shortBuffer =
- - ByteBuffer.allocate(ZipConstants.SHORT_BYTE_SIZE).order(ByteOrder.LITTLE_ENDIAN);
- + /** Maps String to list of ZipEntrys, name -> actual entries. */
- + private final Map<String, List<ZipEntry>> nameMap;
-
- - private ZipParser(ByteBufferChannel archive) {
- - this.archive = archive;
- - }
- -
- - /**
- - * Parses the underlying {@code archive} and returns the information as a list of {@link
- - * ZipEntry}.
- - */
- - private Map<String, List<ZipEntry>> parseEntries() throws IOException {
- - List<ZipEntry> entries = parseCentralDirectory();
- - return parseLocalFileHeaderData(entries);
- - }
- -
- - /**
- - * Checks if the current position contains a central file header signature, {@link
- - * ZipConstants#CENSIG}.
- - */
- - private boolean foundCentralFileheaderSignature() {
- - long signature = (long) getInt();
- - return signature == ZipConstants.CENSIG;
- - }
- -
- - /**
- - * Gets the value as a Java int from two bytes starting at the current position of the archive.
- - */
- - private int getShort() {
- - shortBuffer.rewind();
- - archive.read(shortBuffer);
- - shortBuffer.flip();
- - return (int) shortBuffer.getShort();
- - }
- + /** The actual data source. */
- + private final ByteBufferChannel archive;
-
- /**
- - * Gets the value as a Java long from four bytes starting at the current position of the
- - * archive.
- + * Opens the given {@link ByteBufferChannel} for reading, assuming "UTF8" for file names. {@link
- + * ZipFile} does not synchronized over the buffer that is passed into it.
- + *
- + * @param channel the archive
- + * @throws IOException if an error occurs while creating this {@link ZipFile}
- + * @throws ZipException if the channel is not a zip archive
- + * @throws NullPointerException if the archive is null
- */
- - private int getInt() {
- - intBuffer.rewind();
- - archive.read(intBuffer);
- - intBuffer.flip();
- - return intBuffer.getInt();
- + public static ZipFile createFrom(ByteBufferChannel channel) throws IOException {
- + checkNotNull(channel);
- + ZipParser zipParser = new ZipParser(channel);
- + Map<String, List<ZipEntry>> nameMap = zipParser.parseEntries();
- + return new ZipFile(channel, nameMap);
- }
-
- - /**
- - * Gets the value as a Java long from four bytes starting at the current position of the
- - * archive.
- - */
- - private long getLong() {
- - longBuffer.rewind();
- - archive.read(longBuffer);
- - longBuffer.flip();
- - return longBuffer.getLong();
- + @Override
- + public void close() {
- + archive.close();
- }
-
- /**
- - * Positions the archive at the start of the central directory.
- + * Exposes the raw stream of the archive entry.
- + *
- + * <p>Since the associated files will not be compressed when being packed to the zip file, the
- + * raw stream represents the non-compressed files.
- *
- - * <p>First, it searches for the signature of the "end of central directory record", {@link
- - * ZipConstants#ENDSIG}. Position the stream at the start of the "end of central directory
- - * record". The zip file are created without archive comments, thus {@link ZipConstants#ENDSIG}
- - * should appear exactly at {@link ZipConstants#ENDHDR} from the end of the zip file.
- + * <p><b>WARNING:</b> The returned {@link InputStream}, is <b>not</b> thread-safe. If multiple
- + * threads concurrently reading from the returned {@link InputStream}, it must be synchronized
- + * externally.
- *
- - * <p>Then, parse the "end of central dir record" and position the archive at the start of the
- - * central directory.
- + * @param name name of the entry to get the stream for
- + * @return the raw input stream containing data
- + * @throws IllegalArgumentException if the specified file does not exist in the zip file
- */
- - private void locateCentralDirectory() throws IOException {
- - if (archive.size() < ZipConstants.ENDHDR) {
- - throw new ZipException("The archive is not a ZIP archive.");
- - }
- -
- - // Positions the archive at the start of the "end of central directory record".
- - long offsetRecord = archive.size() - ZipConstants.ENDHDR;
- - archive.position(offsetRecord);
- -
- - // Checks for the signature, {@link ZipConstants#ENDSIG}.
- - long endSig = getLong();
- - if (endSig != ZipConstants.ENDSIG) {
- - throw new ZipException("The archive is not a ZIP archive.");
- - }
- -
- - // Positions the archive at the “offset of central directory”.
- - skipBytes(ZipConstants.ENDOFF - ZipConstants.ENDSUB);
- - // Gets the offset to central directory
- - long offsetDirectory = getInt();
- - // Goes to the central directory.
- - archive.position(offsetDirectory);
- + public InputStream getRawInputStream(String name) {
- + checkArgument(nameMap.containsKey(name),
- + String.format("The file, %s, does not exist in the zip file.", name));
- +
- + List<ZipEntry> entriesWithTheSameName = nameMap.get(name);
- + ZipEntry entry = entriesWithTheSameName.get(0);
- + long start = entry.getDataOffset();
- + long remaining = entry.getSize();
- + return new BoundedInputStream(archive, start, remaining);
- }
-
- /**
- - * Reads the central directory of the given archive and populates the internal tables with
- - * {@link ZipEntry} instances.
- + * Exposes the file names of the included files.
- + *
- + * @return the file names of the included files
- */
- - private List<ZipEntry> parseCentralDirectory() throws IOException {
- - /** List of entries in the order they appear inside the central directory. */
- - List<ZipEntry> entries = new ArrayList<>();
- - locateCentralDirectory();
- -
- - while (foundCentralFileheaderSignature()) {
- - ZipEntry entry = parseCentralDirectoryEntry();
- - entries.add(entry);
- - }
- -
- - return entries;
- + public Set<String> getFileNames() {
- + return nameMap.keySet();
- }
-
- - /**
- - * Reads an individual entry of the central directory, creats an ZipEntry from it and adds it to
- - * the global maps.
- - */
- - private ZipEntry parseCentralDirectoryEntry() throws IOException {
- - // Positions the archive at the "compressed size" and read the value.
- - skipBytes(ZipConstants.CENSIZ - ZipConstants.CENVEM);
- - long compressSize = getInt();
- -
- - // Positions the archive at the "filename length" and read the value.
- - skipBytes(ZipConstants.CENNAM - ZipConstants.CENLEN);
- - int fileNameLen = getShort();
- -
- - // Reads the extra field length and the comment length.
- - int extraLen = getShort();
- - int commentLen = getShort();
- -
- - // Positions the archive at the "local file header offset" and read the value.
- - skipBytes(ZipConstants.CENOFF - ZipConstants.CENDSK);
- - long localHeaderOffset = getInt();
- -
- - // Reads the file name.
- - byte[] fileNameBuf = new byte[fileNameLen];
- - archive.read(ByteBuffer.wrap(fileNameBuf));
- - String fileName = new String(fileNameBuf, Charset.forName("UTF-8"));
- + private ZipFile(ByteBufferChannel channel, Map<String, List<ZipEntry>> nameMap) {
- + archive = channel;
- + this.nameMap = nameMap;
- + }
-
- - // Skips the extra field and the comment.
- - skipBytes(extraLen + commentLen);
- + /* Parses a Zip archive and gets the information for each {@link ZipEntry}. */
- + private static class ZipParser {
- + private final ByteBufferChannel archive;
-
- - ZipEntry entry = new ZipEntry();
- - entry.setSize(compressSize);
- - entry.setLocalHeaderOffset(localHeaderOffset);
- - entry.setName(fileName);
- + // Cached buffers that will only be used locally in the class to reduce garbage collection.
- + private final ByteBuffer longBuffer =
- + ByteBuffer.allocate(ZipConstants.LONG_BYTE_SIZE).order(ByteOrder.LITTLE_ENDIAN);
- + private final ByteBuffer intBuffer =
- + ByteBuffer.allocate(ZipConstants.INT_BYTE_SIZE).order(ByteOrder.LITTLE_ENDIAN);
- + private final ByteBuffer shortBuffer =
- + ByteBuffer.allocate(ZipConstants.SHORT_BYTE_SIZE).order(ByteOrder.LITTLE_ENDIAN);
-
- - return entry;
- - }
- + private ZipParser(ByteBufferChannel archive) {
- + this.archive = archive;
- + }
-
- - /** Walks through all recorded entries and records the offsets for the entry data. */
- - private Map<String, List<ZipEntry>> parseLocalFileHeaderData(List<ZipEntry> entries) {
- - /** Maps String to list of ZipEntrys, name -> actual entries. */
- - Map<String, List<ZipEntry>> nameMap = new LinkedHashMap<>();
- -
- - for (ZipEntry entry : entries) {
- - long offset = entry.getLocalHeaderOffset();
- - archive.position(offset + ZipConstants.LOCNAM);
- -
- - // Gets the data offset of this entry.
- - int fileNameLen = getShort();
- - int extraFieldLen = getShort();
- - long dataOffset =
- - offset
- - + ZipConstants.LOCEXT
- - + ZipConstants.SHORT_BYTE_SIZE
- - + fileNameLen
- - + extraFieldLen;
- - entry.setDataOffset(dataOffset);
- -
- - // Puts the entry into the nameMap.
- - String name = entry.getName();
- - List<ZipEntry> entriesWithTheSameName;
- - if (nameMap.containsKey(name)) {
- - entriesWithTheSameName = nameMap.get(name);
- - } else {
- - entriesWithTheSameName = new ArrayList<>();
- - nameMap.put(name, entriesWithTheSameName);
- + /**
- + * Parses the underlying {@code archive} and returns the information as a list of {@link
- + * ZipEntry}.
- + */
- + private Map<String, List<ZipEntry>> parseEntries() throws IOException {
- + List<ZipEntry> entries = parseCentralDirectory();
- + return parseLocalFileHeaderData(entries);
- }
- - entriesWithTheSameName.add(entry);
- - }
-
- - return nameMap;
- - }
- + /**
- + * Checks if the current position contains a central file header signature, {@link
- + * ZipConstants#CENSIG}.
- + */
- + private boolean foundCentralFileheaderSignature() {
- + long signature = (long) getInt();
- + return signature == ZipConstants.CENSIG;
- + }
-
- - /** Skips the given number of bytes or throws an EOFException if skipping failed. */
- - private void skipBytes(int count) throws IOException {
- - long currentPosition = archive.position();
- - long newPosition = currentPosition + count;
- - if (newPosition > archive.size()) {
- - throw new EOFException();
- - }
- - archive.position(newPosition);
- - }
- - }
- + /**
- + * Gets the value as a Java int from two bytes starting at the current position of the
- + * archive.
- + */
- + private int getShort() {
- + shortBuffer.rewind();
- + archive.read(shortBuffer);
- + shortBuffer.flip();
- + return (int) shortBuffer.getShort();
- + }
-
- - /** Stores the data offset and the size of an entry in the archive. */
- - private static class ZipEntry {
- + /**
- + * Gets the value as a Java long from four bytes starting at the current position of the
- + * archive.
- + */
- + private int getInt() {
- + intBuffer.rewind();
- + archive.read(intBuffer);
- + intBuffer.flip();
- + return intBuffer.getInt();
- + }
-
- - private String name;
- - private long dataOffset = -1;
- - private long size = -1;
- - private long localHeaderOffset = -1;
- + /**
- + * Gets the value as a Java long from four bytes starting at the current position of the
- + * archive.
- + */
- + private long getLong() {
- + longBuffer.rewind();
- + archive.read(longBuffer);
- + longBuffer.flip();
- + return longBuffer.getLong();
- + }
-
- - public long getSize() {
- - return size;
- - }
- + /**
- + * Positions the archive at the start of the central directory.
- + *
- + * <p>First, it searches for the signature of the "end of central directory record", {@link
- + * ZipConstants#ENDSIG}. Position the stream at the start of the "end of central directory
- + * record". The zip file are created without archive comments, thus {@link
- + * ZipConstants#ENDSIG} should appear exactly at {@link ZipConstants#ENDHDR} from the end of
- + * the zip file.
- + *
- + * <p>Then, parse the "end of central dir record" and position the archive at the start of
- + * the central directory.
- + */
- + private void locateCentralDirectory() throws IOException {
- + if (archive.size() < ZipConstants.ENDHDR) {
- + throw new ZipException("The archive is not a ZIP archive.");
- + }
- +
- + // Positions the archive at the start of the "end of central directory record".
- + long offsetRecord = archive.size() - ZipConstants.ENDHDR;
- + archive.position(offsetRecord);
- +
- + // Checks for the signature, {@link ZipConstants#ENDSIG}.
- + long endSig = getLong();
- + if (endSig != ZipConstants.ENDSIG) {
- + throw new ZipException("The archive is not a ZIP archive.");
- + }
- +
- + // Positions the archive at the “offset of central directory”.
- + skipBytes(ZipConstants.ENDOFF - ZipConstants.ENDSUB);
- + // Gets the offset to central directory
- + long offsetDirectory = getInt();
- + // Goes to the central directory.
- + archive.position(offsetDirectory);
- + }
-
- - public long getDataOffset() {
- - return dataOffset;
- - }
- + /**
- + * Reads the central directory of the given archive and populates the internal tables with
- + * {@link ZipEntry} instances.
- + */
- + private List<ZipEntry> parseCentralDirectory() throws IOException {
- + /** List of entries in the order they appear inside the central directory. */
- + List<ZipEntry> entries = new ArrayList<>();
- + locateCentralDirectory();
- +
- + while (foundCentralFileheaderSignature()) {
- + ZipEntry entry = parseCentralDirectoryEntry();
- + entries.add(entry);
- + }
- +
- + return entries;
- + }
-
- - public String getName() {
- - return name;
- - }
- + /**
- + * Reads an individual entry of the central directory, creats an ZipEntry from it and adds
- + * it to the global maps.
- + */
- + private ZipEntry parseCentralDirectoryEntry() throws IOException {
- + // Positions the archive at the "compressed size" and read the value.
- + skipBytes(ZipConstants.CENSIZ - ZipConstants.CENVEM);
- + long compressSize = getInt();
- +
- + // Positions the archive at the "filename length" and read the value.
- + skipBytes(ZipConstants.CENNAM - ZipConstants.CENLEN);
- + int fileNameLen = getShort();
- +
- + // Reads the extra field length and the comment length.
- + int extraLen = getShort();
- + int commentLen = getShort();
- +
- + // Positions the archive at the "local file header offset" and read the value.
- + skipBytes(ZipConstants.CENOFF - ZipConstants.CENDSK);
- + long localHeaderOffset = getInt();
- +
- + // Reads the file name.
- + byte[] fileNameBuf = new byte[fileNameLen];
- + archive.read(ByteBuffer.wrap(fileNameBuf));
- + String fileName = new String(fileNameBuf, Charset.forName("UTF-8"));
- +
- + // Skips the extra field and the comment.
- + skipBytes(extraLen + commentLen);
- +
- + ZipEntry entry = new ZipEntry();
- + entry.setSize(compressSize);
- + entry.setLocalHeaderOffset(localHeaderOffset);
- + entry.setName(fileName);
- +
- + return entry;
- + }
-
- - public long getLocalHeaderOffset() {
- - return localHeaderOffset;
- - }
- + /** Walks through all recorded entries and records the offsets for the entry data. */
- + private Map<String, List<ZipEntry>> parseLocalFileHeaderData(List<ZipEntry> entries) {
- + /** Maps String to list of ZipEntrys, name -> actual entries. */
- + Map<String, List<ZipEntry>> nameMap = new LinkedHashMap<>();
- +
- + for (ZipEntry entry : entries) {
- + long offset = entry.getLocalHeaderOffset();
- + archive.position(offset + ZipConstants.LOCNAM);
- +
- + // Gets the data offset of this entry.
- + int fileNameLen = getShort();
- + int extraFieldLen = getShort();
- + long dataOffset = offset + ZipConstants.LOCEXT + ZipConstants.SHORT_BYTE_SIZE
- + + fileNameLen + extraFieldLen;
- + entry.setDataOffset(dataOffset);
- +
- + // Puts the entry into the nameMap.
- + String name = entry.getName();
- + List<ZipEntry> entriesWithTheSameName;
- + if (nameMap.containsKey(name)) {
- + entriesWithTheSameName = nameMap.get(name);
- + } else {
- + entriesWithTheSameName = new ArrayList<>();
- + nameMap.put(name, entriesWithTheSameName);
- + }
- + entriesWithTheSameName.add(entry);
- + }
- +
- + return nameMap;
- + }
-
- - public void setSize(long size) {
- - this.size = size;
- + /** Skips the given number of bytes or throws an EOFException if skipping failed. */
- + private void skipBytes(int count) throws IOException {
- + long currentPosition = archive.position();
- + long newPosition = currentPosition + count;
- + if (newPosition > archive.size()) {
- + throw new EOFException();
- + }
- + archive.position(newPosition);
- + }
- }
-
- - public void setDataOffset(long dataOffset) {
- - this.dataOffset = dataOffset;
- - }
- + /** Stores the data offset and the size of an entry in the archive. */
- + private static class ZipEntry {
- + private String name;
- + private long dataOffset = -1;
- + private long size = -1;
- + private long localHeaderOffset = -1;
-
- - public void setName(String name) {
- - this.name = name;
- - }
- + public long getSize() {
- + return size;
- + }
-
- - public void setLocalHeaderOffset(long localHeaderOffset) {
- - this.localHeaderOffset = localHeaderOffset;
- - }
- - }
- + public long getDataOffset() {
- + return dataOffset;
- + }
-
- - /**
- - * Various constants for this {@link ZipFile}.
- - *
- - * <p>Referenced from {@link java.util.zip.ZipConstants}.
- - */
- - private static class ZipConstants {
- - /** length of Java short in bytes. */
- - static final int SHORT_BYTE_SIZE = Short.SIZE / 8;
- + public String getName() {
- + return name;
- + }
-
- - /** length of Java int in bytes. */
- - static final int INT_BYTE_SIZE = Integer.SIZE / 8;
- + public long getLocalHeaderOffset() {
- + return localHeaderOffset;
- + }
-
- - /** length of Java long in bytes. */
- - static final int LONG_BYTE_SIZE = Long.SIZE / 8;
- + public void setSize(long size) {
- + this.size = size;
- + }
-
- - /*
- - * Header signatures
- - */
- - static final long LOCSIG = 0x04034b50L; // "PK\003\004"
- - static final long EXTSIG = 0x08074b50L; // "PK\007\008"
- - static final long CENSIG = 0x02014b50L; // "PK\001\002"
- - static final long ENDSIG = 0x06054b50L; // "PK\005\006"
- + public void setDataOffset(long dataOffset) {
- + this.dataOffset = dataOffset;
- + }
-
- - /*
- - * Header sizes in bytes (including signatures)
- - */
- - static final int LOCHDR = 30; // LOC header size
- - static final int EXTHDR = 16; // EXT header size
- - static final int CENHDR = 46; // CEN header size
- - static final int ENDHDR = 22; // END header size
- + public void setName(String name) {
- + this.name = name;
- + }
-
- - /*
- - * Local file (LOC) header field offsets
- - */
- - static final int LOCVER = 4; // version needed to extract
- - static final int LOCFLG = 6; // general purpose bit flag
- - static final int LOCHOW = 8; // compression method
- - static final int LOCTIM = 10; // modification time
- - static final int LOCCRC = 14; // uncompressed file crc-32 value
- - static final int LOCSIZ = 18; // compressed size
- - static final int LOCLEN = 22; // uncompressed size
- - static final int LOCNAM = 26; // filename length
- - static final int LOCEXT = 28; // extra field length
- -
- - /*
- - * Extra local (EXT) header field offsets
- - */
- - static final int EXTCRC = 4; // uncompressed file crc-32 value
- - static final int EXTSIZ = 8; // compressed size
- - static final int EXTLEN = 12; // uncompressed size
- + public void setLocalHeaderOffset(long localHeaderOffset) {
- + this.localHeaderOffset = localHeaderOffset;
- + }
- + }
-
- - /*
- - * Central directory (CEN) header field offsets
- - */
- - static final int CENVEM = 4; // version made by
- - static final int CENVER = 6; // version needed to extract
- - static final int CENFLG = 8; // encrypt, decrypt flags
- - static final int CENHOW = 10; // compression method
- - static final int CENTIM = 12; // modification time
- - static final int CENCRC = 16; // uncompressed file crc-32 value
- - static final int CENSIZ = 20; // compressed size
- - static final int CENLEN = 24; // uncompressed size
- - static final int CENNAM = 28; // filename length
- - static final int CENEXT = 30; // extra field length
- - static final int CENCOM = 32; // comment length
- - static final int CENDSK = 34; // disk number start
- - static final int CENATT = 36; // internal file attributes
- - static final int CENATX = 38; // external file attributes
- - static final int CENOFF = 42; // LOC header offset
- -
- - /*
- - * End of central directory (END) header field offsets
- + /**
- + * Various constants for this {@link ZipFile}.
- + *
- + * <p>Referenced from {@link java.util.zip.ZipConstants}.
- */
- - static final int ENDSUB = 8; // number of entries on this disk
- - static final int ENDTOT = 10; // total number of entries
- - static final int ENDSIZ = 12; // central directory size in bytes
- - static final int ENDOFF = 16; // offset of first CEN header
- - static final int ENDCOM = 20; // zip file comment length
- -
- - private ZipConstants() {}
- - }
- + private static class ZipConstants {
- + /** length of Java short in bytes. */
- + static final int SHORT_BYTE_SIZE = Short.SIZE / 8;
- +
- + /** length of Java int in bytes. */
- + static final int INT_BYTE_SIZE = Integer.SIZE / 8;
- +
- + /** length of Java long in bytes. */
- + static final int LONG_BYTE_SIZE = Long.SIZE / 8;
- +
- + /*
- + * Header signatures
- + */
- + static final long LOCSIG = 0x04034b50L; // "PK\003\004"
- + static final long EXTSIG = 0x08074b50L; // "PK\007\008"
- + static final long CENSIG = 0x02014b50L; // "PK\001\002"
- + static final long ENDSIG = 0x06054b50L; // "PK\005\006"
- +
- + /*
- + * Header sizes in bytes (including signatures)
- + */
- + static final int LOCHDR = 30; // LOC header size
- + static final int EXTHDR = 16; // EXT header size
- + static final int CENHDR = 46; // CEN header size
- + static final int ENDHDR = 22; // END header size
- +
- + /*
- + * Local file (LOC) header field offsets
- + */
- + static final int LOCVER = 4; // version needed to extract
- + static final int LOCFLG = 6; // general purpose bit flag
- + static final int LOCHOW = 8; // compression method
- + static final int LOCTIM = 10; // modification time
- + static final int LOCCRC = 14; // uncompressed file crc-32 value
- + static final int LOCSIZ = 18; // compressed size
- + static final int LOCLEN = 22; // uncompressed size
- + static final int LOCNAM = 26; // filename length
- + static final int LOCEXT = 28; // extra field length
- +
- + /*
- + * Extra local (EXT) header field offsets
- + */
- + static final int EXTCRC = 4; // uncompressed file crc-32 value
- + static final int EXTSIZ = 8; // compressed size
- + static final int EXTLEN = 12; // uncompressed size
- +
- + /*
- + * Central directory (CEN) header field offsets
- + */
- + static final int CENVEM = 4; // version made by
- + static final int CENVER = 6; // version needed to extract
- + static final int CENFLG = 8; // encrypt, decrypt flags
- + static final int CENHOW = 10; // compression method
- + static final int CENTIM = 12; // modification time
- + static final int CENCRC = 16; // uncompressed file crc-32 value
- + static final int CENSIZ = 20; // compressed size
- + static final int CENLEN = 24; // uncompressed size
- + static final int CENNAM = 28; // filename length
- + static final int CENEXT = 30; // extra field length
- + static final int CENCOM = 32; // comment length
- + static final int CENDSK = 34; // disk number start
- + static final int CENATT = 36; // internal file attributes
- + static final int CENATX = 38; // external file attributes
- + static final int CENOFF = 42; // LOC header offset
- +
- + /*
- + * End of central directory (END) header field offsets
- + */
- + static final int ENDSUB = 8; // number of entries on this disk
- + static final int ENDTOT = 10; // total number of entries
- + static final int ENDSIZ = 12; // central directory size in bytes
- + static final int ENDOFF = 16; // offset of first CEN header
- + static final int ENDCOM = 20; // zip file comment length
- +
- + private ZipConstants() {}
- + }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/BoundedInputStreamTest.java b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/BoundedInputStreamTest.java
- index 3847bc1d2ce01..e0825a1fe7862 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/BoundedInputStreamTest.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/BoundedInputStreamTest.java
- @@ -16,244 +16,223 @@ limitations under the License.
- package org.tensorflow.lite.support.metadata;
-
- import static com.google.common.truth.Truth.assertThat;
- +
- import static org.junit.Assert.assertArrayEquals;
- import static org.junit.Assert.assertThrows;
-
- -import java.nio.ByteBuffer;
- import org.junit.Test;
- import org.junit.runner.RunWith;
- import org.robolectric.RobolectricTestRunner;
-
- +import java.nio.ByteBuffer;
- +
- /** Tests of {@link BoundedInputStream}. */
- @RunWith(RobolectricTestRunner.class)
- public class BoundedInputStreamTest {
- + private static final byte[] testBytes = new byte[] {10, 20, 30, 40, 50};
- + private static final int[] testInts = new int[] {10, 20, 30, 40, 50};
- + private static final int TEST_BYTES_LENGTH = testBytes.length;
- +
- + @Test
- + public void boundedInputStream_negtiveStart_throwsException() throws Exception {
- + long start = -1;
- + long remaining = 2;
- + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
- + () -> createBoundedInputStream(testBytes, start, remaining));
- + assertThat(exception).hasMessageThat().isEqualTo(String.format(
- + "Invalid length of stream at offset=%d, length=%d", start, remaining));
- + }
- +
- + @Test
- + public void boundedInputStream_negtiveRemaining_throwsException() throws Exception {
- + long start = 1;
- + long remaining = -2;
- + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
- + () -> createBoundedInputStream(testBytes, start, remaining));
- + assertThat(exception).hasMessageThat().isEqualTo(String.format(
- + "Invalid length of stream at offset=%d, length=%d", start, remaining));
- + }
- +
- + @Test
- + public void available_atStart() throws Exception {
- + int start = 3;
- + BoundedInputStream boundedInputStream =
- + createBoundedInputStream(testBytes, start, TEST_BYTES_LENGTH);
- +
- + int available = boundedInputStream.available();
- + assertThat(available).isEqualTo(TEST_BYTES_LENGTH - start);
- + }
- +
- + @Test
- + public void available_afterRead() throws Exception {
- + BoundedInputStream boundedInputStream =
- + createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH);
- + // Read a byte out of boundedInputStream. The number of remaining bytes is TEST_BYTES_LENGTH
- + // -1.
- + boundedInputStream.read();
- +
- + int available = boundedInputStream.available();
- + assertThat(available).isEqualTo(TEST_BYTES_LENGTH - 1);
- + }
- +
- + @Test
- + public void read_repeatedRead() throws Exception {
- + int[] values = new int[TEST_BYTES_LENGTH];
- + BoundedInputStream boundedInputStream =
- + createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH);
- +
- + for (int i = 0; i < TEST_BYTES_LENGTH; i++) {
- + values[i] = boundedInputStream.read();
- + }
- +
- + assertArrayEquals(testInts, values);
- + }
- +
- + @Test
- + public void read_reachTheEnd() throws Exception {
- + BoundedInputStream boundedInputStream =
- + createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH);
- + boundedInputStream.skip(TEST_BYTES_LENGTH);
- + int value = boundedInputStream.read();
- +
- + assertThat(value).isEqualTo(-1);
- + }
- +
- + @Test
- + public void read_channelSizeisSmallerThanTheStreamSpecified() throws Exception {
- + BoundedInputStream boundedInputStream =
- + createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH + 1);
- + boundedInputStream.skip(TEST_BYTES_LENGTH);
- +
- + int value = boundedInputStream.read();
- +
- + assertThat(value).isEqualTo(-1);
- + }
-
- - private static final byte[] testBytes = new byte[] {10, 20, 30, 40, 50};
- - private static final int[] testInts = new int[] {10, 20, 30, 40, 50};
- - private static final int TEST_BYTES_LENGTH = testBytes.length;
- -
- - @Test
- - public void boundedInputStream_negtiveStart_throwsException() throws Exception {
- - long start = -1;
- - long remaining = 2;
- - IllegalArgumentException exception =
- - assertThrows(
- - IllegalArgumentException.class,
- - () -> createBoundedInputStream(testBytes, start, remaining));
- - assertThat(exception)
- - .hasMessageThat()
- - .isEqualTo(
- - String.format("Invalid length of stream at offset=%d, length=%d", start, remaining));
- - }
- -
- - @Test
- - public void boundedInputStream_negtiveRemaining_throwsException() throws Exception {
- - long start = 1;
- - long remaining = -2;
- - IllegalArgumentException exception =
- - assertThrows(
- - IllegalArgumentException.class,
- - () -> createBoundedInputStream(testBytes, start, remaining));
- - assertThat(exception)
- - .hasMessageThat()
- - .isEqualTo(
- - String.format("Invalid length of stream at offset=%d, length=%d", start, remaining));
- - }
- -
- - @Test
- - public void available_atStart() throws Exception {
- - int start = 3;
- - BoundedInputStream boundedInputStream =
- - createBoundedInputStream(testBytes, start, TEST_BYTES_LENGTH);
- -
- - int available = boundedInputStream.available();
- - assertThat(available).isEqualTo(TEST_BYTES_LENGTH - start);
- - }
- -
- - @Test
- - public void available_afterRead() throws Exception {
- - BoundedInputStream boundedInputStream =
- - createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH);
- - // Read a byte out of boundedInputStream. The number of remaining bytes is TEST_BYTES_LENGTH -1.
- - boundedInputStream.read();
- -
- - int available = boundedInputStream.available();
- - assertThat(available).isEqualTo(TEST_BYTES_LENGTH - 1);
- - }
- -
- - @Test
- - public void read_repeatedRead() throws Exception {
- - int[] values = new int[TEST_BYTES_LENGTH];
- - BoundedInputStream boundedInputStream =
- - createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH);
- -
- - for (int i = 0; i < TEST_BYTES_LENGTH; i++) {
- - values[i] = boundedInputStream.read();
- + @Test
- + public void readArray_nullArray_throwsException() throws Exception {
- + byte[] array = null;
- + int offset = 0;
- + int length = 1;
- + BoundedInputStream boundedInputStream =
- + createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH);
- +
- + NullPointerException exception = assertThrows(
- + NullPointerException.class, () -> boundedInputStream.read(array, offset, length));
- + assertThat(exception).hasMessageThat().isEqualTo("The object reference is null.");
- }
-
- - assertArrayEquals(testInts, values);
- - }
- -
- - @Test
- - public void read_reachTheEnd() throws Exception {
- - BoundedInputStream boundedInputStream =
- - createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH);
- - boundedInputStream.skip(TEST_BYTES_LENGTH);
- - int value = boundedInputStream.read();
- -
- - assertThat(value).isEqualTo(-1);
- - }
- -
- - @Test
- - public void read_channelSizeisSmallerThanTheStreamSpecified() throws Exception {
- - BoundedInputStream boundedInputStream =
- - createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH + 1);
- - boundedInputStream.skip(TEST_BYTES_LENGTH);
- -
- - int value = boundedInputStream.read();
- -
- - assertThat(value).isEqualTo(-1);
- - }
- -
- - @Test
- - public void readArray_nullArray_throwsException() throws Exception {
- - byte[] array = null;
- - int offset = 0;
- - int length = 1;
- - BoundedInputStream boundedInputStream =
- - createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH);
- -
- - NullPointerException exception =
- - assertThrows(
- - NullPointerException.class, () -> boundedInputStream.read(array, offset, length));
- - assertThat(exception).hasMessageThat().isEqualTo("The object reference is null.");
- - }
- -
- - @Test
- - public void readArray_negativeOffset_throwsException() throws Exception {
- - byte[] array = new byte[5];
- - int offset = -1;
- - int length = array.length;
- - BoundedInputStream boundedInputStream =
- - createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH);
- -
- - IndexOutOfBoundsException exception =
- - assertThrows(
- - IndexOutOfBoundsException.class, () -> boundedInputStream.read(array, offset, length));
- - assertThat(exception)
- - .hasMessageThat()
- - .isEqualTo(String.format("The start offset (%s) must not be negative", offset));
- - }
- -
- - @Test
- - public void readArray_OffsetEqualsArrayLength_throwsException() throws Exception {
- - byte[] array = new byte[5];
- - int offset = array.length;
- - int length = 0;
- - BoundedInputStream boundedInputStream =
- - createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH);
- -
- - IndexOutOfBoundsException exception =
- - assertThrows(
- - IndexOutOfBoundsException.class, () -> boundedInputStream.read(array, offset, length));
- - assertThat(exception)
- - .hasMessageThat()
- - .isEqualTo(
- - String.format(
- + @Test
- + public void readArray_negativeOffset_throwsException() throws Exception {
- + byte[] array = new byte[5];
- + int offset = -1;
- + int length = array.length;
- + BoundedInputStream boundedInputStream =
- + createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH);
- +
- + IndexOutOfBoundsException exception = assertThrows(IndexOutOfBoundsException.class,
- + () -> boundedInputStream.read(array, offset, length));
- + assertThat(exception).hasMessageThat().isEqualTo(
- + String.format("The start offset (%s) must not be negative", offset));
- + }
- +
- + @Test
- + public void readArray_OffsetEqualsArrayLength_throwsException() throws Exception {
- + byte[] array = new byte[5];
- + int offset = array.length;
- + int length = 0;
- + BoundedInputStream boundedInputStream =
- + createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH);
- +
- + IndexOutOfBoundsException exception = assertThrows(IndexOutOfBoundsException.class,
- + () -> boundedInputStream.read(array, offset, length));
- + assertThat(exception).hasMessageThat().isEqualTo(String.format(
- "The start offset (%s) must be less than size (%s)", offset, array.length));
- - }
- -
- - @Test
- - public void readArray_negativeLength_throwsException() throws Exception {
- - byte[] array = new byte[5];
- - int offset = 0;
- - int length = -1;
- - BoundedInputStream boundedInputStream =
- - createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH);
- -
- - IndexOutOfBoundsException exception =
- - assertThrows(
- - IndexOutOfBoundsException.class, () -> boundedInputStream.read(array, offset, length));
- - assertThat(exception)
- - .hasMessageThat()
- - .isEqualTo(
- - String.format(
- + }
- +
- + @Test
- + public void readArray_negativeLength_throwsException() throws Exception {
- + byte[] array = new byte[5];
- + int offset = 0;
- + int length = -1;
- + BoundedInputStream boundedInputStream =
- + createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH);
- +
- + IndexOutOfBoundsException exception = assertThrows(IndexOutOfBoundsException.class,
- + () -> boundedInputStream.read(array, offset, length));
- + assertThat(exception).hasMessageThat().isEqualTo(String.format(
- "The maximumn number of bytes to read (%s) must not be negative", length));
- - }
- -
- - @Test
- - public void readArray_exceedEndOfArray_throwsException() throws Exception {
- - byte[] array = new byte[5];
- - int offset = 0;
- - int length = array.length + 1;
- - BoundedInputStream boundedInputStream =
- - createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH);
- -
- - IndexOutOfBoundsException exception =
- - assertThrows(
- - IndexOutOfBoundsException.class, () -> boundedInputStream.read(array, offset, length));
- - assertThat(exception)
- - .hasMessageThat()
- - .isEqualTo(
- - String.format(
- - "The maximumn number of bytes to read (%s) must be less than size (%s)",
- - length, array.length - offset + 1));
- - }
- -
- - @Test
- - public void readArray_zeroLength() throws Exception {
- - byte[] array = new byte[5];
- - int offset = 0;
- - int length = 0;
- - BoundedInputStream boundedInputStream =
- - createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH);
- -
- - int value = boundedInputStream.read(array, offset, length);
- - assertThat(value).isEqualTo(0);
- - }
- -
- - @Test
- - public void readArray_exceedEndOfStream() throws Exception {
- - byte[] array = new byte[5];
- - int offset = 0;
- - int length = 1;
- - BoundedInputStream boundedInputStream =
- - createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH);
- -
- - // Move the position of the stream to the end.
- - boundedInputStream.skip(TEST_BYTES_LENGTH);
- -
- - int value = boundedInputStream.read(array, offset, length);
- -
- - assertThat(value).isEqualTo(-1);
- - }
- -
- - @Test
- - public void readArray_lengthGreaterThanStreamRemaining() throws Exception {
- - byte[] array = new byte[5];
- - int offset = 1;
- - int length = array.length - 1; // 4
- - BoundedInputStream boundedInputStream =
- - createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH);
- -
- - // Moves the position of the stream to end-2.
- - boundedInputStream.skip(TEST_BYTES_LENGTH - 2);
- -
- - // Reads the last two bytes of the stream to the array, and put the data at offset 1.
- - int value = boundedInputStream.read(array, offset, length);
- -
- - byte[] expectedArray = new byte[] {0, 40, 50, 0, 0};
- - assertArrayEquals(expectedArray, array);
- - assertThat(value).isEqualTo(2);
- -
- - // Reachs the end of the stream, thus cannot read anymore.
- - assertThat(boundedInputStream.read()).isEqualTo(-1);
- - }
- -
- - private static BoundedInputStream createBoundedInputStream(
- - final byte[] testBytes, long start, long remaining) {
- - ByteBuffer buffer = ByteBuffer.wrap(testBytes);
- - SeekableByteChannelCompat channel = new ByteBufferChannel(buffer);
- - return new BoundedInputStream(channel, start, remaining);
- - }
- + }
- +
- + @Test
- + public void readArray_exceedEndOfArray_throwsException() throws Exception {
- + byte[] array = new byte[5];
- + int offset = 0;
- + int length = array.length + 1;
- + BoundedInputStream boundedInputStream =
- + createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH);
- +
- + IndexOutOfBoundsException exception = assertThrows(IndexOutOfBoundsException.class,
- + () -> boundedInputStream.read(array, offset, length));
- + assertThat(exception).hasMessageThat().isEqualTo(String.format(
- + "The maximumn number of bytes to read (%s) must be less than size (%s)", length,
- + array.length - offset + 1));
- + }
- +
- + @Test
- + public void readArray_zeroLength() throws Exception {
- + byte[] array = new byte[5];
- + int offset = 0;
- + int length = 0;
- + BoundedInputStream boundedInputStream =
- + createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH);
- +
- + int value = boundedInputStream.read(array, offset, length);
- + assertThat(value).isEqualTo(0);
- + }
- +
- + @Test
- + public void readArray_exceedEndOfStream() throws Exception {
- + byte[] array = new byte[5];
- + int offset = 0;
- + int length = 1;
- + BoundedInputStream boundedInputStream =
- + createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH);
- +
- + // Move the position of the stream to the end.
- + boundedInputStream.skip(TEST_BYTES_LENGTH);
- +
- + int value = boundedInputStream.read(array, offset, length);
- +
- + assertThat(value).isEqualTo(-1);
- + }
- +
- + @Test
- + public void readArray_lengthGreaterThanStreamRemaining() throws Exception {
- + byte[] array = new byte[5];
- + int offset = 1;
- + int length = array.length - 1; // 4
- + BoundedInputStream boundedInputStream =
- + createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH);
- +
- + // Moves the position of the stream to end-2.
- + boundedInputStream.skip(TEST_BYTES_LENGTH - 2);
- +
- + // Reads the last two bytes of the stream to the array, and put the data at offset 1.
- + int value = boundedInputStream.read(array, offset, length);
- +
- + byte[] expectedArray = new byte[] {0, 40, 50, 0, 0};
- + assertArrayEquals(expectedArray, array);
- + assertThat(value).isEqualTo(2);
- +
- + // Reachs the end of the stream, thus cannot read anymore.
- + assertThat(boundedInputStream.read()).isEqualTo(-1);
- + }
- +
- + private static BoundedInputStream createBoundedInputStream(
- + final byte[] testBytes, long start, long remaining) {
- + ByteBuffer buffer = ByteBuffer.wrap(testBytes);
- + SeekableByteChannelCompat channel = new ByteBufferChannel(buffer);
- + return new BoundedInputStream(channel, start, remaining);
- + }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/ByteBufferChannelTest.java b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/ByteBufferChannelTest.java
- index abda43058aa90..ce625de8034b7 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/ByteBufferChannelTest.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/ByteBufferChannelTest.java
- @@ -16,254 +16,252 @@ limitations under the License.
- package org.tensorflow.lite.support.metadata;
-
- import static com.google.common.truth.Truth.assertThat;
- -import static java.nio.charset.StandardCharsets.UTF_8;
- +
- import static org.junit.Assert.assertThrows;
-
- -import java.nio.ByteBuffer;
- +import static java.nio.charset.StandardCharsets.UTF_8;
- +
- import org.junit.Test;
- import org.junit.runner.RunWith;
- import org.robolectric.RobolectricTestRunner;
-
- +import java.nio.ByteBuffer;
- +
- /** Tests of {@link ByteBufferChannel}. */
- @RunWith(RobolectricTestRunner.class)
- public final class ByteBufferChannelTest {
- - private static final String VALID_STRING = "1234567890";
- - private final ByteBuffer validByteBuffer = ByteBuffer.wrap(VALID_STRING.getBytes(UTF_8));
- - private final int validByteBufferLength = validByteBuffer.limit();
- -
- - @Test
- - public void byteBufferChannel_validByteBuffer() {
- - ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
- - assertThat(byteBufferChannel).isNotNull();
- - }
- -
- - @Test
- - public void byteBufferChannel_nullByteBuffer_throwsException() {
- - NullPointerException exception =
- - assertThrows(NullPointerException.class, () -> new ByteBufferChannel(/*buffer=*/ null));
- - assertThat(exception).hasMessageThat().isEqualTo("The ByteBuffer cannot be null.");
- - }
- -
- - @Test
- - public void isOpen_openedByteBufferChannel() {
- - ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
- - assertThat(byteBufferChannel.isOpen()).isTrue();
- - }
- -
- - @Test
- - public void position_newByteBufferChannelWithInitialPosition0() {
- - ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
- - long position = byteBufferChannel.position();
- -
- - long expectedPosition = 0;
- - assertThat(position).isEqualTo(expectedPosition);
- - }
- -
- - @Test
- - public void position_validNewPosition() {
- - ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
- - long validNewPosition = 6;
- -
- - byteBufferChannel.position(validNewPosition);
- - long position = byteBufferChannel.position();
- -
- - assertThat(position).isEqualTo(validNewPosition);
- - }
- -
- - @Test
- - public void position_negtiveNewPosition_throwsException() {
- - ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
- - long invalidNewPosition = -1;
- -
- - IllegalArgumentException exception =
- - assertThrows(
- - IllegalArgumentException.class, () -> byteBufferChannel.position(invalidNewPosition));
- - assertThat(exception)
- - .hasMessageThat()
- - .isEqualTo("The new position should be non-negative and be less than Integer.MAX_VALUE.");
- - }
- -
- - @Test
- - public void position_newPositionGreaterThanMaxIntegerValue_throwsException() {
- - ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
- - long invalidNewPosition = Integer.MAX_VALUE + 1;
- -
- - IllegalArgumentException exception =
- - assertThrows(
- - IllegalArgumentException.class, () -> byteBufferChannel.position(invalidNewPosition));
- - assertThat(exception)
- - .hasMessageThat()
- - .isEqualTo("The new position should be non-negative and be less than Integer.MAX_VALUE.");
- - }
- -
- - @Test
- - public void position_newPositionGreaterThanByfferLength_throwsException() {
- - ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
- - long invalidNewPosition = (long) validByteBufferLength + 1;
- -
- - IllegalArgumentException exception =
- - assertThrows(
- - IllegalArgumentException.class, () -> byteBufferChannel.position(invalidNewPosition));
- - assertThat(exception).hasMessageThat().isEqualTo("newPosition > limit: (11 > 10)");
- - }
- -
- - @Test
- - public void read_fromPosition0() {
- - ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
- - long validNewPosition = 0;
- -
- - byteBufferChannel.position(validNewPosition);
- - ByteBuffer dstBuffer = ByteBuffer.allocate(validByteBufferLength);
- - int numBytes = byteBufferChannel.read(dstBuffer);
- -
- - assertThat(numBytes).isEqualTo(validByteBufferLength);
- - assertThat(dstBuffer).isEqualTo(validByteBuffer);
- - }
- -
- - @Test
- - public void read_fromPosition5() {
- - ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
- - long validNewPosition = 5;
- -
- - byteBufferChannel.position(validNewPosition);
- - ByteBuffer dstBuffer = ByteBuffer.allocate(validByteBufferLength);
- - int numBytes = byteBufferChannel.read(dstBuffer);
- -
- - assertThat(numBytes).isEqualTo(validByteBufferLength - (int) validNewPosition);
- - String dstString = convertByteByfferToString(dstBuffer, numBytes);
- - String expectedString = "67890";
- - assertThat(dstString).isEqualTo(expectedString);
- - }
- -
- - @Test
- - public void read_fromPositionValidByteBufferLength() {
- - ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
- - long validNewPosition = validByteBufferLength;
- -
- - byteBufferChannel.position(validNewPosition);
- - ByteBuffer dstBuffer = ByteBuffer.allocate(validByteBufferLength);
- - int numBytes = byteBufferChannel.read(dstBuffer);
- -
- - assertThat(numBytes).isEqualTo(-1);
- - }
- -
- - @Test
- - public void read_dstBufferRemaining0() {
- - ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
- - long validNewPosition = 0;
- -
- - byteBufferChannel.position(validNewPosition);
- - ByteBuffer dstBuffer = ByteBuffer.allocate(validByteBufferLength);
- - dstBuffer.position(validByteBufferLength);
- - int numBytes = byteBufferChannel.read(dstBuffer);
- -
- - assertThat(numBytes).isEqualTo(0);
- - String dstString = convertByteByfferToString(dstBuffer, numBytes);
- - String expectedString = "";
- - assertThat(dstString).isEqualTo(expectedString);
- - }
- -
- - @Test
- - public void read_dstBufferIsSmallerThanTheBufferChannel() {
- - ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
- - int dstBufferLength = 3;
- -
- - ByteBuffer dstBuffer = ByteBuffer.allocate(dstBufferLength);
- - int numBytes = byteBufferChannel.read(dstBuffer);
- -
- - assertThat(numBytes).isEqualTo(dstBufferLength);
- - assertThat(validByteBuffer.position()).isEqualTo(dstBufferLength);
- -
- - String dstString = convertByteByfferToString(dstBuffer, dstBufferLength);
- - String expectedString = "123";
- - assertThat(dstString).isEqualTo(expectedString);
- - }
- -
- - @Test
- - public void size_validBuffer() {
- - ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
- - assertThat(byteBufferChannel.size()).isEqualTo(validByteBufferLength);
- - }
- -
- - @Test
- - public void truncate_validSizeAndPosition0() {
- - ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
- - long truncateSize = 3;
- -
- - byteBufferChannel.truncate(truncateSize);
- -
- - assertThat(byteBufferChannel.size()).isEqualTo(truncateSize);
- - assertThat(byteBufferChannel.position()).isEqualTo(0);
- - }
- -
- - @Test
- - public void truncate_validSizeAndPosition5() {
- - ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
- - long validNewPosition = 5;
- -
- - byteBufferChannel.position(validNewPosition);
- - long truncateSize = 3;
- - byteBufferChannel.truncate(truncateSize);
- -
- - assertThat(byteBufferChannel.position()).isEqualTo(truncateSize);
- - }
- -
- - @Test
- - public void truncate_sizeNotSmallerThanBufferSize() {
- - ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
- - long truncateSize = (long) validByteBufferLength;
- -
- - byteBufferChannel.truncate(truncateSize);
- -
- - assertThat(byteBufferChannel.position()).isEqualTo(0);
- - }
- -
- - @Test
- - public void write_srcBufferSmallerThanBufferChannel() {
- - String srcString = "5555";
- - long newPosition = 3;
- - String expectedString = "1235555890";
- - ByteBuffer srcBuffer = ByteBuffer.wrap(srcString.getBytes(UTF_8));
- -
- - ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
- - byteBufferChannel.position(newPosition);
- - byteBufferChannel.write(srcBuffer);
- -
- - assertThat(byteBufferChannel.position()).isEqualTo(newPosition + srcString.length());
- - ByteBuffer dstBuffer = ByteBuffer.allocate(validByteBufferLength);
- - byteBufferChannel.position(0);
- - byteBufferChannel.read(dstBuffer);
- - ByteBuffer expectedBuffer = ByteBuffer.wrap(expectedString.getBytes(UTF_8));
- - dstBuffer.rewind();
- - expectedBuffer.rewind();
- - assertThat(dstBuffer).isEqualTo(expectedBuffer);
- - }
- -
- - @Test
- - public void write_srcBufferGreaterThanBufferChannel() {
- - String srcString = "5555";
- - long newPosition = 8;
- - String expectedString = "1234567855";
- - ByteBuffer srcBuffer = ByteBuffer.wrap(srcString.getBytes(UTF_8));
- -
- - ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
- - byteBufferChannel.position(newPosition);
- - byteBufferChannel.write(srcBuffer);
- - assertThat(byteBufferChannel.position()).isEqualTo(validByteBufferLength);
- -
- - ByteBuffer dstBuffer = ByteBuffer.allocate(validByteBufferLength);
- - byteBufferChannel.position(0);
- - byteBufferChannel.read(dstBuffer);
- - ByteBuffer expectedBuffer = ByteBuffer.wrap(expectedString.getBytes(UTF_8));
- - dstBuffer.rewind();
- - expectedBuffer.rewind();
- - assertThat(dstBuffer).isEqualTo(expectedBuffer);
- - }
- -
- - private static String convertByteByfferToString(ByteBuffer buffer, int arrLength) {
- - byte[] bytes = new byte[arrLength];
- - buffer.rewind();
- - buffer.get(bytes);
- - return new String(bytes, UTF_8);
- - }
- + private static final String VALID_STRING = "1234567890";
- + private final ByteBuffer validByteBuffer = ByteBuffer.wrap(VALID_STRING.getBytes(UTF_8));
- + private final int validByteBufferLength = validByteBuffer.limit();
- +
- + @Test
- + public void byteBufferChannel_validByteBuffer() {
- + ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
- + assertThat(byteBufferChannel).isNotNull();
- + }
- +
- + @Test
- + public void byteBufferChannel_nullByteBuffer_throwsException() {
- + NullPointerException exception = assertThrows(
- + NullPointerException.class, () -> new ByteBufferChannel(/*buffer=*/null));
- + assertThat(exception).hasMessageThat().isEqualTo("The ByteBuffer cannot be null.");
- + }
- +
- + @Test
- + public void isOpen_openedByteBufferChannel() {
- + ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
- + assertThat(byteBufferChannel.isOpen()).isTrue();
- + }
- +
- + @Test
- + public void position_newByteBufferChannelWithInitialPosition0() {
- + ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
- + long position = byteBufferChannel.position();
- +
- + long expectedPosition = 0;
- + assertThat(position).isEqualTo(expectedPosition);
- + }
- +
- + @Test
- + public void position_validNewPosition() {
- + ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
- + long validNewPosition = 6;
- +
- + byteBufferChannel.position(validNewPosition);
- + long position = byteBufferChannel.position();
- +
- + assertThat(position).isEqualTo(validNewPosition);
- + }
- +
- + @Test
- + public void position_negtiveNewPosition_throwsException() {
- + ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
- + long invalidNewPosition = -1;
- +
- + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
- + () -> byteBufferChannel.position(invalidNewPosition));
- + assertThat(exception).hasMessageThat().isEqualTo(
- + "The new position should be non-negative and be less than Integer.MAX_VALUE.");
- + }
- +
- + @Test
- + public void position_newPositionGreaterThanMaxIntegerValue_throwsException() {
- + ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
- + long invalidNewPosition = Integer.MAX_VALUE + 1;
- +
- + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
- + () -> byteBufferChannel.position(invalidNewPosition));
- + assertThat(exception).hasMessageThat().isEqualTo(
- + "The new position should be non-negative and be less than Integer.MAX_VALUE.");
- + }
- +
- + @Test
- + public void position_newPositionGreaterThanByfferLength_throwsException() {
- + ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
- + long invalidNewPosition = (long) validByteBufferLength + 1;
- +
- + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
- + () -> byteBufferChannel.position(invalidNewPosition));
- + assertThat(exception).hasMessageThat().isEqualTo("newPosition > limit: (11 > 10)");
- + }
- +
- + @Test
- + public void read_fromPosition0() {
- + ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
- + long validNewPosition = 0;
- +
- + byteBufferChannel.position(validNewPosition);
- + ByteBuffer dstBuffer = ByteBuffer.allocate(validByteBufferLength);
- + int numBytes = byteBufferChannel.read(dstBuffer);
- +
- + assertThat(numBytes).isEqualTo(validByteBufferLength);
- + assertThat(dstBuffer).isEqualTo(validByteBuffer);
- + }
- +
- + @Test
- + public void read_fromPosition5() {
- + ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
- + long validNewPosition = 5;
- +
- + byteBufferChannel.position(validNewPosition);
- + ByteBuffer dstBuffer = ByteBuffer.allocate(validByteBufferLength);
- + int numBytes = byteBufferChannel.read(dstBuffer);
- +
- + assertThat(numBytes).isEqualTo(validByteBufferLength - (int) validNewPosition);
- + String dstString = convertByteByfferToString(dstBuffer, numBytes);
- + String expectedString = "67890";
- + assertThat(dstString).isEqualTo(expectedString);
- + }
- +
- + @Test
- + public void read_fromPositionValidByteBufferLength() {
- + ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
- + long validNewPosition = validByteBufferLength;
- +
- + byteBufferChannel.position(validNewPosition);
- + ByteBuffer dstBuffer = ByteBuffer.allocate(validByteBufferLength);
- + int numBytes = byteBufferChannel.read(dstBuffer);
- +
- + assertThat(numBytes).isEqualTo(-1);
- + }
- +
- + @Test
- + public void read_dstBufferRemaining0() {
- + ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
- + long validNewPosition = 0;
- +
- + byteBufferChannel.position(validNewPosition);
- + ByteBuffer dstBuffer = ByteBuffer.allocate(validByteBufferLength);
- + dstBuffer.position(validByteBufferLength);
- + int numBytes = byteBufferChannel.read(dstBuffer);
- +
- + assertThat(numBytes).isEqualTo(0);
- + String dstString = convertByteByfferToString(dstBuffer, numBytes);
- + String expectedString = "";
- + assertThat(dstString).isEqualTo(expectedString);
- + }
- +
- + @Test
- + public void read_dstBufferIsSmallerThanTheBufferChannel() {
- + ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
- + int dstBufferLength = 3;
- +
- + ByteBuffer dstBuffer = ByteBuffer.allocate(dstBufferLength);
- + int numBytes = byteBufferChannel.read(dstBuffer);
- +
- + assertThat(numBytes).isEqualTo(dstBufferLength);
- + assertThat(validByteBuffer.position()).isEqualTo(dstBufferLength);
- +
- + String dstString = convertByteByfferToString(dstBuffer, dstBufferLength);
- + String expectedString = "123";
- + assertThat(dstString).isEqualTo(expectedString);
- + }
- +
- + @Test
- + public void size_validBuffer() {
- + ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
- + assertThat(byteBufferChannel.size()).isEqualTo(validByteBufferLength);
- + }
- +
- + @Test
- + public void truncate_validSizeAndPosition0() {
- + ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
- + long truncateSize = 3;
- +
- + byteBufferChannel.truncate(truncateSize);
- +
- + assertThat(byteBufferChannel.size()).isEqualTo(truncateSize);
- + assertThat(byteBufferChannel.position()).isEqualTo(0);
- + }
- +
- + @Test
- + public void truncate_validSizeAndPosition5() {
- + ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
- + long validNewPosition = 5;
- +
- + byteBufferChannel.position(validNewPosition);
- + long truncateSize = 3;
- + byteBufferChannel.truncate(truncateSize);
- +
- + assertThat(byteBufferChannel.position()).isEqualTo(truncateSize);
- + }
- +
- + @Test
- + public void truncate_sizeNotSmallerThanBufferSize() {
- + ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
- + long truncateSize = (long) validByteBufferLength;
- +
- + byteBufferChannel.truncate(truncateSize);
- +
- + assertThat(byteBufferChannel.position()).isEqualTo(0);
- + }
- +
- + @Test
- + public void write_srcBufferSmallerThanBufferChannel() {
- + String srcString = "5555";
- + long newPosition = 3;
- + String expectedString = "1235555890";
- + ByteBuffer srcBuffer = ByteBuffer.wrap(srcString.getBytes(UTF_8));
- +
- + ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
- + byteBufferChannel.position(newPosition);
- + byteBufferChannel.write(srcBuffer);
- +
- + assertThat(byteBufferChannel.position()).isEqualTo(newPosition + srcString.length());
- + ByteBuffer dstBuffer = ByteBuffer.allocate(validByteBufferLength);
- + byteBufferChannel.position(0);
- + byteBufferChannel.read(dstBuffer);
- + ByteBuffer expectedBuffer = ByteBuffer.wrap(expectedString.getBytes(UTF_8));
- + dstBuffer.rewind();
- + expectedBuffer.rewind();
- + assertThat(dstBuffer).isEqualTo(expectedBuffer);
- + }
- +
- + @Test
- + public void write_srcBufferGreaterThanBufferChannel() {
- + String srcString = "5555";
- + long newPosition = 8;
- + String expectedString = "1234567855";
- + ByteBuffer srcBuffer = ByteBuffer.wrap(srcString.getBytes(UTF_8));
- +
- + ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
- + byteBufferChannel.position(newPosition);
- + byteBufferChannel.write(srcBuffer);
- + assertThat(byteBufferChannel.position()).isEqualTo(validByteBufferLength);
- +
- + ByteBuffer dstBuffer = ByteBuffer.allocate(validByteBufferLength);
- + byteBufferChannel.position(0);
- + byteBufferChannel.read(dstBuffer);
- + ByteBuffer expectedBuffer = ByteBuffer.wrap(expectedString.getBytes(UTF_8));
- + dstBuffer.rewind();
- + expectedBuffer.rewind();
- + assertThat(dstBuffer).isEqualTo(expectedBuffer);
- + }
- +
- + private static String convertByteByfferToString(ByteBuffer buffer, int arrLength) {
- + byte[] bytes = new byte[arrLength];
- + buffer.rewind();
- + buffer.get(bytes);
- + return new String(bytes, UTF_8);
- + }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/MetadataExtractorTest.java b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/MetadataExtractorTest.java
- index 67fc50d9f57b1..9f1173a1ea19b 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/MetadataExtractorTest.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/MetadataExtractorTest.java
- @@ -16,24 +16,20 @@ limitations under the License.
- package org.tensorflow.lite.support.metadata;
-
- import static com.google.common.truth.Truth.assertThat;
- +
- import static org.junit.Assert.assertArrayEquals;
- import static org.junit.Assert.assertThrows;
-
- import android.content.Context;
- import android.content.res.AssetFileDescriptor;
- +
- import androidx.test.core.app.ApplicationProvider;
- +
- import com.google.flatbuffers.FlatBufferBuilder;
- -import java.io.FileInputStream;
- -import java.io.InputStream;
- -import java.nio.ByteBuffer;
- -import java.nio.channels.FileChannel;
- -import java.util.Arrays;
- -import java.util.Collection;
- -import java.util.HashSet;
- -import java.util.Random;
- -import java.util.Set;
- +
- import org.apache.commons.io.IOUtils;
- import org.checkerframework.checker.nullness.qual.Nullable;
- +import org.junit.Ignore;
- import org.junit.Test;
- import org.junit.runner.RunWith;
- import org.junit.runners.Suite;
- @@ -56,931 +52,903 @@ import org.tensorflow.lite.support.metadata.schema.ModelMetadata;
- import org.tensorflow.lite.support.metadata.schema.SubGraphMetadata;
- import org.tensorflow.lite.support.metadata.schema.TensorMetadata;
-
- -import org.junit.Ignore;
- +import java.io.FileInputStream;
- +import java.io.InputStream;
- +import java.nio.ByteBuffer;
- +import java.nio.channels.FileChannel;
- +import java.util.Arrays;
- +import java.util.Collection;
- +import java.util.HashSet;
- +import java.util.Random;
- +import java.util.Set;
-
- /** Tests of {@link MetadataExtractor}. */
- @RunWith(Suite.class)
- @SuiteClasses({MetadataExtractorTest.General.class, MetadataExtractorTest.InputTensorType.class})
- public class MetadataExtractorTest {
- - private static final int[] validShape = new int[] {4, 10, 10, 3};
- - private static final byte DATA_TYPE = TensorType.UINT8;
- - private static final byte CONTENT_PROPERTIES_TYPE = ContentProperties.ImageProperties;
- - private static final float VALID_SCALE = 3.3f;
- - private static final long VALID_ZERO_POINT = 2;
- - private static final float DEFAULT_SCALE = 0.0f;
- - private static final long DEFAULT_ZERO_POINT = 0;
- - private static final String MODEL_NAME = "model.tflite";
- - // Scale and zero point should both be a single value, not an array.
- - private static final float[] invalidScale = new float[] {0.0f, 1.2f};
- - private static final long[] invalidZeroPoint = new long[] {1, 2};
- - private static final String MODEL_PATH = "mobilenet_v1_1.0_224_quant.tflite";
- - // labels.txt is packed in mobilenet_v1_1.0_224_quant.tflite as an associated file.
- - private static final String VALID_LABEL_FILE_NAME = "labels.txt";
- - // invalid.txt is not packed in mobilenet_v1_1.0_224_quant.tflite.
- - private static final String INVALID_LABEL_FILE_NAME = "invalid.txt";
- - private static final int EMPTY_FLATBUFFER_VECTOR = -1;
- - private static final String TFLITE_MODEL_IDENTIFIER = "TFL3";
- - private static final String TFLITE_METADATA_IDENTIFIER = "M001";
- -
- - /** General tests of MetadataExtractor. */
- - @RunWith(RobolectricTestRunner.class)
- - public static final class General extends MetadataExtractorTest {
- -
- - @Test
- - public void hasMetadata_modelWithMetadata() throws Exception {
- - // Creates a model flatbuffer with metadata.
- - ByteBuffer modelWithMetadata = createModelByteBuffer();
- -
- - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- - assertThat(metadataExtractor.hasMetadata()).isTrue();
- - }
- -
- - @Test
- - public void hasMetadata_modelWithoutMetadata() throws Exception {
- - // Creates a model flatbuffer without metadata.
- - ByteBuffer modelWithoutMetadata = createModelByteBuffer(/*metadataBuffer=*/ null, DATA_TYPE);
- -
- - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithoutMetadata);
- - assertThat(metadataExtractor.hasMetadata()).isFalse();
- - }
- -
- - @Ignore
- - @Test
- - public void getAssociatedFile_validAssociateFile() throws Exception {
- - ByteBuffer mobileNetBuffer = loadMobileNetBuffer();
- - MetadataExtractor mobileNetMetadataExtractor = new MetadataExtractor(mobileNetBuffer);
- - InputStream associateFileStream =
- - mobileNetMetadataExtractor.getAssociatedFile(VALID_LABEL_FILE_NAME);
- -
- - // Reads the golden file from context.
- - Context context = ApplicationProvider.getApplicationContext();
- - InputStream goldenAssociateFileStream = context.getAssets().open(VALID_LABEL_FILE_NAME);
- - assertThat(IOUtils.contentEquals(goldenAssociateFileStream, associateFileStream)).isTrue();
- - }
- -
- - @Ignore
- - @Test
- - public void getAssociatedFile_invalidAssociateFile() throws Exception {
- - ByteBuffer mobileNetBuffer = loadMobileNetBuffer();
- - MetadataExtractor mobileNetMetadataExtractor = new MetadataExtractor(mobileNetBuffer);
- - IllegalArgumentException exception =
- - assertThrows(
- - IllegalArgumentException.class,
- - () -> mobileNetMetadataExtractor.getAssociatedFile(INVALID_LABEL_FILE_NAME));
- - assertThat(exception)
- - .hasMessageThat()
- - .isEqualTo(
- - String.format(
- - "The file, %s, does not exist in the zip file.", INVALID_LABEL_FILE_NAME));
- - }
- -
- - @Ignore
- - @Test
- - public void getAssociatedFile_nullFileName() throws Exception {
- - ByteBuffer mobileNetBuffer = loadMobileNetBuffer();
- - MetadataExtractor mobileNetMetadataExtractor = new MetadataExtractor(mobileNetBuffer);
- - IllegalArgumentException exception =
- - assertThrows(
- - IllegalArgumentException.class,
- - () -> mobileNetMetadataExtractor.getAssociatedFile(/*fileName=*/ null));
- - assertThat(exception)
- - .hasMessageThat()
- - .contains("The file, null, does not exist in the zip file.");
- - }
- -
- - @Test
- - public void getAssociatedFile_nonZipModel_throwsException() throws Exception {
- - // Creates a model flatbuffer with metadata.
- - ByteBuffer modelWithMetadata = createModelByteBuffer();
- -
- - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- - IllegalStateException exception =
- - assertThrows(
- - IllegalStateException.class,
- - () -> metadataExtractor.getAssociatedFile(VALID_LABEL_FILE_NAME));
- - assertThat(exception)
- - .hasMessageThat()
- - .contains("This model does not contain associated files, and is not a Zip file.");
- - }
- -
- - @Test
- - public void getAssociatedFileNames_nonZipModel_throwsException() throws Exception {
- - // Creates a model flatbuffer with metadata.
- - ByteBuffer modelWithMetadata = createModelByteBuffer();
- -
- - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- - IllegalStateException exception =
- - assertThrows(IllegalStateException.class, metadataExtractor::getAssociatedFileNames);
- - assertThat(exception)
- - .hasMessageThat()
- - .contains("This model does not contain associated files, and is not a Zip file.");
- - }
- -
- - @Ignore
- - @Test
- - public void getAssociatedFileNames_validFileNames() throws Exception {
- - ByteBuffer mobileNetBuffer = loadMobileNetBuffer();
- - MetadataExtractor mobileNetMetadataExtractor = new MetadataExtractor(mobileNetBuffer);
- - Set<String> expectedSet = new HashSet<>();
- - expectedSet.add(VALID_LABEL_FILE_NAME);
- - assertThat(mobileNetMetadataExtractor.getAssociatedFileNames()).isEqualTo(expectedSet);
- - }
- -
- - @Test
- - public void metadataExtractor_loadNullBuffer_throwsException() {
- - ByteBuffer nullBuffer = null;
- - NullPointerException exception =
- - assertThrows(NullPointerException.class, () -> new MetadataExtractor(nullBuffer));
- - assertThat(exception).hasMessageThat().contains("Model flatbuffer cannot be null.");
- - }
- -
- - @Test
- - public void metadataExtractor_loadRandomBuffer_throwsException() {
- - ByteBuffer randomBuffer = createRandomByteBuffer();
- - IllegalArgumentException exception =
- - assertThrows(IllegalArgumentException.class, () -> new MetadataExtractor(randomBuffer));
- - assertThat(exception)
- - .hasMessageThat()
- - .contains(
- - "The identifier of the model is invalid. The buffer may not be a valid TFLite model"
- - + " flatbuffer.");
- - }
- -
- - @Test
- - public void metadataExtractor_loadModelWithInvalidIdentifier_throwsException() {
- - // Creates a model with an invalid identifier.
- - String invalidIdentifier = "INVI";
- - FlatBufferBuilder builder = new FlatBufferBuilder();
- - Model.startModel(builder);
- - int model = Model.endModel(builder);
- - builder.finish(model, invalidIdentifier);
- - ByteBuffer modelBuffer = builder.dataBuffer();
- -
- - IllegalArgumentException exception =
- - assertThrows(IllegalArgumentException.class, () -> new MetadataExtractor(modelBuffer));
- - assertThat(exception)
- - .hasMessageThat()
- - .contains(
- - "The identifier of the model is invalid. The buffer may not be a valid TFLite model"
- - + " flatbuffer.");
- - }
- -
- - @Test
- - public void metadataExtractor_loadMetadataWithInvalidIdentifier_throwsException() {
- - // Creates a model with metadata which contains an invalid identifier.
- - String invalidIdentifier = "INVI";
- - ByteBuffer metadata = createMetadataByteBuffer(invalidIdentifier, null);
- - ByteBuffer modelBuffer = createModelByteBuffer(metadata, DATA_TYPE);
- -
- - IllegalArgumentException exception =
- - assertThrows(IllegalArgumentException.class, () -> new MetadataExtractor(modelBuffer));
- - assertThat(exception)
- - .hasMessageThat()
- - .contains(
- - "The identifier of the metadata is invalid. The buffer may not be a valid TFLite"
- - + " metadata flatbuffer.");
- - }
- -
- - @Test
- - public void getInputTensorCount_validModelFile() throws Exception {
- - // Creates a model flatbuffer with metadata.
- - ByteBuffer modelWithMetadata = createModelByteBuffer();
- -
- - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- - int count = metadataExtractor.getInputTensorCount();
- - assertThat(count).isEqualTo(3);
- - }
- -
- - @Test
- - public void getOutputTensorCount_validModelFile() throws Exception {
- - // Creates a model flatbuffer with metadata.
- - ByteBuffer modelWithMetadata = createModelByteBuffer();
- -
- - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- - int count = metadataExtractor.getOutputTensorCount();
- - assertThat(count).isEqualTo(3);
- - }
- -
- - @Test
- - public void getInputTensorShape_validTensorShape() throws Exception {
- - // Creates a model flatbuffer with metadata.
- - ByteBuffer modelWithMetadata = createModelByteBuffer();
- -
- - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- - int[] shape = metadataExtractor.getInputTensorShape(0);
- - assertArrayEquals(validShape, shape);
- - }
- -
- - @Test
- - public void getInputTensorShape_emptyTensor() throws Exception {
- - // Creates a model flatbuffer with metadata.
- - ByteBuffer modelWithMetadata = createModelByteBuffer();
- -
- - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- - int[] shape = metadataExtractor.getInputTensorShape(1);
- - assertThat(shape).isEmpty();
- - }
- -
- - @Test
- - public void getInputTensorType_emptyTensor() throws Exception {
- - // Creates a model flatbuffer with metadata.
- - ByteBuffer modelWithMetadata = createModelByteBuffer();
- -
- - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- - byte type = metadataExtractor.getInputTensorType(1);
- - assertThat(type).isEqualTo(TensorType.FLOAT32);
- - }
- -
- - @Test
- - public void getOutputTensorShape_validTensor() throws Exception {
- - // Creates a model flatbuffer with metadata.
- - ByteBuffer modelWithMetadata = createModelByteBuffer();
- -
- - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- - int[] shape = metadataExtractor.getOutputTensorShape(0);
- - assertArrayEquals(validShape, shape);
- - }
- -
- - @Test
- - public void getOutputTensorShape_emptyTensor() throws Exception {
- - // Creates a model flatbuffer with metadata.
- - ByteBuffer modelWithMetadata = createModelByteBuffer();
- -
- - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- - int[] shape = metadataExtractor.getOutputTensorShape(1);
- - assertThat(shape).isEmpty();
- - }
- -
- - @Test
- - public void getOutputTensorType_emptyTensor() throws Exception {
- - // Creates a model flatbuffer with metadata.
- - ByteBuffer modelWithMetadata = createModelByteBuffer();
- -
- - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- - byte type = metadataExtractor.getOutputTensorType(1);
- - assertThat(type).isEqualTo(TensorType.FLOAT32);
- - }
- -
- - @Test
- - public void getInputTensorShape_indexGreaterThanTensorNumber_throwsException()
- - throws Exception {
- - // Creates a model flatbuffer with metadata.
- - ByteBuffer modelWithMetadata = createModelByteBuffer();
- -
- - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- - IllegalArgumentException exception =
- - assertThrows(
- - IllegalArgumentException.class, () -> metadataExtractor.getInputTensorShape(3));
- - assertThat(exception).hasMessageThat().contains("The inputIndex specified is invalid.");
- - }
- -
- - @Test
- - public void getInputTensorShape_negtiveIndex_throwsException() throws Exception {
- - // Creates a model flatbuffer with metadata.
- - ByteBuffer modelWithMetadata = createModelByteBuffer();
- -
- - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- - IllegalArgumentException exception =
- - assertThrows(
- - IllegalArgumentException.class, () -> metadataExtractor.getInputTensorShape(-1));
- - assertThat(exception).hasMessageThat().contains("The inputIndex specified is invalid.");
- - }
- -
- - @Test
- - public void getOutputTensorShape_indexGreaterThanTensorNumber_throwsException()
- - throws Exception {
- - // Creates a model flatbuffer with metadata.
- - ByteBuffer modelWithMetadata = createModelByteBuffer();
- -
- - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- - IllegalArgumentException exception =
- - assertThrows(
- - IllegalArgumentException.class, () -> metadataExtractor.getOutputTensorShape(3));
- - assertThat(exception).hasMessageThat().contains("The outputIndex specified is invalid.");
- - }
- -
- - @Test
- - public void getOutputTensorShape_negtiveIndex_throwsException() throws Exception {
- - // Creates a model flatbuffer with metadata.
- - ByteBuffer modelWithMetadata = createModelByteBuffer();
- -
- - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- - IllegalArgumentException exception =
- - assertThrows(
- - IllegalArgumentException.class, () -> metadataExtractor.getOutputTensorShape(-1));
- - assertThat(exception).hasMessageThat().contains("The outputIndex specified is invalid.");
- - }
- -
- - @Test
- - public void getModelMetadata_modelWithMetadata() throws Exception {
- - // Creates a model flatbuffer with metadata.
- - ByteBuffer modelWithMetadata = createModelByteBuffer();
- -
- - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- - ModelMetadata modelMetadata = metadataExtractor.getModelMetadata();
- - assertThat(modelMetadata.name()).isEqualTo(MODEL_NAME);
- - }
- -
- - @Test
- - public void getModelMetadata_modelWithoutMetadata_throwsException() throws Exception {
- - // Creates a model flatbuffer without metadata.
- - ByteBuffer modelWithoutMetadata = createModelByteBuffer(/*metadataBuffer=*/ null, DATA_TYPE);
- -
- - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithoutMetadata);
- -
- - IllegalStateException exception =
- - assertThrows(IllegalStateException.class, () -> metadataExtractor.getModelMetadata());
- - assertThat(exception)
- - .hasMessageThat()
- - .contains("This model does not contain model metadata.");
- - }
- -
- - @Test
- - public void metadataExtractor_modelWithEmptySubgraphMetadata_throwsException() {
- - // Creates a metadata FlatBuffer without empty subgraph metadata.
- - FlatBufferBuilder builder = new FlatBufferBuilder();
- - SubGraphMetadata.startSubGraphMetadata(builder);
- - int subgraph1Metadata = SubGraphMetadata.endSubGraphMetadata(builder);
- - int subgraphsMetadata =
- - ModelMetadata.createSubgraphMetadataVector(builder, new int[] {subgraph1Metadata});
- -
- - ModelMetadata.startModelMetadata(builder);
- - ModelMetadata.addSubgraphMetadata(builder, subgraphsMetadata);
- - int modelMetadata = ModelMetadata.endModelMetadata(builder);
- - builder.finish(modelMetadata, TFLITE_METADATA_IDENTIFIER);
- - ByteBuffer emptyMetadata = builder.dataBuffer();
- - ByteBuffer modelWithEmptyMetadata = createModelByteBuffer(emptyMetadata, DATA_TYPE);
- -
- - IllegalArgumentException exception =
- - assertThrows(
- - IllegalArgumentException.class, () -> new MetadataExtractor(modelWithEmptyMetadata));
- - assertThat(exception)
- - .hasMessageThat()
- - .isEqualTo(
- - "The number of input tensors in the model is 3. The number of input tensors that"
- - + " recorded in the metadata is 0. These two values does not match.");
- - }
- -
- - @Test
- - public void metadataExtractor_modelWithEmptyMetadata_throwsException() {
- - // Creates a empty metadata FlatBuffer.
- - FlatBufferBuilder builder = new FlatBufferBuilder();
- - ModelMetadata.startModelMetadata(builder);
- - int modelMetadata = ModelMetadata.endModelMetadata(builder);
- - builder.finish(modelMetadata, TFLITE_METADATA_IDENTIFIER);
- -
- - ByteBuffer emptyMetadata = builder.dataBuffer();
- - ByteBuffer modelWithEmptyMetadata = createModelByteBuffer(emptyMetadata, DATA_TYPE);
- -
- - IllegalArgumentException exception =
- - assertThrows(
- - IllegalArgumentException.class, () -> new MetadataExtractor(modelWithEmptyMetadata));
- - assertThat(exception)
- - .hasMessageThat()
- - .contains("The metadata flatbuffer does not contain any subgraph metadata.");
- - }
- -
- - @Test
- - public void metadataExtractor_modelWithNoMetadata_throwsException() throws Exception {
- - // Creates a model flatbuffer without metadata.
- - ByteBuffer modelWithoutMetadata = createModelByteBuffer(/*metadataBuffer=*/ null, DATA_TYPE);
- -
- - // It is allowed to create a model without metadata, but invoking methods that reads metadata
- - // is not allowed.
- - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithoutMetadata);
- -
- - IllegalStateException exception =
- - assertThrows(
- - IllegalStateException.class, () -> metadataExtractor.getInputTensorMetadata(0));
- - assertThat(exception)
- - .hasMessageThat()
- - .contains("This model does not contain model metadata.");
- - }
- -
- - @Test
- - public void metadataExtractor_modelWithIrrelevantMetadata_throwsException() throws Exception {
- - // Creates a model with irrelevant metadata.
- - FlatBufferBuilder builder = new FlatBufferBuilder();
- - SubGraph.startSubGraph(builder);
- - int subgraph = SubGraph.endSubGraph(builder);
- -
- - int metadataName = builder.createString("Irrelevant metadata");
- - Metadata.startMetadata(builder);
- - Metadata.addName(builder, metadataName);
- - int metadata = Metadata.endMetadata(builder);
- - int metadataArray = Model.createMetadataVector(builder, new int[] {metadata});
- -
- - // Creates Model.
- - int[] subgraphs = new int[1];
- - subgraphs[0] = subgraph;
- - int modelSubgraphs = Model.createSubgraphsVector(builder, subgraphs);
- - Model.startModel(builder);
- - Model.addSubgraphs(builder, modelSubgraphs);
- - Model.addMetadata(builder, metadataArray);
- - int model = Model.endModel(builder);
- - builder.finish(model, TFLITE_MODEL_IDENTIFIER);
- - ByteBuffer modelBuffer = builder.dataBuffer();
- -
- - // It is allowed to create a model without metadata, but invoking methods that reads metadata
- - // is not allowed.
- - MetadataExtractor metadataExtractor = new MetadataExtractor(modelBuffer);
- -
- - IllegalStateException exception =
- - assertThrows(
- - IllegalStateException.class, () -> metadataExtractor.getInputTensorMetadata(0));
- - assertThat(exception)
- - .hasMessageThat()
- - .contains("This model does not contain model metadata.");
- - }
- -
- - @Test
- - public void getInputTensorMetadata_validTensor() throws Exception {
- - // Creates a model flatbuffer with metadata.
- - ByteBuffer modelWithMetadata = createModelByteBuffer();
- -
- - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- - TensorMetadata inputMetadata = metadataExtractor.getInputTensorMetadata(0);
- - assertThat(inputMetadata.content().contentPropertiesType())
- - .isEqualTo(CONTENT_PROPERTIES_TYPE);
- - }
- -
- - @Test
- - public void getInputTensorMetadata_emptyTensor() throws Exception {
- - // Creates a model flatbuffer with metadata.
- - ByteBuffer modelWithMetadata = createModelByteBuffer();
- -
- - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- - TensorMetadata inputMetadata = metadataExtractor.getInputTensorMetadata(1);
- - assertThat(inputMetadata.content()).isNull();
- - }
- -
- - @Test
- - public void getInputTensorMetadata_invalidTensor() throws Exception {
- - // Creates a model flatbuffer with metadata.
- - ByteBuffer modelWithMetadata = createModelByteBuffer();
- -
- - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- - TensorMetadata inputMetadata = metadataExtractor.getInputTensorMetadata(2);
- - assertThat(inputMetadata.content().contentPropertiesType())
- - .isEqualTo(CONTENT_PROPERTIES_TYPE);
- - }
- -
- - @Test
- - public void getOutputTensorMetadata_validTensor() throws Exception {
- - // Creates a model flatbuffer with metadata.
- - ByteBuffer modelWithMetadata = createModelByteBuffer();
- -
- - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- - TensorMetadata outputMetadata = metadataExtractor.getOutputTensorMetadata(0);
- - assertThat(outputMetadata.content().contentPropertiesType())
- - .isEqualTo(CONTENT_PROPERTIES_TYPE);
- - }
- -
- - @Test
- - public void getOutputTensorMetadata_emptyTensor() throws Exception {
- - // Creates a model flatbuffer with metadata.
- - ByteBuffer modelWithMetadata = createModelByteBuffer();
- -
- - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- - TensorMetadata outputMetadata = metadataExtractor.getOutputTensorMetadata(1);
- - assertThat(outputMetadata.content()).isNull();
- - }
- -
- - @Test
- - public void getOutputTensorMetadata_invalidTensor() throws Exception {
- - // Creates a model flatbuffer with metadata.
- - ByteBuffer modelWithMetadata = createModelByteBuffer();
- -
- - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- - TensorMetadata outputMetadata = metadataExtractor.getOutputTensorMetadata(2);
- - assertThat(outputMetadata.content().contentPropertiesType())
- - .isEqualTo(CONTENT_PROPERTIES_TYPE);
- - }
- -
- - @Test
- - public void getInputTensorMetadata_indexGreaterThanTensorNumber_throwsException()
- - throws Exception {
- - // Creates a model flatbuffer with metadata.
- - ByteBuffer modelWithMetadata = createModelByteBuffer();
- -
- - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- - IllegalArgumentException exception =
- - assertThrows(
- - IllegalArgumentException.class, () -> metadataExtractor.getInputTensorMetadata(3));
- - assertThat(exception).hasMessageThat().contains("The inputIndex specified is invalid.");
- - }
- -
- - @Test
- - public void getInputTensorMetadata_negtiveIndex_throwsException() throws Exception {
- - // Creates a model flatbuffer with metadata.
- - ByteBuffer modelWithMetadata = createModelByteBuffer();
- -
- - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- - IllegalArgumentException exception =
- - assertThrows(
- - IllegalArgumentException.class, () -> metadataExtractor.getInputTensorMetadata(-1));
- - assertThat(exception).hasMessageThat().contains("The inputIndex specified is invalid.");
- - }
- -
- - @Test
- - public void getOutputTensorMetadata_indexGreaterThanTensorNumber_throwsException()
- - throws Exception {
- - // Creates a model flatbuffer with metadata.
- - ByteBuffer modelWithMetadata = createModelByteBuffer();
- -
- - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- - IllegalArgumentException exception =
- - assertThrows(
- - IllegalArgumentException.class, () -> metadataExtractor.getOutputTensorMetadata(3));
- - assertThat(exception).hasMessageThat().contains("The outputIndex specified is invalid.");
- - }
- -
- - @Test
- - public void getOutputTensorMetadata_negtiveIndex_throwsException() throws Exception {
- - // Creates a model flatbuffer with metadata.
- - ByteBuffer modelWithMetadata = createModelByteBuffer();
- -
- - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- - IllegalArgumentException exception =
- - assertThrows(
- - IllegalArgumentException.class, () -> metadataExtractor.getOutputTensorMetadata(-1));
- - assertThat(exception).hasMessageThat().contains("The outputIndex specified is invalid.");
- - }
- -
- - @Test
- - public void getInputTensorQuantizationParams_validScaleAndZeroPoint() throws Exception {
- - // Creates a model flatbuffer with metadata.
- - ByteBuffer modelWithMetadata = createModelByteBuffer();
- -
- - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- - QuantizationParams quantizationParams = metadataExtractor.getInputTensorQuantizationParams(0);
- - assertThat(quantizationParams.getScale()).isEqualTo(VALID_SCALE);
- - assertThat(quantizationParams.getZeroPoint()).isEqualTo(VALID_ZERO_POINT);
- - }
- -
- - @Test
- - public void getInputTensorQuantizationParams_emptyTensor() throws Exception {
- - // Creates a model flatbuffer with metadata.
- - ByteBuffer modelWithMetadata = createModelByteBuffer();
- -
- - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- - QuantizationParams quantizationParams = metadataExtractor.getInputTensorQuantizationParams(1);
- - // Scale and zero point are expected to be 1.0f and 0, respectively as default.
- - assertThat(quantizationParams.getScale()).isEqualTo(DEFAULT_SCALE);
- - assertThat(quantizationParams.getZeroPoint()).isEqualTo(DEFAULT_ZERO_POINT);
- - }
- -
- - @Test
- - public void getInputTensorQuantizationParams_invalidScale() throws Exception {
- - // Creates a model flatbuffer with metadata.
- - ByteBuffer modelWithMetadata = createModelByteBuffer();
- -
- - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- - IllegalArgumentException exception =
- - assertThrows(
- - IllegalArgumentException.class,
- - () -> metadataExtractor.getInputTensorQuantizationParams(2));
- - assertThat(exception)
- - .hasMessageThat()
- - .contains("Input and output tensors do not support per-channel quantization.");
- - }
- -
- - @Test
- - public void getOutputTensorQuantizationParams_validScaleAndZeroPoint() throws Exception {
- - // Creates a model flatbuffer with metadata.
- - ByteBuffer modelWithMetadata = createModelByteBuffer();
- -
- - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- - QuantizationParams quantizationParams =
- - metadataExtractor.getOutputTensorQuantizationParams(0);
- - assertThat(quantizationParams.getScale()).isEqualTo(VALID_SCALE);
- - assertThat(quantizationParams.getZeroPoint()).isEqualTo(VALID_ZERO_POINT);
- - }
- -
- - @Test
- - public void getOutputTensorQuantizationParams_emptyTensor() throws Exception {
- - // Creates a model flatbuffer with metadata.
- - ByteBuffer modelWithMetadata = createModelByteBuffer();
- -
- - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- - QuantizationParams quantizationParams =
- - metadataExtractor.getOutputTensorQuantizationParams(1);
- - // Scale and zero point are expected to be 1.0f and 0, respectively as default.
- - assertThat(quantizationParams.getScale()).isEqualTo(DEFAULT_SCALE);
- - assertThat(quantizationParams.getZeroPoint()).isEqualTo(DEFAULT_ZERO_POINT);
- - }
- -
- - @Test
- - public void getOutputTensorQuantizationParams_invalidScale() throws Exception {
- - // Creates a model flatbuffer with metadata.
- - ByteBuffer modelWithMetadata = createModelByteBuffer();
- -
- - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- - IllegalArgumentException exception =
- - assertThrows(
- - IllegalArgumentException.class,
- - () -> metadataExtractor.getOutputTensorQuantizationParams(2));
- - assertThat(exception)
- - .hasMessageThat()
- - .contains("Input and output tensors do not support per-channel quantization.");
- - }
- -
- - @Test
- - public void isMinimumParserVersionSatisfied_olderVersion() throws Exception {
- - // A version older than the current one. The version starts from 1.0.0, thus 0.10.0 will
- - // precede any furture versions.
- - String minVersion = "0.10";
- - // Creates a metadata using the above version.
- - ByteBuffer metadata = createMetadataByteBuffer(TFLITE_METADATA_IDENTIFIER, minVersion);
- - ByteBuffer modelWithMetadata = createModelByteBuffer(metadata, DATA_TYPE);
- -
- - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- -
- - assertThat(metadataExtractor.isMinimumParserVersionSatisfied()).isTrue();
- - }
- -
- - @Test
- - public void isMinimumParserVersionSatisfied_sameVersionSamelength() throws Exception {
- - // A version the same as the current one.
- - String minVersion = MetadataParser.VERSION;
- - // Creates a metadata using the above version.
- - ByteBuffer metadata = createMetadataByteBuffer(TFLITE_METADATA_IDENTIFIER, minVersion);
- - ByteBuffer modelWithMetadata = createModelByteBuffer(metadata, DATA_TYPE);
- -
- - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- -
- - assertThat(metadataExtractor.isMinimumParserVersionSatisfied()).isTrue();
- - }
- -
- - @Test
- - public void isMinimumParserVersionSatisfied_sameVersionLongerlength() throws Exception {
- - // A version the same as the current one, but with longer length.
- - String minVersion = MetadataParser.VERSION + ".0";
- - // Creates a metadata using the above version.
- - ByteBuffer metadata = createMetadataByteBuffer(TFLITE_METADATA_IDENTIFIER, minVersion);
- - ByteBuffer modelWithMetadata = createModelByteBuffer(metadata, DATA_TYPE);
- -
- - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- -
- - assertThat(metadataExtractor.isMinimumParserVersionSatisfied()).isTrue();
- - }
- -
- - @Test
- - public void isMinimumParserVersionSatisfied_emptyVersion() throws Exception {
- - // An empty version, which can be generated before the first versioned release.
- - String minVersion = null;
- - // Creates a metadata using the above version.
- - ByteBuffer metadata = createMetadataByteBuffer(TFLITE_METADATA_IDENTIFIER, minVersion);
- - ByteBuffer modelWithMetadata = createModelByteBuffer(metadata, DATA_TYPE);
- -
- - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- -
- - assertThat(metadataExtractor.isMinimumParserVersionSatisfied()).isTrue();
- - }
- -
- - @Test
- - public void isMinimumParserVersionSatisfied_newerVersion() throws Exception {
- - // Creates a version newer than the current one by appending "1" to the end of the current
- - // version for testing purposes. For example, 1.0.0 becomes 1.0.01.
- - String minVersion = MetadataParser.VERSION + "1";
- - // Creates a metadata using the above version.
- - ByteBuffer metadata = createMetadataByteBuffer(TFLITE_METADATA_IDENTIFIER, minVersion);
- - ByteBuffer modelWithMetadata = createModelByteBuffer(metadata, DATA_TYPE);
- -
- - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- -
- - assertThat(metadataExtractor.isMinimumParserVersionSatisfied()).isFalse();
- - }
- -
- - @Test
- - public void isMinimumParserVersionSatisfied_newerVersionLongerLength() throws Exception {
- - // Creates a version newer than the current one by appending ".1" to the end of the current
- - // version for testing purposes. For example, 1.0.0 becomes 1.0.0.1.
- - String minVersion = MetadataParser.VERSION + ".1";
- - // Creates a metadata using the above version.
- - ByteBuffer metadata = createMetadataByteBuffer(TFLITE_METADATA_IDENTIFIER, minVersion);
- - ByteBuffer modelWithMetadata = createModelByteBuffer(metadata, DATA_TYPE);
- -
- - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- -
- - assertThat(metadataExtractor.isMinimumParserVersionSatisfied()).isFalse();
- - }
- - }
- -
- - /** Parameterized tests for the input tensor data type. */
- - @RunWith(ParameterizedRobolectricTestRunner.class)
- - public static final class InputTensorType extends MetadataExtractorTest {
- - /** The tensor type that used to create the model buffer. */
- - @Parameter(0)
- - public byte tensorType;
- -
- - /** A list of TensorType that is used in the test. */
- - @Parameters
- - public static Collection<Object[]> data() {
- - return Arrays.asList(
- - new Object[][] {
- - {TensorType.FLOAT32}, {TensorType.INT32},
- - {TensorType.UINT8}, {TensorType.INT64},
- - {TensorType.STRING}
- - });
- - }
- -
- - @Test
- - public void getInputTensorType_validTensor() throws Exception {
- - ByteBuffer metadata = createMetadataByteBuffer(TFLITE_METADATA_IDENTIFIER, null);
- - ByteBuffer modelWithMetadata = createModelByteBuffer(metadata, tensorType);
- - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- - byte type = metadataExtractor.getInputTensorType(0);
- - assertThat(type).isEqualTo(tensorType);
- - }
- -
- - @Test
- - public void getOutputTensorType_validTensor() throws Exception {
- - ByteBuffer metadata = createMetadataByteBuffer(TFLITE_METADATA_IDENTIFIER, null);
- - ByteBuffer modelWithMetadata = createModelByteBuffer(metadata, tensorType);
- - MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- - byte type = metadataExtractor.getOutputTensorType(0);
- - assertThat(type).isEqualTo(tensorType);
- - }
- - }
- -
- - /**
- - * Creates an example metadata flatbuffer, which contains one subgraph with three inputs and three
- - * outputs.
- - */
- - private static ByteBuffer createMetadataByteBuffer(
- - String identifier, @Nullable String minVersionStr) {
- - FlatBufferBuilder builder = new FlatBufferBuilder();
- -
- - Content.startContent(builder);
- - Content.addContentPropertiesType(builder, CONTENT_PROPERTIES_TYPE);
- - int content = Content.endContent(builder);
- -
- - TensorMetadata.startTensorMetadata(builder);
- - TensorMetadata.addContent(builder, content);
- - int metadataForValidTensor = TensorMetadata.endTensorMetadata(builder);
- -
- - TensorMetadata.startTensorMetadata(builder);
- - int metadataForEmptyTensor = TensorMetadata.endTensorMetadata(builder);
- -
- - TensorMetadata.startTensorMetadata(builder);
- - TensorMetadata.addContent(builder, content);
- - int metadataForInvalidTensor = TensorMetadata.endTensorMetadata(builder);
- -
- - int[] tensorMetadataArray =
- - new int[] {metadataForValidTensor, metadataForEmptyTensor, metadataForInvalidTensor};
- - int inputTensorMetadata =
- - SubGraphMetadata.createInputTensorMetadataVector(builder, tensorMetadataArray);
- - int outputTensorMetadata =
- - SubGraphMetadata.createOutputTensorMetadataVector(builder, tensorMetadataArray);
- -
- - SubGraphMetadata.startSubGraphMetadata(builder);
- - SubGraphMetadata.addInputTensorMetadata(builder, inputTensorMetadata);
- - SubGraphMetadata.addOutputTensorMetadata(builder, outputTensorMetadata);
- - int subgraph1Metadata = SubGraphMetadata.endSubGraphMetadata(builder);
- -
- - int[] subgraphMetadataArray = new int[] {subgraph1Metadata};
- - int subgraphsMetadata =
- - ModelMetadata.createSubgraphMetadataVector(builder, subgraphMetadataArray);
- -
- - int modelName = builder.createString(MODEL_NAME);
- - if (minVersionStr != null) {
- - int minVersion = builder.createString(minVersionStr);
- - ModelMetadata.startModelMetadata(builder);
- - ModelMetadata.addMinParserVersion(builder, minVersion);
- - } else {
- - // If minVersionStr is null, skip generating the field in the metadata.
- - ModelMetadata.startModelMetadata(builder);
- - }
- - ModelMetadata.addName(builder, modelName);
- - ModelMetadata.addSubgraphMetadata(builder, subgraphsMetadata);
- - int modelMetadata = ModelMetadata.endModelMetadata(builder);
- -
- - builder.finish(modelMetadata, identifier);
- - return builder.dataBuffer();
- - }
- -
- - private static int createQuantizationParameters(
- - FlatBufferBuilder builder, float[] scale, long[] zeroPoint) {
- - int inputScale = QuantizationParameters.createScaleVector(builder, scale);
- - int inputZeroPoint = QuantizationParameters.createZeroPointVector(builder, zeroPoint);
- - QuantizationParameters.startQuantizationParameters(builder);
- - QuantizationParameters.addScale(builder, inputScale);
- - QuantizationParameters.addZeroPoint(builder, inputZeroPoint);
- - return QuantizationParameters.endQuantizationParameters(builder);
- - }
- -
- - private static int createTensor(
- - FlatBufferBuilder builder, int[] inputShape, byte inputType, int inputQuantization) {
- - int inputShapeVector1 = Tensor.createShapeVector(builder, inputShape);
- - Tensor.startTensor(builder);
- - Tensor.addShape(builder, inputShapeVector1);
- - Tensor.addType(builder, inputType);
- - Tensor.addQuantization(builder, inputQuantization);
- - return Tensor.endTensor(builder);
- - }
- -
- - /**
- - * Creates an example model flatbuffer, which contains one subgraph with three inputs and three
- - * output.
- - */
- - private static ByteBuffer createModelByteBuffer(ByteBuffer metadataBuffer, byte dataType) {
- - FlatBufferBuilder builder = new FlatBufferBuilder();
- -
- - // Creates a valid set of quantization parameters.
- - int validQuantization =
- - createQuantizationParameters(
- - builder, new float[] {VALID_SCALE}, new long[] {VALID_ZERO_POINT});
- -
- - // Creates an invalid set of quantization parameters.
- - int inValidQuantization = createQuantizationParameters(builder, invalidScale, invalidZeroPoint);
- -
- - // Creates an input Tensor with valid quantization parameters.
- - int validTensor = createTensor(builder, validShape, dataType, validQuantization);
- -
- - // Creates an empty input Tensor.
- - Tensor.startTensor(builder);
- - int emptyTensor = Tensor.endTensor(builder);
- -
- - // Creates an input Tensor with invalid quantization parameters.
- - int invalidTensor = createTensor(builder, validShape, dataType, inValidQuantization);
- -
- - // Creates the SubGraph.
- - int[] tensors = new int[6];
- - tensors[0] = validTensor;
- - tensors[1] = emptyTensor;
- - tensors[2] = invalidTensor;
- - tensors[3] = validTensor;
- - tensors[4] = emptyTensor;
- - tensors[5] = invalidTensor;
- - int subgraphTensors = SubGraph.createTensorsVector(builder, tensors);
- -
- - int subgraphInputs = SubGraph.createInputsVector(builder, new int[] {0, 1, 2});
- - int subgraphOutputs = SubGraph.createOutputsVector(builder, new int[] {3, 4, 5});
- -
- - SubGraph.startSubGraph(builder);
- - SubGraph.addTensors(builder, subgraphTensors);
- - SubGraph.addInputs(builder, subgraphInputs);
- - SubGraph.addOutputs(builder, subgraphOutputs);
- - int subgraph = SubGraph.endSubGraph(builder);
- -
- - // Creates the Model.
- - int[] subgraphs = new int[1];
- - subgraphs[0] = subgraph;
- - int modelSubgraphs = Model.createSubgraphsVector(builder, subgraphs);
- -
- - // Inserts metadataBuffer into the model if it's not null.
- - int modelBuffers = EMPTY_FLATBUFFER_VECTOR;
- - int metadataArray = EMPTY_FLATBUFFER_VECTOR;
- - if (metadataBuffer != null) {
- - int data = Buffer.createDataVector(builder, metadataBuffer);
- - Buffer.startBuffer(builder);
- - Buffer.addData(builder, data);
- - int buffer = Buffer.endBuffer(builder);
- - modelBuffers = Model.createBuffersVector(builder, new int[] {buffer});
- -
- - int metadataName = builder.createString(ModelInfo.METADATA_FIELD_NAME);
- - Metadata.startMetadata(builder);
- - Metadata.addName(builder, metadataName);
- - Metadata.addBuffer(builder, 0);
- - int metadata = Metadata.endMetadata(builder);
- - metadataArray = Model.createMetadataVector(builder, new int[] {metadata});
- - }
- -
- - Model.startModel(builder);
- - Model.addSubgraphs(builder, modelSubgraphs);
- - if (modelBuffers != EMPTY_FLATBUFFER_VECTOR && metadataArray != EMPTY_FLATBUFFER_VECTOR) {
- - Model.addBuffers(builder, modelBuffers);
- - Model.addMetadata(builder, metadataArray);
- + private static final int[] validShape = new int[] {4, 10, 10, 3};
- + private static final byte DATA_TYPE = TensorType.UINT8;
- + private static final byte CONTENT_PROPERTIES_TYPE = ContentProperties.ImageProperties;
- + private static final float VALID_SCALE = 3.3f;
- + private static final long VALID_ZERO_POINT = 2;
- + private static final float DEFAULT_SCALE = 0.0f;
- + private static final long DEFAULT_ZERO_POINT = 0;
- + private static final String MODEL_NAME = "model.tflite";
- + // Scale and zero point should both be a single value, not an array.
- + private static final float[] invalidScale = new float[] {0.0f, 1.2f};
- + private static final long[] invalidZeroPoint = new long[] {1, 2};
- + private static final String MODEL_PATH = "mobilenet_v1_1.0_224_quant.tflite";
- + // labels.txt is packed in mobilenet_v1_1.0_224_quant.tflite as an associated file.
- + private static final String VALID_LABEL_FILE_NAME = "labels.txt";
- + // invalid.txt is not packed in mobilenet_v1_1.0_224_quant.tflite.
- + private static final String INVALID_LABEL_FILE_NAME = "invalid.txt";
- + private static final int EMPTY_FLATBUFFER_VECTOR = -1;
- + private static final String TFLITE_MODEL_IDENTIFIER = "TFL3";
- + private static final String TFLITE_METADATA_IDENTIFIER = "M001";
- +
- + /** General tests of MetadataExtractor. */
- + @RunWith(RobolectricTestRunner.class)
- + public static final class General extends MetadataExtractorTest {
- + @Test
- + public void hasMetadata_modelWithMetadata() throws Exception {
- + // Creates a model flatbuffer with metadata.
- + ByteBuffer modelWithMetadata = createModelByteBuffer();
- +
- + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- + assertThat(metadataExtractor.hasMetadata()).isTrue();
- + }
- +
- + @Test
- + public void hasMetadata_modelWithoutMetadata() throws Exception {
- + // Creates a model flatbuffer without metadata.
- + ByteBuffer modelWithoutMetadata =
- + createModelByteBuffer(/*metadataBuffer=*/null, DATA_TYPE);
- +
- + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithoutMetadata);
- + assertThat(metadataExtractor.hasMetadata()).isFalse();
- + }
- +
- + @Ignore
- + @Test
- + public void getAssociatedFile_validAssociateFile() throws Exception {
- + ByteBuffer mobileNetBuffer = loadMobileNetBuffer();
- + MetadataExtractor mobileNetMetadataExtractor = new MetadataExtractor(mobileNetBuffer);
- + InputStream associateFileStream =
- + mobileNetMetadataExtractor.getAssociatedFile(VALID_LABEL_FILE_NAME);
- +
- + // Reads the golden file from context.
- + Context context = ApplicationProvider.getApplicationContext();
- + InputStream goldenAssociateFileStream = context.getAssets().open(VALID_LABEL_FILE_NAME);
- + assertThat(IOUtils.contentEquals(goldenAssociateFileStream, associateFileStream))
- + .isTrue();
- + }
- +
- + @Ignore
- + @Test
- + public void getAssociatedFile_invalidAssociateFile() throws Exception {
- + ByteBuffer mobileNetBuffer = loadMobileNetBuffer();
- + MetadataExtractor mobileNetMetadataExtractor = new MetadataExtractor(mobileNetBuffer);
- + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
- + () -> mobileNetMetadataExtractor.getAssociatedFile(INVALID_LABEL_FILE_NAME));
- + assertThat(exception).hasMessageThat().isEqualTo(String.format(
- + "The file, %s, does not exist in the zip file.", INVALID_LABEL_FILE_NAME));
- + }
- +
- + @Ignore
- + @Test
- + public void getAssociatedFile_nullFileName() throws Exception {
- + ByteBuffer mobileNetBuffer = loadMobileNetBuffer();
- + MetadataExtractor mobileNetMetadataExtractor = new MetadataExtractor(mobileNetBuffer);
- + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
- + () -> mobileNetMetadataExtractor.getAssociatedFile(/*fileName=*/null));
- + assertThat(exception).hasMessageThat().contains(
- + "The file, null, does not exist in the zip file.");
- + }
- +
- + @Test
- + public void getAssociatedFile_nonZipModel_throwsException() throws Exception {
- + // Creates a model flatbuffer with metadata.
- + ByteBuffer modelWithMetadata = createModelByteBuffer();
- +
- + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- + IllegalStateException exception = assertThrows(IllegalStateException.class,
- + () -> metadataExtractor.getAssociatedFile(VALID_LABEL_FILE_NAME));
- + assertThat(exception).hasMessageThat().contains(
- + "This model does not contain associated files, and is not a Zip file.");
- + }
- +
- + @Test
- + public void getAssociatedFileNames_nonZipModel_throwsException() throws Exception {
- + // Creates a model flatbuffer with metadata.
- + ByteBuffer modelWithMetadata = createModelByteBuffer();
- +
- + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- + IllegalStateException exception = assertThrows(
- + IllegalStateException.class, metadataExtractor::getAssociatedFileNames);
- + assertThat(exception).hasMessageThat().contains(
- + "This model does not contain associated files, and is not a Zip file.");
- + }
- +
- + @Ignore
- + @Test
- + public void getAssociatedFileNames_validFileNames() throws Exception {
- + ByteBuffer mobileNetBuffer = loadMobileNetBuffer();
- + MetadataExtractor mobileNetMetadataExtractor = new MetadataExtractor(mobileNetBuffer);
- + Set<String> expectedSet = new HashSet<>();
- + expectedSet.add(VALID_LABEL_FILE_NAME);
- + assertThat(mobileNetMetadataExtractor.getAssociatedFileNames()).isEqualTo(expectedSet);
- + }
- +
- + @Test
- + public void metadataExtractor_loadNullBuffer_throwsException() {
- + ByteBuffer nullBuffer = null;
- + NullPointerException exception = assertThrows(
- + NullPointerException.class, () -> new MetadataExtractor(nullBuffer));
- + assertThat(exception).hasMessageThat().contains("Model flatbuffer cannot be null.");
- + }
- +
- + @Test
- + public void metadataExtractor_loadRandomBuffer_throwsException() {
- + ByteBuffer randomBuffer = createRandomByteBuffer();
- + IllegalArgumentException exception = assertThrows(
- + IllegalArgumentException.class, () -> new MetadataExtractor(randomBuffer));
- + assertThat(exception).hasMessageThat().contains(
- + "The identifier of the model is invalid. The buffer may not be a valid TFLite model"
- + + " flatbuffer.");
- + }
- +
- + @Test
- + public void metadataExtractor_loadModelWithInvalidIdentifier_throwsException() {
- + // Creates a model with an invalid identifier.
- + String invalidIdentifier = "INVI";
- + FlatBufferBuilder builder = new FlatBufferBuilder();
- + Model.startModel(builder);
- + int model = Model.endModel(builder);
- + builder.finish(model, invalidIdentifier);
- + ByteBuffer modelBuffer = builder.dataBuffer();
- +
- + IllegalArgumentException exception = assertThrows(
- + IllegalArgumentException.class, () -> new MetadataExtractor(modelBuffer));
- + assertThat(exception).hasMessageThat().contains(
- + "The identifier of the model is invalid. The buffer may not be a valid TFLite model"
- + + " flatbuffer.");
- + }
- +
- + @Test
- + public void metadataExtractor_loadMetadataWithInvalidIdentifier_throwsException() {
- + // Creates a model with metadata which contains an invalid identifier.
- + String invalidIdentifier = "INVI";
- + ByteBuffer metadata = createMetadataByteBuffer(invalidIdentifier, null);
- + ByteBuffer modelBuffer = createModelByteBuffer(metadata, DATA_TYPE);
- +
- + IllegalArgumentException exception = assertThrows(
- + IllegalArgumentException.class, () -> new MetadataExtractor(modelBuffer));
- + assertThat(exception).hasMessageThat().contains(
- + "The identifier of the metadata is invalid. The buffer may not be a valid TFLite"
- + + " metadata flatbuffer.");
- + }
- +
- + @Test
- + public void getInputTensorCount_validModelFile() throws Exception {
- + // Creates a model flatbuffer with metadata.
- + ByteBuffer modelWithMetadata = createModelByteBuffer();
- +
- + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- + int count = metadataExtractor.getInputTensorCount();
- + assertThat(count).isEqualTo(3);
- + }
- +
- + @Test
- + public void getOutputTensorCount_validModelFile() throws Exception {
- + // Creates a model flatbuffer with metadata.
- + ByteBuffer modelWithMetadata = createModelByteBuffer();
- +
- + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- + int count = metadataExtractor.getOutputTensorCount();
- + assertThat(count).isEqualTo(3);
- + }
- +
- + @Test
- + public void getInputTensorShape_validTensorShape() throws Exception {
- + // Creates a model flatbuffer with metadata.
- + ByteBuffer modelWithMetadata = createModelByteBuffer();
- +
- + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- + int[] shape = metadataExtractor.getInputTensorShape(0);
- + assertArrayEquals(validShape, shape);
- + }
- +
- + @Test
- + public void getInputTensorShape_emptyTensor() throws Exception {
- + // Creates a model flatbuffer with metadata.
- + ByteBuffer modelWithMetadata = createModelByteBuffer();
- +
- + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- + int[] shape = metadataExtractor.getInputTensorShape(1);
- + assertThat(shape).isEmpty();
- + }
- +
- + @Test
- + public void getInputTensorType_emptyTensor() throws Exception {
- + // Creates a model flatbuffer with metadata.
- + ByteBuffer modelWithMetadata = createModelByteBuffer();
- +
- + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- + byte type = metadataExtractor.getInputTensorType(1);
- + assertThat(type).isEqualTo(TensorType.FLOAT32);
- + }
- +
- + @Test
- + public void getOutputTensorShape_validTensor() throws Exception {
- + // Creates a model flatbuffer with metadata.
- + ByteBuffer modelWithMetadata = createModelByteBuffer();
- +
- + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- + int[] shape = metadataExtractor.getOutputTensorShape(0);
- + assertArrayEquals(validShape, shape);
- + }
- +
- + @Test
- + public void getOutputTensorShape_emptyTensor() throws Exception {
- + // Creates a model flatbuffer with metadata.
- + ByteBuffer modelWithMetadata = createModelByteBuffer();
- +
- + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- + int[] shape = metadataExtractor.getOutputTensorShape(1);
- + assertThat(shape).isEmpty();
- + }
- +
- + @Test
- + public void getOutputTensorType_emptyTensor() throws Exception {
- + // Creates a model flatbuffer with metadata.
- + ByteBuffer modelWithMetadata = createModelByteBuffer();
- +
- + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- + byte type = metadataExtractor.getOutputTensorType(1);
- + assertThat(type).isEqualTo(TensorType.FLOAT32);
- + }
- +
- + @Test
- + public void getInputTensorShape_indexGreaterThanTensorNumber_throwsException()
- + throws Exception {
- + // Creates a model flatbuffer with metadata.
- + ByteBuffer modelWithMetadata = createModelByteBuffer();
- +
- + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- + IllegalArgumentException exception = assertThrows(
- + IllegalArgumentException.class, () -> metadataExtractor.getInputTensorShape(3));
- + assertThat(exception).hasMessageThat().contains("The inputIndex specified is invalid.");
- + }
- +
- + @Test
- + public void getInputTensorShape_negtiveIndex_throwsException() throws Exception {
- + // Creates a model flatbuffer with metadata.
- + ByteBuffer modelWithMetadata = createModelByteBuffer();
- +
- + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
- + () -> metadataExtractor.getInputTensorShape(-1));
- + assertThat(exception).hasMessageThat().contains("The inputIndex specified is invalid.");
- + }
- +
- + @Test
- + public void getOutputTensorShape_indexGreaterThanTensorNumber_throwsException()
- + throws Exception {
- + // Creates a model flatbuffer with metadata.
- + ByteBuffer modelWithMetadata = createModelByteBuffer();
- +
- + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
- + () -> metadataExtractor.getOutputTensorShape(3));
- + assertThat(exception).hasMessageThat().contains(
- + "The outputIndex specified is invalid.");
- + }
- +
- + @Test
- + public void getOutputTensorShape_negtiveIndex_throwsException() throws Exception {
- + // Creates a model flatbuffer with metadata.
- + ByteBuffer modelWithMetadata = createModelByteBuffer();
- +
- + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
- + () -> metadataExtractor.getOutputTensorShape(-1));
- + assertThat(exception).hasMessageThat().contains(
- + "The outputIndex specified is invalid.");
- + }
- +
- + @Test
- + public void getModelMetadata_modelWithMetadata() throws Exception {
- + // Creates a model flatbuffer with metadata.
- + ByteBuffer modelWithMetadata = createModelByteBuffer();
- +
- + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- + ModelMetadata modelMetadata = metadataExtractor.getModelMetadata();
- + assertThat(modelMetadata.name()).isEqualTo(MODEL_NAME);
- + }
- +
- + @Test
- + public void getModelMetadata_modelWithoutMetadata_throwsException() throws Exception {
- + // Creates a model flatbuffer without metadata.
- + ByteBuffer modelWithoutMetadata =
- + createModelByteBuffer(/*metadataBuffer=*/null, DATA_TYPE);
- +
- + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithoutMetadata);
- +
- + IllegalStateException exception = assertThrows(
- + IllegalStateException.class, () -> metadataExtractor.getModelMetadata());
- + assertThat(exception).hasMessageThat().contains(
- + "This model does not contain model metadata.");
- + }
- +
- + @Test
- + public void metadataExtractor_modelWithEmptySubgraphMetadata_throwsException() {
- + // Creates a metadata FlatBuffer without empty subgraph metadata.
- + FlatBufferBuilder builder = new FlatBufferBuilder();
- + SubGraphMetadata.startSubGraphMetadata(builder);
- + int subgraph1Metadata = SubGraphMetadata.endSubGraphMetadata(builder);
- + int subgraphsMetadata = ModelMetadata.createSubgraphMetadataVector(
- + builder, new int[] {subgraph1Metadata});
- +
- + ModelMetadata.startModelMetadata(builder);
- + ModelMetadata.addSubgraphMetadata(builder, subgraphsMetadata);
- + int modelMetadata = ModelMetadata.endModelMetadata(builder);
- + builder.finish(modelMetadata, TFLITE_METADATA_IDENTIFIER);
- + ByteBuffer emptyMetadata = builder.dataBuffer();
- + ByteBuffer modelWithEmptyMetadata = createModelByteBuffer(emptyMetadata, DATA_TYPE);
- +
- + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
- + () -> new MetadataExtractor(modelWithEmptyMetadata));
- + assertThat(exception).hasMessageThat().isEqualTo(
- + "The number of input tensors in the model is 3. The number of input tensors that"
- + + " recorded in the metadata is 0. These two values does not match.");
- + }
- +
- + @Test
- + public void metadataExtractor_modelWithEmptyMetadata_throwsException() {
- + // Creates a empty metadata FlatBuffer.
- + FlatBufferBuilder builder = new FlatBufferBuilder();
- + ModelMetadata.startModelMetadata(builder);
- + int modelMetadata = ModelMetadata.endModelMetadata(builder);
- + builder.finish(modelMetadata, TFLITE_METADATA_IDENTIFIER);
- +
- + ByteBuffer emptyMetadata = builder.dataBuffer();
- + ByteBuffer modelWithEmptyMetadata = createModelByteBuffer(emptyMetadata, DATA_TYPE);
- +
- + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
- + () -> new MetadataExtractor(modelWithEmptyMetadata));
- + assertThat(exception).hasMessageThat().contains(
- + "The metadata flatbuffer does not contain any subgraph metadata.");
- + }
- +
- + @Test
- + public void metadataExtractor_modelWithNoMetadata_throwsException() throws Exception {
- + // Creates a model flatbuffer without metadata.
- + ByteBuffer modelWithoutMetadata =
- + createModelByteBuffer(/*metadataBuffer=*/null, DATA_TYPE);
- +
- + // It is allowed to create a model without metadata, but invoking methods that reads
- + // metadata is not allowed.
- + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithoutMetadata);
- +
- + IllegalStateException exception = assertThrows(
- + IllegalStateException.class, () -> metadataExtractor.getInputTensorMetadata(0));
- + assertThat(exception).hasMessageThat().contains(
- + "This model does not contain model metadata.");
- + }
- +
- + @Test
- + public void metadataExtractor_modelWithIrrelevantMetadata_throwsException()
- + throws Exception {
- + // Creates a model with irrelevant metadata.
- + FlatBufferBuilder builder = new FlatBufferBuilder();
- + SubGraph.startSubGraph(builder);
- + int subgraph = SubGraph.endSubGraph(builder);
- +
- + int metadataName = builder.createString("Irrelevant metadata");
- + Metadata.startMetadata(builder);
- + Metadata.addName(builder, metadataName);
- + int metadata = Metadata.endMetadata(builder);
- + int metadataArray = Model.createMetadataVector(builder, new int[] {metadata});
- +
- + // Creates Model.
- + int[] subgraphs = new int[1];
- + subgraphs[0] = subgraph;
- + int modelSubgraphs = Model.createSubgraphsVector(builder, subgraphs);
- + Model.startModel(builder);
- + Model.addSubgraphs(builder, modelSubgraphs);
- + Model.addMetadata(builder, metadataArray);
- + int model = Model.endModel(builder);
- + builder.finish(model, TFLITE_MODEL_IDENTIFIER);
- + ByteBuffer modelBuffer = builder.dataBuffer();
- +
- + // It is allowed to create a model without metadata, but invoking methods that reads
- + // metadata is not allowed.
- + MetadataExtractor metadataExtractor = new MetadataExtractor(modelBuffer);
- +
- + IllegalStateException exception = assertThrows(
- + IllegalStateException.class, () -> metadataExtractor.getInputTensorMetadata(0));
- + assertThat(exception).hasMessageThat().contains(
- + "This model does not contain model metadata.");
- + }
- +
- + @Test
- + public void getInputTensorMetadata_validTensor() throws Exception {
- + // Creates a model flatbuffer with metadata.
- + ByteBuffer modelWithMetadata = createModelByteBuffer();
- +
- + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- + TensorMetadata inputMetadata = metadataExtractor.getInputTensorMetadata(0);
- + assertThat(inputMetadata.content().contentPropertiesType())
- + .isEqualTo(CONTENT_PROPERTIES_TYPE);
- + }
- +
- + @Test
- + public void getInputTensorMetadata_emptyTensor() throws Exception {
- + // Creates a model flatbuffer with metadata.
- + ByteBuffer modelWithMetadata = createModelByteBuffer();
- +
- + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- + TensorMetadata inputMetadata = metadataExtractor.getInputTensorMetadata(1);
- + assertThat(inputMetadata.content()).isNull();
- + }
- +
- + @Test
- + public void getInputTensorMetadata_invalidTensor() throws Exception {
- + // Creates a model flatbuffer with metadata.
- + ByteBuffer modelWithMetadata = createModelByteBuffer();
- +
- + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- + TensorMetadata inputMetadata = metadataExtractor.getInputTensorMetadata(2);
- + assertThat(inputMetadata.content().contentPropertiesType())
- + .isEqualTo(CONTENT_PROPERTIES_TYPE);
- + }
- +
- + @Test
- + public void getOutputTensorMetadata_validTensor() throws Exception {
- + // Creates a model flatbuffer with metadata.
- + ByteBuffer modelWithMetadata = createModelByteBuffer();
- +
- + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- + TensorMetadata outputMetadata = metadataExtractor.getOutputTensorMetadata(0);
- + assertThat(outputMetadata.content().contentPropertiesType())
- + .isEqualTo(CONTENT_PROPERTIES_TYPE);
- + }
- +
- + @Test
- + public void getOutputTensorMetadata_emptyTensor() throws Exception {
- + // Creates a model flatbuffer with metadata.
- + ByteBuffer modelWithMetadata = createModelByteBuffer();
- +
- + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- + TensorMetadata outputMetadata = metadataExtractor.getOutputTensorMetadata(1);
- + assertThat(outputMetadata.content()).isNull();
- + }
- +
- + @Test
- + public void getOutputTensorMetadata_invalidTensor() throws Exception {
- + // Creates a model flatbuffer with metadata.
- + ByteBuffer modelWithMetadata = createModelByteBuffer();
- +
- + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- + TensorMetadata outputMetadata = metadataExtractor.getOutputTensorMetadata(2);
- + assertThat(outputMetadata.content().contentPropertiesType())
- + .isEqualTo(CONTENT_PROPERTIES_TYPE);
- + }
- +
- + @Test
- + public void getInputTensorMetadata_indexGreaterThanTensorNumber_throwsException()
- + throws Exception {
- + // Creates a model flatbuffer with metadata.
- + ByteBuffer modelWithMetadata = createModelByteBuffer();
- +
- + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
- + () -> metadataExtractor.getInputTensorMetadata(3));
- + assertThat(exception).hasMessageThat().contains("The inputIndex specified is invalid.");
- + }
- +
- + @Test
- + public void getInputTensorMetadata_negtiveIndex_throwsException() throws Exception {
- + // Creates a model flatbuffer with metadata.
- + ByteBuffer modelWithMetadata = createModelByteBuffer();
- +
- + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
- + () -> metadataExtractor.getInputTensorMetadata(-1));
- + assertThat(exception).hasMessageThat().contains("The inputIndex specified is invalid.");
- + }
- +
- + @Test
- + public void getOutputTensorMetadata_indexGreaterThanTensorNumber_throwsException()
- + throws Exception {
- + // Creates a model flatbuffer with metadata.
- + ByteBuffer modelWithMetadata = createModelByteBuffer();
- +
- + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
- + () -> metadataExtractor.getOutputTensorMetadata(3));
- + assertThat(exception).hasMessageThat().contains(
- + "The outputIndex specified is invalid.");
- + }
- +
- + @Test
- + public void getOutputTensorMetadata_negtiveIndex_throwsException() throws Exception {
- + // Creates a model flatbuffer with metadata.
- + ByteBuffer modelWithMetadata = createModelByteBuffer();
- +
- + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
- + () -> metadataExtractor.getOutputTensorMetadata(-1));
- + assertThat(exception).hasMessageThat().contains(
- + "The outputIndex specified is invalid.");
- + }
- +
- + @Test
- + public void getInputTensorQuantizationParams_validScaleAndZeroPoint() throws Exception {
- + // Creates a model flatbuffer with metadata.
- + ByteBuffer modelWithMetadata = createModelByteBuffer();
- +
- + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- + QuantizationParams quantizationParams =
- + metadataExtractor.getInputTensorQuantizationParams(0);
- + assertThat(quantizationParams.getScale()).isEqualTo(VALID_SCALE);
- + assertThat(quantizationParams.getZeroPoint()).isEqualTo(VALID_ZERO_POINT);
- + }
- +
- + @Test
- + public void getInputTensorQuantizationParams_emptyTensor() throws Exception {
- + // Creates a model flatbuffer with metadata.
- + ByteBuffer modelWithMetadata = createModelByteBuffer();
- +
- + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- + QuantizationParams quantizationParams =
- + metadataExtractor.getInputTensorQuantizationParams(1);
- + // Scale and zero point are expected to be 1.0f and 0, respectively as default.
- + assertThat(quantizationParams.getScale()).isEqualTo(DEFAULT_SCALE);
- + assertThat(quantizationParams.getZeroPoint()).isEqualTo(DEFAULT_ZERO_POINT);
- + }
- +
- + @Test
- + public void getInputTensorQuantizationParams_invalidScale() throws Exception {
- + // Creates a model flatbuffer with metadata.
- + ByteBuffer modelWithMetadata = createModelByteBuffer();
- +
- + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
- + () -> metadataExtractor.getInputTensorQuantizationParams(2));
- + assertThat(exception).hasMessageThat().contains(
- + "Input and output tensors do not support per-channel quantization.");
- + }
- +
- + @Test
- + public void getOutputTensorQuantizationParams_validScaleAndZeroPoint() throws Exception {
- + // Creates a model flatbuffer with metadata.
- + ByteBuffer modelWithMetadata = createModelByteBuffer();
- +
- + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- + QuantizationParams quantizationParams =
- + metadataExtractor.getOutputTensorQuantizationParams(0);
- + assertThat(quantizationParams.getScale()).isEqualTo(VALID_SCALE);
- + assertThat(quantizationParams.getZeroPoint()).isEqualTo(VALID_ZERO_POINT);
- + }
- +
- + @Test
- + public void getOutputTensorQuantizationParams_emptyTensor() throws Exception {
- + // Creates a model flatbuffer with metadata.
- + ByteBuffer modelWithMetadata = createModelByteBuffer();
- +
- + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- + QuantizationParams quantizationParams =
- + metadataExtractor.getOutputTensorQuantizationParams(1);
- + // Scale and zero point are expected to be 1.0f and 0, respectively as default.
- + assertThat(quantizationParams.getScale()).isEqualTo(DEFAULT_SCALE);
- + assertThat(quantizationParams.getZeroPoint()).isEqualTo(DEFAULT_ZERO_POINT);
- + }
- +
- + @Test
- + public void getOutputTensorQuantizationParams_invalidScale() throws Exception {
- + // Creates a model flatbuffer with metadata.
- + ByteBuffer modelWithMetadata = createModelByteBuffer();
- +
- + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
- + () -> metadataExtractor.getOutputTensorQuantizationParams(2));
- + assertThat(exception).hasMessageThat().contains(
- + "Input and output tensors do not support per-channel quantization.");
- + }
- +
- + @Test
- + public void isMinimumParserVersionSatisfied_olderVersion() throws Exception {
- + // A version older than the current one. The version starts from 1.0.0, thus 0.10.0 will
- + // precede any furture versions.
- + String minVersion = "0.10";
- + // Creates a metadata using the above version.
- + ByteBuffer metadata = createMetadataByteBuffer(TFLITE_METADATA_IDENTIFIER, minVersion);
- + ByteBuffer modelWithMetadata = createModelByteBuffer(metadata, DATA_TYPE);
- +
- + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- +
- + assertThat(metadataExtractor.isMinimumParserVersionSatisfied()).isTrue();
- + }
- +
- + @Test
- + public void isMinimumParserVersionSatisfied_sameVersionSamelength() throws Exception {
- + // A version the same as the current one.
- + String minVersion = MetadataParser.VERSION;
- + // Creates a metadata using the above version.
- + ByteBuffer metadata = createMetadataByteBuffer(TFLITE_METADATA_IDENTIFIER, minVersion);
- + ByteBuffer modelWithMetadata = createModelByteBuffer(metadata, DATA_TYPE);
- +
- + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- +
- + assertThat(metadataExtractor.isMinimumParserVersionSatisfied()).isTrue();
- + }
- +
- + @Test
- + public void isMinimumParserVersionSatisfied_sameVersionLongerlength() throws Exception {
- + // A version the same as the current one, but with longer length.
- + String minVersion = MetadataParser.VERSION + ".0";
- + // Creates a metadata using the above version.
- + ByteBuffer metadata = createMetadataByteBuffer(TFLITE_METADATA_IDENTIFIER, minVersion);
- + ByteBuffer modelWithMetadata = createModelByteBuffer(metadata, DATA_TYPE);
- +
- + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- +
- + assertThat(metadataExtractor.isMinimumParserVersionSatisfied()).isTrue();
- + }
- +
- + @Test
- + public void isMinimumParserVersionSatisfied_emptyVersion() throws Exception {
- + // An empty version, which can be generated before the first versioned release.
- + String minVersion = null;
- + // Creates a metadata using the above version.
- + ByteBuffer metadata = createMetadataByteBuffer(TFLITE_METADATA_IDENTIFIER, minVersion);
- + ByteBuffer modelWithMetadata = createModelByteBuffer(metadata, DATA_TYPE);
- +
- + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- +
- + assertThat(metadataExtractor.isMinimumParserVersionSatisfied()).isTrue();
- + }
- +
- + @Test
- + public void isMinimumParserVersionSatisfied_newerVersion() throws Exception {
- + // Creates a version newer than the current one by appending "1" to the end of the
- + // current version for testing purposes. For example, 1.0.0 becomes 1.0.01.
- + String minVersion = MetadataParser.VERSION + "1";
- + // Creates a metadata using the above version.
- + ByteBuffer metadata = createMetadataByteBuffer(TFLITE_METADATA_IDENTIFIER, minVersion);
- + ByteBuffer modelWithMetadata = createModelByteBuffer(metadata, DATA_TYPE);
- +
- + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- +
- + assertThat(metadataExtractor.isMinimumParserVersionSatisfied()).isFalse();
- + }
- +
- + @Test
- + public void isMinimumParserVersionSatisfied_newerVersionLongerLength() throws Exception {
- + // Creates a version newer than the current one by appending ".1" to the end of the
- + // current version for testing purposes. For example, 1.0.0 becomes 1.0.0.1.
- + String minVersion = MetadataParser.VERSION + ".1";
- + // Creates a metadata using the above version.
- + ByteBuffer metadata = createMetadataByteBuffer(TFLITE_METADATA_IDENTIFIER, minVersion);
- + ByteBuffer modelWithMetadata = createModelByteBuffer(metadata, DATA_TYPE);
- +
- + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- +
- + assertThat(metadataExtractor.isMinimumParserVersionSatisfied()).isFalse();
- + }
- + }
- +
- + /** Parameterized tests for the input tensor data type. */
- + @RunWith(ParameterizedRobolectricTestRunner.class)
- + public static final class InputTensorType extends MetadataExtractorTest {
- + /** The tensor type that used to create the model buffer. */
- + @Parameter(0)
- + public byte tensorType;
- +
- + /** A list of TensorType that is used in the test. */
- + @Parameters
- + public static Collection<Object[]> data() {
- + return Arrays.asList(new Object[][] {{TensorType.FLOAT32}, {TensorType.INT32},
- + {TensorType.UINT8}, {TensorType.INT64}, {TensorType.STRING}});
- + }
- +
- + @Test
- + public void getInputTensorType_validTensor() throws Exception {
- + ByteBuffer metadata = createMetadataByteBuffer(TFLITE_METADATA_IDENTIFIER, null);
- + ByteBuffer modelWithMetadata = createModelByteBuffer(metadata, tensorType);
- + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- + byte type = metadataExtractor.getInputTensorType(0);
- + assertThat(type).isEqualTo(tensorType);
- + }
- +
- + @Test
- + public void getOutputTensorType_validTensor() throws Exception {
- + ByteBuffer metadata = createMetadataByteBuffer(TFLITE_METADATA_IDENTIFIER, null);
- + ByteBuffer modelWithMetadata = createModelByteBuffer(metadata, tensorType);
- + MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- + byte type = metadataExtractor.getOutputTensorType(0);
- + assertThat(type).isEqualTo(tensorType);
- + }
- + }
- +
- + /**
- + * Creates an example metadata flatbuffer, which contains one subgraph with three inputs and
- + * three outputs.
- + */
- + private static ByteBuffer createMetadataByteBuffer(
- + String identifier, @Nullable String minVersionStr) {
- + FlatBufferBuilder builder = new FlatBufferBuilder();
- +
- + Content.startContent(builder);
- + Content.addContentPropertiesType(builder, CONTENT_PROPERTIES_TYPE);
- + int content = Content.endContent(builder);
- +
- + TensorMetadata.startTensorMetadata(builder);
- + TensorMetadata.addContent(builder, content);
- + int metadataForValidTensor = TensorMetadata.endTensorMetadata(builder);
- +
- + TensorMetadata.startTensorMetadata(builder);
- + int metadataForEmptyTensor = TensorMetadata.endTensorMetadata(builder);
- +
- + TensorMetadata.startTensorMetadata(builder);
- + TensorMetadata.addContent(builder, content);
- + int metadataForInvalidTensor = TensorMetadata.endTensorMetadata(builder);
- +
- + int[] tensorMetadataArray = new int[] {
- + metadataForValidTensor, metadataForEmptyTensor, metadataForInvalidTensor};
- + int inputTensorMetadata =
- + SubGraphMetadata.createInputTensorMetadataVector(builder, tensorMetadataArray);
- + int outputTensorMetadata =
- + SubGraphMetadata.createOutputTensorMetadataVector(builder, tensorMetadataArray);
- +
- + SubGraphMetadata.startSubGraphMetadata(builder);
- + SubGraphMetadata.addInputTensorMetadata(builder, inputTensorMetadata);
- + SubGraphMetadata.addOutputTensorMetadata(builder, outputTensorMetadata);
- + int subgraph1Metadata = SubGraphMetadata.endSubGraphMetadata(builder);
- +
- + int[] subgraphMetadataArray = new int[] {subgraph1Metadata};
- + int subgraphsMetadata =
- + ModelMetadata.createSubgraphMetadataVector(builder, subgraphMetadataArray);
- +
- + int modelName = builder.createString(MODEL_NAME);
- + if (minVersionStr != null) {
- + int minVersion = builder.createString(minVersionStr);
- + ModelMetadata.startModelMetadata(builder);
- + ModelMetadata.addMinParserVersion(builder, minVersion);
- + } else {
- + // If minVersionStr is null, skip generating the field in the metadata.
- + ModelMetadata.startModelMetadata(builder);
- + }
- + ModelMetadata.addName(builder, modelName);
- + ModelMetadata.addSubgraphMetadata(builder, subgraphsMetadata);
- + int modelMetadata = ModelMetadata.endModelMetadata(builder);
- +
- + builder.finish(modelMetadata, identifier);
- + return builder.dataBuffer();
- + }
- +
- + private static int createQuantizationParameters(
- + FlatBufferBuilder builder, float[] scale, long[] zeroPoint) {
- + int inputScale = QuantizationParameters.createScaleVector(builder, scale);
- + int inputZeroPoint = QuantizationParameters.createZeroPointVector(builder, zeroPoint);
- + QuantizationParameters.startQuantizationParameters(builder);
- + QuantizationParameters.addScale(builder, inputScale);
- + QuantizationParameters.addZeroPoint(builder, inputZeroPoint);
- + return QuantizationParameters.endQuantizationParameters(builder);
- + }
- +
- + private static int createTensor(
- + FlatBufferBuilder builder, int[] inputShape, byte inputType, int inputQuantization) {
- + int inputShapeVector1 = Tensor.createShapeVector(builder, inputShape);
- + Tensor.startTensor(builder);
- + Tensor.addShape(builder, inputShapeVector1);
- + Tensor.addType(builder, inputType);
- + Tensor.addQuantization(builder, inputQuantization);
- + return Tensor.endTensor(builder);
- + }
- +
- + /**
- + * Creates an example model flatbuffer, which contains one subgraph with three inputs and three
- + * output.
- + */
- + private static ByteBuffer createModelByteBuffer(ByteBuffer metadataBuffer, byte dataType) {
- + FlatBufferBuilder builder = new FlatBufferBuilder();
- +
- + // Creates a valid set of quantization parameters.
- + int validQuantization = createQuantizationParameters(
- + builder, new float[] {VALID_SCALE}, new long[] {VALID_ZERO_POINT});
- +
- + // Creates an invalid set of quantization parameters.
- + int inValidQuantization =
- + createQuantizationParameters(builder, invalidScale, invalidZeroPoint);
- +
- + // Creates an input Tensor with valid quantization parameters.
- + int validTensor = createTensor(builder, validShape, dataType, validQuantization);
- +
- + // Creates an empty input Tensor.
- + Tensor.startTensor(builder);
- + int emptyTensor = Tensor.endTensor(builder);
- +
- + // Creates an input Tensor with invalid quantization parameters.
- + int invalidTensor = createTensor(builder, validShape, dataType, inValidQuantization);
- +
- + // Creates the SubGraph.
- + int[] tensors = new int[6];
- + tensors[0] = validTensor;
- + tensors[1] = emptyTensor;
- + tensors[2] = invalidTensor;
- + tensors[3] = validTensor;
- + tensors[4] = emptyTensor;
- + tensors[5] = invalidTensor;
- + int subgraphTensors = SubGraph.createTensorsVector(builder, tensors);
- +
- + int subgraphInputs = SubGraph.createInputsVector(builder, new int[] {0, 1, 2});
- + int subgraphOutputs = SubGraph.createOutputsVector(builder, new int[] {3, 4, 5});
- +
- + SubGraph.startSubGraph(builder);
- + SubGraph.addTensors(builder, subgraphTensors);
- + SubGraph.addInputs(builder, subgraphInputs);
- + SubGraph.addOutputs(builder, subgraphOutputs);
- + int subgraph = SubGraph.endSubGraph(builder);
- +
- + // Creates the Model.
- + int[] subgraphs = new int[1];
- + subgraphs[0] = subgraph;
- + int modelSubgraphs = Model.createSubgraphsVector(builder, subgraphs);
- +
- + // Inserts metadataBuffer into the model if it's not null.
- + int modelBuffers = EMPTY_FLATBUFFER_VECTOR;
- + int metadataArray = EMPTY_FLATBUFFER_VECTOR;
- + if (metadataBuffer != null) {
- + int data = Buffer.createDataVector(builder, metadataBuffer);
- + Buffer.startBuffer(builder);
- + Buffer.addData(builder, data);
- + int buffer = Buffer.endBuffer(builder);
- + modelBuffers = Model.createBuffersVector(builder, new int[] {buffer});
- +
- + int metadataName = builder.createString(ModelInfo.METADATA_FIELD_NAME);
- + Metadata.startMetadata(builder);
- + Metadata.addName(builder, metadataName);
- + Metadata.addBuffer(builder, 0);
- + int metadata = Metadata.endMetadata(builder);
- + metadataArray = Model.createMetadataVector(builder, new int[] {metadata});
- + }
- +
- + Model.startModel(builder);
- + Model.addSubgraphs(builder, modelSubgraphs);
- + if (modelBuffers != EMPTY_FLATBUFFER_VECTOR && metadataArray != EMPTY_FLATBUFFER_VECTOR) {
- + Model.addBuffers(builder, modelBuffers);
- + Model.addMetadata(builder, metadataArray);
- + }
- + int model = Model.endModel(builder);
- + builder.finish(model, TFLITE_MODEL_IDENTIFIER);
- +
- + return builder.dataBuffer();
- + }
- +
- + /** Creates an example model flatbuffer with the default metadata and data type. */
- + private static ByteBuffer createModelByteBuffer() {
- + ByteBuffer metadata =
- + createMetadataByteBuffer(TFLITE_METADATA_IDENTIFIER, /*minVersionStr=*/null);
- + return createModelByteBuffer(metadata, DATA_TYPE);
- + }
- +
- + private static ByteBuffer loadMobileNetBuffer() throws Exception {
- + Context context = ApplicationProvider.getApplicationContext();
- + // Loads a MobileNet model flatbuffer with metadata. The MobileNet model is a zip file that
- + // contains a label file as the associated file.
- + AssetFileDescriptor fileDescriptor = context.getAssets().openFd(MODEL_PATH);
- + FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
- + FileChannel fileChannel = inputStream.getChannel();
- + long startOffset = fileDescriptor.getStartOffset();
- + long declaredLength = fileDescriptor.getDeclaredLength();
- + return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
- + }
- +
- + private static ByteBuffer createRandomByteBuffer() {
- + byte[] buffer = new byte[20];
- + new Random().nextBytes(buffer);
- + return ByteBuffer.wrap(buffer);
- }
- - int model = Model.endModel(builder);
- - builder.finish(model, TFLITE_MODEL_IDENTIFIER);
- -
- - return builder.dataBuffer();
- - }
- -
- - /** Creates an example model flatbuffer with the default metadata and data type. */
- - private static ByteBuffer createModelByteBuffer() {
- - ByteBuffer metadata =
- - createMetadataByteBuffer(TFLITE_METADATA_IDENTIFIER, /*minVersionStr=*/ null);
- - return createModelByteBuffer(metadata, DATA_TYPE);
- - }
- -
- - private static ByteBuffer loadMobileNetBuffer() throws Exception {
- - Context context = ApplicationProvider.getApplicationContext();
- - // Loads a MobileNet model flatbuffer with metadata. The MobileNet model is a zip file that
- - // contains a label file as the associated file.
- - AssetFileDescriptor fileDescriptor = context.getAssets().openFd(MODEL_PATH);
- - FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
- - FileChannel fileChannel = inputStream.getChannel();
- - long startOffset = fileDescriptor.getStartOffset();
- - long declaredLength = fileDescriptor.getDeclaredLength();
- - return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
- - }
- -
- - private static ByteBuffer createRandomByteBuffer() {
- - byte[] buffer = new byte[20];
- - new Random().nextBytes(buffer);
- - return ByteBuffer.wrap(buffer);
- - }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/MetadataParserTest.java b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/MetadataParserTest.java
- index a47566fec06e9..eede6750ea479 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/MetadataParserTest.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/MetadataParserTest.java
- @@ -17,20 +17,20 @@ package org.tensorflow.lite.support.metadata;
-
- import static com.google.common.truth.Truth.assertThat;
-
- -import java.util.regex.Pattern;
- import org.junit.Test;
- import org.junit.runner.RunWith;
- import org.junit.runners.JUnit4;
-
- +import java.util.regex.Pattern;
- +
- /** Tests of {@link MetadataParser}. */
- @RunWith(JUnit4.class)
- public final class MetadataParserTest {
- -
- - @Test
- - public void version_wellFormedAsSemanticVersion() throws Exception {
- - // Validates that the version is well-formed (x.y.z).
- - String pattern = "[0-9]+\\.[0-9]+\\.[0-9]+";
- - Pattern r = Pattern.compile(pattern);
- - assertThat(MetadataParser.VERSION).matches(r);
- - }
- + @Test
- + public void version_wellFormedAsSemanticVersion() throws Exception {
- + // Validates that the version is well-formed (x.y.z).
- + String pattern = "[0-9]+\\.[0-9]+\\.[0-9]+";
- + Pattern r = Pattern.compile(pattern);
- + assertThat(MetadataParser.VERSION).matches(r);
- + }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/ZipFileTest.java b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/ZipFileTest.java
- index 61231e902e03e..80d2ddc6fd34e 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/ZipFileTest.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/ZipFileTest.java
- @@ -16,11 +16,20 @@ limitations under the License.
- package org.tensorflow.lite.support.metadata;
-
- import static com.google.common.truth.Truth.assertThat;
- +
- import static org.junit.Assert.assertThrows;
-
- import android.content.Context;
- import android.content.res.AssetFileDescriptor;
- +
- import androidx.test.core.app.ApplicationProvider;
- +
- +import org.apache.commons.io.IOUtils;
- +import org.junit.Ignore;
- +import org.junit.Test;
- +import org.junit.runner.RunWith;
- +import org.robolectric.RobolectricTestRunner;
- +
- import java.io.FileInputStream;
- import java.io.InputStream;
- import java.nio.ByteBuffer;
- @@ -28,113 +37,102 @@ import java.nio.channels.FileChannel;
- import java.util.HashSet;
- import java.util.Set;
- import java.util.zip.ZipException;
- -import org.apache.commons.io.IOUtils;
- -import org.junit.Test;
- -import org.junit.runner.RunWith;
- -import org.robolectric.RobolectricTestRunner;
- -
- -import org.junit.Ignore;
-
- /** Tests of {@link ZipFile}. */
- @RunWith(RobolectricTestRunner.class)
- public final class ZipFileTest {
- -
- - // The TFLite model file is a zip file.
- - private static final String MODEL_PATH = "mobilenet_v1_1.0_224_quant.tflite";
- - // labels.txt is packed in mobilenet_v1_1.0_224_quant.tflite as an associated file.
- - private static final String VALID_LABEL_FILE_NAME = "labels.txt";
- - // invalid.txt is not packed in mobilenet_v1_1.0_224_quant.tflite.
- - private static final String INVALID_LABEL_FILE_NAME = "invalid.txt";
- - private final Context context = ApplicationProvider.getApplicationContext();
- -
- - @Test
- - public void zipFile_nullChannel_throwsException() throws Exception {
- - NullPointerException exception =
- - assertThrows(NullPointerException.class, () -> ZipFile.createFrom(null));
- - assertThat(exception).hasMessageThat().isEqualTo("The object reference is null.");
- - }
- -
- - @Test
- - public void zipFile_invalidFileWithExtremeSmallSize_throwsException() throws Exception {
- - // The size limit for a zip file is the End head size, ZipConstant.ENDHDR, which is 22.
- - ByteBuffer modelBuffer = ByteBuffer.allocate(21);
- - ByteBufferChannel modelChannel = new ByteBufferChannel(modelBuffer);
- -
- - ZipException exception =
- - assertThrows(ZipException.class, () -> ZipFile.createFrom(modelChannel));
- - assertThat(exception).hasMessageThat().isEqualTo("The archive is not a ZIP archive.");
- - }
- -
- - @Test
- - public void zipFile_invalidFileWithNoSignature_throwsException() throws Exception {
- - // An invalid zip file that meets the size requirement but does not contain the zip signature.
- - ByteBuffer modelBuffer = ByteBuffer.allocate(22);
- - ByteBufferChannel modelChannel = new ByteBufferChannel(modelBuffer);
- -
- - ZipException exception =
- - assertThrows(ZipException.class, () -> ZipFile.createFrom(modelChannel));
- - assertThat(exception).hasMessageThat().isEqualTo("The archive is not a ZIP archive.");
- - }
- -
- - @Ignore
- - @Test
- - public void getFileNames_correctFileName() throws Exception {
- - ByteBufferChannel modelChannel = loadModel(MODEL_PATH);
- - ZipFile zipFile = ZipFile.createFrom(modelChannel);
- - Set<String> expectedSet = new HashSet<>();
- - expectedSet.add(VALID_LABEL_FILE_NAME);
- - assertThat(zipFile.getFileNames()).isEqualTo(expectedSet);
- - }
- -
- - @Ignore
- - @Test
- - public void getRawInputStream_existentFile() throws Exception {
- - ByteBufferChannel modelChannel = loadModel(MODEL_PATH);
- - ZipFile zipFile = ZipFile.createFrom(modelChannel);
- - InputStream fileStream = zipFile.getRawInputStream(VALID_LABEL_FILE_NAME);
- -
- - // Reads the golden file from context.
- - InputStream goldenFileStream = context.getAssets().open(VALID_LABEL_FILE_NAME);
- - assertThat(IOUtils.contentEquals(goldenFileStream, fileStream)).isTrue();
- - }
- -
- - @Ignore
- - @Test
- - public void getRawInputStream_nonExistentFile() throws Exception {
- - ByteBufferChannel modelChannel = loadModel(MODEL_PATH);
- - ZipFile zipFile = ZipFile.createFrom(modelChannel);
- -
- - IllegalArgumentException exception =
- - assertThrows(
- - IllegalArgumentException.class,
- - () -> zipFile.getRawInputStream(INVALID_LABEL_FILE_NAME));
- - assertThat(exception)
- - .hasMessageThat()
- - .isEqualTo(
- - String.format(
- + // The TFLite model file is a zip file.
- + private static final String MODEL_PATH = "mobilenet_v1_1.0_224_quant.tflite";
- + // labels.txt is packed in mobilenet_v1_1.0_224_quant.tflite as an associated file.
- + private static final String VALID_LABEL_FILE_NAME = "labels.txt";
- + // invalid.txt is not packed in mobilenet_v1_1.0_224_quant.tflite.
- + private static final String INVALID_LABEL_FILE_NAME = "invalid.txt";
- + private final Context context = ApplicationProvider.getApplicationContext();
- +
- + @Test
- + public void zipFile_nullChannel_throwsException() throws Exception {
- + NullPointerException exception =
- + assertThrows(NullPointerException.class, () -> ZipFile.createFrom(null));
- + assertThat(exception).hasMessageThat().isEqualTo("The object reference is null.");
- + }
- +
- + @Test
- + public void zipFile_invalidFileWithExtremeSmallSize_throwsException() throws Exception {
- + // The size limit for a zip file is the End head size, ZipConstant.ENDHDR, which is 22.
- + ByteBuffer modelBuffer = ByteBuffer.allocate(21);
- + ByteBufferChannel modelChannel = new ByteBufferChannel(modelBuffer);
- +
- + ZipException exception =
- + assertThrows(ZipException.class, () -> ZipFile.createFrom(modelChannel));
- + assertThat(exception).hasMessageThat().isEqualTo("The archive is not a ZIP archive.");
- + }
- +
- + @Test
- + public void zipFile_invalidFileWithNoSignature_throwsException() throws Exception {
- + // An invalid zip file that meets the size requirement but does not contain the zip
- + // signature.
- + ByteBuffer modelBuffer = ByteBuffer.allocate(22);
- + ByteBufferChannel modelChannel = new ByteBufferChannel(modelBuffer);
- +
- + ZipException exception =
- + assertThrows(ZipException.class, () -> ZipFile.createFrom(modelChannel));
- + assertThat(exception).hasMessageThat().isEqualTo("The archive is not a ZIP archive.");
- + }
- +
- + @Ignore
- + @Test
- + public void getFileNames_correctFileName() throws Exception {
- + ByteBufferChannel modelChannel = loadModel(MODEL_PATH);
- + ZipFile zipFile = ZipFile.createFrom(modelChannel);
- + Set<String> expectedSet = new HashSet<>();
- + expectedSet.add(VALID_LABEL_FILE_NAME);
- + assertThat(zipFile.getFileNames()).isEqualTo(expectedSet);
- + }
- +
- + @Ignore
- + @Test
- + public void getRawInputStream_existentFile() throws Exception {
- + ByteBufferChannel modelChannel = loadModel(MODEL_PATH);
- + ZipFile zipFile = ZipFile.createFrom(modelChannel);
- + InputStream fileStream = zipFile.getRawInputStream(VALID_LABEL_FILE_NAME);
- +
- + // Reads the golden file from context.
- + InputStream goldenFileStream = context.getAssets().open(VALID_LABEL_FILE_NAME);
- + assertThat(IOUtils.contentEquals(goldenFileStream, fileStream)).isTrue();
- + }
- +
- + @Ignore
- + @Test
- + public void getRawInputStream_nonExistentFile() throws Exception {
- + ByteBufferChannel modelChannel = loadModel(MODEL_PATH);
- + ZipFile zipFile = ZipFile.createFrom(modelChannel);
- +
- + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
- + () -> zipFile.getRawInputStream(INVALID_LABEL_FILE_NAME));
- + assertThat(exception).hasMessageThat().isEqualTo(String.format(
- "The file, %s, does not exist in the zip file.", INVALID_LABEL_FILE_NAME));
- - }
- -
- - @Ignore
- - @Test
- - public void close_validStatus() throws Exception {
- - ByteBufferChannel modelChannel = loadModel(MODEL_PATH);
- - ZipFile zipFile = ZipFile.createFrom(modelChannel);
- - // Should do nothing (including not throwing an exception).
- - zipFile.close();
- - }
- -
- - private static ByteBufferChannel loadModel(String modelPath) throws Exception {
- - // Creates a ZipFile with a TFLite model flatbuffer with metadata. The MobileNet
- - // model is a zip file that contains a label file as the associated file.
- - Context context = ApplicationProvider.getApplicationContext();
- - AssetFileDescriptor fileDescriptor = context.getAssets().openFd(modelPath);
- - FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
- - FileChannel fileChannel = inputStream.getChannel();
- - long startOffset = fileDescriptor.getStartOffset();
- - long declaredLength = fileDescriptor.getDeclaredLength();
- - ByteBuffer modelBuffer =
- - fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
- - return new ByteBufferChannel(modelBuffer);
- - }
- + }
- +
- + @Ignore
- + @Test
- + public void close_validStatus() throws Exception {
- + ByteBufferChannel modelChannel = loadModel(MODEL_PATH);
- + ZipFile zipFile = ZipFile.createFrom(modelChannel);
- + // Should do nothing (including not throwing an exception).
- + zipFile.close();
- + }
- +
- + private static ByteBufferChannel loadModel(String modelPath) throws Exception {
- + // Creates a ZipFile with a TFLite model flatbuffer with metadata. The MobileNet
- + // model is a zip file that contains a label file as the associated file.
- + Context context = ApplicationProvider.getApplicationContext();
- + AssetFileDescriptor fileDescriptor = context.getAssets().openFd(modelPath);
- + FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
- + FileChannel fileChannel = inputStream.getChannel();
- + long startOffset = fileDescriptor.getStartOffset();
- + long declaredLength = fileDescriptor.getDeclaredLength();
- + ByteBuffer modelBuffer =
- + fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
- + return new ByteBufferChannel(modelBuffer);
- + }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/ios/image/apis/GMLImage.h b/third_party/tflite_support/src/tensorflow_lite_support/odml/ios/image/apis/GMLImage.h
- index 110186bb63a1b..18797d8135eb8 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/odml/ios/image/apis/GMLImage.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/ios/image/apis/GMLImage.h
- @@ -19,7 +19,8 @@
- NS_ASSUME_NONNULL_BEGIN
-
- /** Types of image sources. */
- -typedef NSInteger GMLImageSourceType NS_TYPED_ENUM NS_SWIFT_NAME(MLImageSourceType);
- +typedef NSInteger GMLImageSourceType
- + NS_TYPED_ENUM NS_SWIFT_NAME(MLImageSourceType);
- /** Image source is a `UIImage`. */
- static const GMLImageSourceType GMLImageSourceTypeImage = 0;
- /** Image source is a `CVPixelBuffer`. */
- @@ -38,8 +39,9 @@ NS_SWIFT_NAME(MLImage)
- @property(nonatomic, readonly) CGFloat height;
-
- /**
- - * The display orientation of the image. If `imageSourceType` is `.image`, the default value is
- - * `image.imageOrientation`; otherwise the default value is `.up`.
- + * The display orientation of the image. If `imageSourceType` is `.image`, the
- + * default value is `image.imageOrientation`; otherwise the default value is
- + * `.up`.
- */
- @property(nonatomic) UIImageOrientation orientation;
-
- @@ -47,30 +49,34 @@ NS_SWIFT_NAME(MLImage)
- @property(nonatomic, readonly) GMLImageSourceType imageSourceType;
-
- /** The source image. `nil` if `imageSourceType` is not `.image`. */
- -@property(nonatomic, readonly, nullable) UIImage *image;
- +@property(nonatomic, readonly, nullable) UIImage* image;
-
- -/** The source pixel buffer. `nil` if `imageSourceType` is not `.pixelBuffer`. */
- +/** The source pixel buffer. `nil` if `imageSourceType` is not `.pixelBuffer`.
- + */
- @property(nonatomic, readonly, nullable) CVPixelBufferRef pixelBuffer;
-
- -/** The source sample buffer. `nil` if `imageSourceType` is not `.sampleBuffer`. */
- +/** The source sample buffer. `nil` if `imageSourceType` is not `.sampleBuffer`.
- + */
- @property(nonatomic, readonly, nullable) CMSampleBufferRef sampleBuffer;
-
- /**
- * Initializes an `MLImage` object with the given image.
- *
- - * @param image The image to use as the source. Its `CGImage` property must not be `NULL`.
- - * @return A new `MLImage` instance with the given image as the source. `nil` if the given `image`
- - * is `nil` or invalid.
- + * @param image The image to use as the source. Its `CGImage` property must not
- + * be `NULL`.
- + * @return A new `MLImage` instance with the given image as the source. `nil` if
- + * the given `image` is `nil` or invalid.
- */
- -- (nullable instancetype)initWithImage:(UIImage *)image NS_DESIGNATED_INITIALIZER;
- +- (nullable instancetype)initWithImage:(UIImage*)image
- + NS_DESIGNATED_INITIALIZER;
-
- /**
- * Initializes an `MLImage` object with the given pixel buffer.
- *
- - * @param pixelBuffer The pixel buffer to use as the source. It will be retained by the new
- - * `MLImage` instance for the duration of its lifecycle.
- - * @return A new `MLImage` instance with the given pixel buffer as the source. `nil` if the given
- - * pixel buffer is `nil` or invalid.
- + * @param pixelBuffer The pixel buffer to use as the source. It will be retained
- + * by the new `MLImage` instance for the duration of its lifecycle.
- + * @return A new `MLImage` instance with the given pixel buffer as the source.
- + * `nil` if the given pixel buffer is `nil` or invalid.
- */
- - (nullable instancetype)initWithPixelBuffer:(CVPixelBufferRef)pixelBuffer
- NS_DESIGNATED_INITIALIZER;
- @@ -78,12 +84,13 @@ NS_SWIFT_NAME(MLImage)
- /**
- * Initializes an `MLImage` object with the given sample buffer.
- *
- - * @param sampleBuffer The sample buffer to use as the source. It will be retained by the new
- - * `MLImage` instance for the duration of its lifecycle. The sample buffer must be based on a
- - * pixel buffer (not compressed data). In practice, it should be the video output of the camera
- - * on an iOS device, not other arbitrary types of `CMSampleBuffer`s.
- - * @return A new `MLImage` instance with the given sample buffer as the source. `nil` if the given
- - * sample buffer is `nil` or invalid.
- + * @param sampleBuffer The sample buffer to use as the source. It will be
- + * retained by the new `MLImage` instance for the duration of its lifecycle. The
- + * sample buffer must be based on a pixel buffer (not compressed data). In
- + * practice, it should be the video output of the camera on an iOS device, not
- + * other arbitrary types of `CMSampleBuffer`s.
- + * @return A new `MLImage` instance with the given sample buffer as the source.
- + * `nil` if the given sample buffer is `nil` or invalid.
- */
- - (nullable instancetype)initWithSampleBuffer:(CMSampleBufferRef)sampleBuffer
- NS_DESIGNATED_INITIALIZER;
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/BitmapExtractor.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/BitmapExtractor.java
- index a32fc24749e0c..59116a72a0533 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/BitmapExtractor.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/BitmapExtractor.java
- @@ -24,28 +24,27 @@ import android.graphics.Bitmap;
- * {@link IllegalArgumentException} will be thrown.
- */
- public final class BitmapExtractor {
- -
- - /**
- - * Extracts a {@link android.graphics.Bitmap} from an {@link MlImage}.
- - *
- - * <p>Notice: Properties of the {@code image} like rotation will not take effects.
- - *
- - * @param image the image to extract {@link android.graphics.Bitmap} from.
- - * @return the {@link android.graphics.Bitmap} stored in {@link MlImage}
- - * @throws IllegalArgumentException when the extraction requires unsupported format or data type
- - * conversions.
- - */
- - public static Bitmap extract(MlImage image) {
- - ImageContainer imageContainer = image.getContainer(MlImage.STORAGE_TYPE_BITMAP);
- - if (imageContainer != null) {
- - return ((BitmapImageContainer) imageContainer).getBitmap();
- - } else {
- - // TODO(b/180504869): Support ByteBuffer -> Bitmap conversion.
- - throw new IllegalArgumentException(
- - "Extracting Bitmap from an MlImage created by objects other than Bitmap is not"
- - + " supported");
- + /**
- + * Extracts a {@link android.graphics.Bitmap} from an {@link MlImage}.
- + *
- + * <p>Notice: Properties of the {@code image} like rotation will not take effects.
- + *
- + * @param image the image to extract {@link android.graphics.Bitmap} from.
- + * @return the {@link android.graphics.Bitmap} stored in {@link MlImage}
- + * @throws IllegalArgumentException when the extraction requires unsupported format or data type
- + * conversions.
- + */
- + public static Bitmap extract(MlImage image) {
- + ImageContainer imageContainer = image.getContainer(MlImage.STORAGE_TYPE_BITMAP);
- + if (imageContainer != null) {
- + return ((BitmapImageContainer) imageContainer).getBitmap();
- + } else {
- + // TODO(b/180504869): Support ByteBuffer -> Bitmap conversion.
- + throw new IllegalArgumentException(
- + "Extracting Bitmap from an MlImage created by objects other than Bitmap is not"
- + + " supported");
- + }
- }
- - }
-
- - private BitmapExtractor() {}
- + private BitmapExtractor() {}
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/BitmapImageContainer.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/BitmapImageContainer.java
- index 77e63f0351449..b1b02f8e369ec 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/BitmapImageContainer.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/BitmapImageContainer.java
- @@ -16,44 +16,44 @@ limitations under the License.
- package com.google.android.odml.image;
-
- import android.graphics.Bitmap;
- +
- import com.google.android.odml.image.MlImage.ImageFormat;
-
- class BitmapImageContainer implements ImageContainer {
- + private final Bitmap bitmap;
- + private final ImageProperties properties;
- +
- + public BitmapImageContainer(Bitmap bitmap) {
- + this.bitmap = bitmap;
- + this.properties = ImageProperties.builder()
- + .setImageFormat(convertFormatCode(bitmap.getConfig()))
- + .setStorageType(MlImage.STORAGE_TYPE_BITMAP)
- + .build();
- + }
- +
- + public Bitmap getBitmap() {
- + return bitmap;
- + }
- +
- + @Override
- + public ImageProperties getImageProperties() {
- + return properties;
- + }
- +
- + @Override
- + public void close() {
- + bitmap.recycle();
- + }
-
- - private final Bitmap bitmap;
- - private final ImageProperties properties;
- -
- - public BitmapImageContainer(Bitmap bitmap) {
- - this.bitmap = bitmap;
- - this.properties = ImageProperties.builder()
- - .setImageFormat(convertFormatCode(bitmap.getConfig()))
- - .setStorageType(MlImage.STORAGE_TYPE_BITMAP)
- - .build();
- - }
- -
- - public Bitmap getBitmap() {
- - return bitmap;
- - }
- -
- - @Override
- - public ImageProperties getImageProperties() {
- - return properties;
- - }
- -
- - @Override
- - public void close() {
- - bitmap.recycle();
- - }
- -
- - @ImageFormat
- - static int convertFormatCode(Bitmap.Config config) {
- - switch (config) {
- - case ALPHA_8:
- - return MlImage.IMAGE_FORMAT_ALPHA;
- - case ARGB_8888:
- - return MlImage.IMAGE_FORMAT_RGBA;
- - default:
- - return MlImage.IMAGE_FORMAT_UNKNOWN;
- + @ImageFormat
- + static int convertFormatCode(Bitmap.Config config) {
- + switch (config) {
- + case ALPHA_8:
- + return MlImage.IMAGE_FORMAT_ALPHA;
- + case ARGB_8888:
- + return MlImage.IMAGE_FORMAT_RGBA;
- + default:
- + return MlImage.IMAGE_FORMAT_UNKNOWN;
- + }
- }
- - }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/BitmapMlImageBuilder.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/BitmapMlImageBuilder.java
- index fe9c35a8a6ede..6c4552bfdac3a 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/BitmapMlImageBuilder.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/BitmapMlImageBuilder.java
- @@ -20,6 +20,7 @@ import android.graphics.Bitmap;
- import android.graphics.Rect;
- import android.net.Uri;
- import android.provider.MediaStore;
- +
- import java.io.IOException;
-
- /**
- @@ -32,82 +33,76 @@ import java.io.IOException;
- * <p>Use {@link BitmapExtractor} to get {@link android.graphics.Bitmap} you passed in.
- */
- public class BitmapMlImageBuilder {
- + // Mandatory fields.
- + private final Bitmap bitmap;
-
- - // Mandatory fields.
- - private final Bitmap bitmap;
- -
- - // Optional fields.
- - private int rotation;
- - private Rect roi;
- - private long timestamp;
- + // Optional fields.
- + private int rotation;
- + private Rect roi;
- + private long timestamp;
-
- - /**
- - * Creates the builder with a mandatory {@link android.graphics.Bitmap}.
- - *
- - * <p>Also calls {@link #setRotation(int)} to set the optional properties. If not set, the values
- - * will be set with default:
- - *
- - * <ul>
- - * <li>rotation: 0
- - * </ul>
- - *
- - * @param bitmap image data object.
- - */
- - public BitmapMlImageBuilder(Bitmap bitmap) {
- - this.bitmap = bitmap;
- - rotation = 0;
- - roi = new Rect(0, 0, bitmap.getWidth(), bitmap.getHeight());
- - timestamp = 0;
- - }
- + /**
- + * Creates the builder with a mandatory {@link android.graphics.Bitmap}.
- + *
- + * <p>Also calls {@link #setRotation(int)} to set the optional properties. If not set, the
- + * values will be set with default:
- + *
- + * <ul>
- + * <li>rotation: 0
- + * </ul>
- + *
- + * @param bitmap image data object.
- + */
- + public BitmapMlImageBuilder(Bitmap bitmap) {
- + this.bitmap = bitmap;
- + rotation = 0;
- + roi = new Rect(0, 0, bitmap.getWidth(), bitmap.getHeight());
- + timestamp = 0;
- + }
-
- - /**
- - * Creates the builder to build {@link MlImage} from a file.
- - *
- - * <p>Also calls {@link #setRotation(int)} to set the optional properties. If not set, the values
- - * will be set with default:
- - *
- - * <ul>
- - * <li>rotation: 0
- - * </ul>
- - *
- - * @param context the application context.
- - * @param uri the path to the resource file.
- - */
- - public BitmapMlImageBuilder(Context context, Uri uri) throws IOException {
- - this(MediaStore.Images.Media.getBitmap(context.getContentResolver(), uri));
- - }
- + /**
- + * Creates the builder to build {@link MlImage} from a file.
- + *
- + * <p>Also calls {@link #setRotation(int)} to set the optional properties. If not set, the
- + * values will be set with default:
- + *
- + * <ul>
- + * <li>rotation: 0
- + * </ul>
- + *
- + * @param context the application context.
- + * @param uri the path to the resource file.
- + */
- + public BitmapMlImageBuilder(Context context, Uri uri) throws IOException {
- + this(MediaStore.Images.Media.getBitmap(context.getContentResolver(), uri));
- + }
-
- - /**
- - * Sets value for {@link MlImage#getRotation()}.
- - *
- - * @throws IllegalArgumentException if the rotation value is not 0, 90, 180 or 270.
- - */
- - public BitmapMlImageBuilder setRotation(int rotation) {
- - MlImage.validateRotation(rotation);
- - this.rotation = rotation;
- - return this;
- - }
- + /**
- + * Sets value for {@link MlImage#getRotation()}.
- + *
- + * @throws IllegalArgumentException if the rotation value is not 0, 90, 180 or 270.
- + */
- + public BitmapMlImageBuilder setRotation(int rotation) {
- + MlImage.validateRotation(rotation);
- + this.rotation = rotation;
- + return this;
- + }
-
- - /** Sets value for {@link MlImage#getRoi()}. */
- - BitmapMlImageBuilder setRoi(Rect roi) {
- - this.roi = roi;
- - return this;
- - }
- + /** Sets value for {@link MlImage#getRoi()}. */
- + BitmapMlImageBuilder setRoi(Rect roi) {
- + this.roi = roi;
- + return this;
- + }
-
- - /** Sets value for {@link MlImage#getTimestamp()}. */
- - BitmapMlImageBuilder setTimestamp(long timestamp) {
- - this.timestamp = timestamp;
- - return this;
- - }
- + /** Sets value for {@link MlImage#getTimestamp()}. */
- + BitmapMlImageBuilder setTimestamp(long timestamp) {
- + this.timestamp = timestamp;
- + return this;
- + }
-
- - /** Builds an {@link MlImage} instance. */
- - public MlImage build() {
- - return new MlImage(
- - new BitmapImageContainer(bitmap),
- - rotation,
- - roi,
- - timestamp,
- - bitmap.getWidth(),
- - bitmap.getHeight());
- - }
- + /** Builds an {@link MlImage} instance. */
- + public MlImage build() {
- + return new MlImage(new BitmapImageContainer(bitmap), rotation, roi, timestamp,
- + bitmap.getWidth(), bitmap.getHeight());
- + }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ByteBufferExtractor.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ByteBufferExtractor.java
- index 7b86be6d1b533..d5861c8ca94ac 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ByteBufferExtractor.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ByteBufferExtractor.java
- @@ -19,8 +19,10 @@ import android.graphics.Bitmap;
- import android.graphics.Bitmap.Config;
- import android.os.Build.VERSION;
- import android.os.Build.VERSION_CODES;
- +
- import com.google.android.odml.image.MlImage.ImageFormat;
- import com.google.auto.value.AutoValue;
- +
- import java.nio.ByteBuffer;
- import java.nio.ByteOrder;
- import java.util.Locale;
- @@ -32,229 +34,234 @@ import java.util.Locale;
- * otherwise {@link IllegalArgumentException} will be thrown.
- */
- public class ByteBufferExtractor {
- -
- - /**
- - * Extracts a {@link ByteBuffer} from an {@link MlImage}.
- - *
- - * <p>The returned {@link ByteBuffer} is a read-only view, with the first available {@link
- - * ImageProperties} whose storage type is {@code MlImage.STORAGE_TYPE_BYTEBUFFER}.
- - *
- - * @see MlImage#getContainedImageProperties()
- - * @return A read-only {@link ByteBuffer}.
- - * @throws IllegalArgumentException when the image doesn't contain a {@link ByteBuffer} storage.
- - */
- - public static ByteBuffer extract(MlImage image) {
- - ImageContainer container = image.getContainer();
- - switch (container.getImageProperties().getStorageType()) {
- - case MlImage.STORAGE_TYPE_BYTEBUFFER:
- - ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container;
- - return byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer();
- - default:
- - throw new IllegalArgumentException(
- - "Extract ByteBuffer from an MlImage created by objects other than Bytebuffer is not"
- - + " supported");
- - }
- - }
- -
- - /**
- - * Extracts a readonly {@link ByteBuffer} in given {@code targetFormat} from an {@link MlImage}.
- - *
- - * <p>Notice: Properties of the {@code image} like rotation will not take effects.
- - *
- - * <p>Format conversion spec:
- - *
- - * <ul>
- - * <li>When extracting RGB images to RGBA format, A channel will always set to 255.
- - * <li>When extracting RGBA images to RGB format, A channel will be dropped.
- - * </ul>
- - *
- - * @param image the image to extract buffer from.
- - * @param targetFormat the image format of the result bytebuffer.
- - * @return the readonly {@link ByteBuffer} stored in {@link MlImage}
- - * @throws IllegalArgumentException when the extraction requires unsupported format or data type
- - * conversions.
- - */
- - static ByteBuffer extract(MlImage image, @ImageFormat int targetFormat) {
- - ImageContainer container;
- - ImageProperties byteBufferProperties =
- - ImageProperties.builder()
- - .setStorageType(MlImage.STORAGE_TYPE_BYTEBUFFER)
- - .setImageFormat(targetFormat)
- - .build();
- - if ((container = image.getContainer(byteBufferProperties)) != null) {
- - ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container;
- - return byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer();
- - } else if ((container = image.getContainer(MlImage.STORAGE_TYPE_BYTEBUFFER)) != null) {
- - ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container;
- - @ImageFormat int sourceFormat = byteBufferImageContainer.getImageFormat();
- - return convertByteBuffer(byteBufferImageContainer.getByteBuffer(), sourceFormat, targetFormat)
- - .asReadOnlyBuffer();
- - } else if ((container = image.getContainer(MlImage.STORAGE_TYPE_BITMAP)) != null) {
- - BitmapImageContainer bitmapImageContainer = (BitmapImageContainer) container;
- - ByteBuffer byteBuffer =
- - extractByteBufferFromBitmap(bitmapImageContainer.getBitmap(), targetFormat)
- - .asReadOnlyBuffer();
- - image.addContainer(new ByteBufferImageContainer(byteBuffer, targetFormat));
- - return byteBuffer;
- - } else {
- - throw new IllegalArgumentException(
- - "Extracting ByteBuffer from an MlImage created by objects other than Bitmap or"
- - + " Bytebuffer is not supported");
- - }
- - }
- -
- - /** A wrapper for a {@link ByteBuffer} and its {@link ImageFormat}. */
- - @AutoValue
- - abstract static class Result {
- /**
- - * Gets the {@link ByteBuffer} in the result of {@link ByteBufferExtractor#extract(MlImage)}.
- + * Extracts a {@link ByteBuffer} from an {@link MlImage}.
- + *
- + * <p>The returned {@link ByteBuffer} is a read-only view, with the first available {@link
- + * ImageProperties} whose storage type is {@code MlImage.STORAGE_TYPE_BYTEBUFFER}.
- + *
- + * @see MlImage#getContainedImageProperties()
- + * @return A read-only {@link ByteBuffer}.
- + * @throws IllegalArgumentException when the image doesn't contain a {@link ByteBuffer} storage.
- */
- - public abstract ByteBuffer buffer();
- + public static ByteBuffer extract(MlImage image) {
- + ImageContainer container = image.getContainer();
- + switch (container.getImageProperties().getStorageType()) {
- + case MlImage.STORAGE_TYPE_BYTEBUFFER:
- + ByteBufferImageContainer byteBufferImageContainer =
- + (ByteBufferImageContainer) container;
- + return byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer();
- + default:
- + throw new IllegalArgumentException(
- + "Extract ByteBuffer from an MlImage created by objects other than Bytebuffer is not"
- + + " supported");
- + }
- + }
-
- /**
- - * Gets the {@link ImageFormat} in the result of {@link ByteBufferExtractor#extract(MlImage)}.
- + * Extracts a readonly {@link ByteBuffer} in given {@code targetFormat} from an {@link MlImage}.
- + *
- + * <p>Notice: Properties of the {@code image} like rotation will not take effects.
- + *
- + * <p>Format conversion spec:
- + *
- + * <ul>
- + * <li>When extracting RGB images to RGBA format, A channel will always set to 255.
- + * <li>When extracting RGBA images to RGB format, A channel will be dropped.
- + * </ul>
- + *
- + * @param image the image to extract buffer from.
- + * @param targetFormat the image format of the result bytebuffer.
- + * @return the readonly {@link ByteBuffer} stored in {@link MlImage}
- + * @throws IllegalArgumentException when the extraction requires unsupported format or data type
- + * conversions.
- */
- - @ImageFormat
- - public abstract int format();
- -
- - static Result create(ByteBuffer buffer, @ImageFormat int imageFormat) {
- - return new AutoValue_ByteBufferExtractor_Result(buffer, imageFormat);
- + static ByteBuffer extract(MlImage image, @ImageFormat int targetFormat) {
- + ImageContainer container;
- + ImageProperties byteBufferProperties =
- + ImageProperties.builder()
- + .setStorageType(MlImage.STORAGE_TYPE_BYTEBUFFER)
- + .setImageFormat(targetFormat)
- + .build();
- + if ((container = image.getContainer(byteBufferProperties)) != null) {
- + ByteBufferImageContainer byteBufferImageContainer =
- + (ByteBufferImageContainer) container;
- + return byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer();
- + } else if ((container = image.getContainer(MlImage.STORAGE_TYPE_BYTEBUFFER)) != null) {
- + ByteBufferImageContainer byteBufferImageContainer =
- + (ByteBufferImageContainer) container;
- + @ImageFormat
- + int sourceFormat = byteBufferImageContainer.getImageFormat();
- + return convertByteBuffer(
- + byteBufferImageContainer.getByteBuffer(), sourceFormat, targetFormat)
- + .asReadOnlyBuffer();
- + } else if ((container = image.getContainer(MlImage.STORAGE_TYPE_BITMAP)) != null) {
- + BitmapImageContainer bitmapImageContainer = (BitmapImageContainer) container;
- + ByteBuffer byteBuffer =
- + extractByteBufferFromBitmap(bitmapImageContainer.getBitmap(), targetFormat)
- + .asReadOnlyBuffer();
- + image.addContainer(new ByteBufferImageContainer(byteBuffer, targetFormat));
- + return byteBuffer;
- + } else {
- + throw new IllegalArgumentException(
- + "Extracting ByteBuffer from an MlImage created by objects other than Bitmap or"
- + + " Bytebuffer is not supported");
- + }
- }
- - }
-
- - /**
- - * Extracts a {@link ByteBuffer} in any available {@code imageFormat} from an {@link MlImage}.
- - *
- - * <p>It will make the best effort to return an already existed {@link ByteBuffer} to avoid copy.
- - *
- - * <p>Notice: Properties of the {@code image} like rotation will not take effects.
- - *
- - * @return the readonly {@link ByteBuffer} stored in {@link MlImage}
- - * @throws IllegalArgumentException when {@code image} doesn't contain {@link ByteBuffer} with
- - * given {@code imageFormat}
- - */
- - static Result extractInRecommendedFormat(MlImage image) {
- - ImageContainer container;
- - if ((container = image.getContainer(MlImage.STORAGE_TYPE_BITMAP)) != null) {
- - Bitmap bitmap = ((BitmapImageContainer) container).getBitmap();
- - @ImageFormat int format = adviseImageFormat(bitmap);
- - Result result =
- - Result.create(extractByteBufferFromBitmap(bitmap, format).asReadOnlyBuffer(), format);
- + /** A wrapper for a {@link ByteBuffer} and its {@link ImageFormat}. */
- + @AutoValue
- + abstract static class Result {
- + /**
- + * Gets the {@link ByteBuffer} in the result of {@link
- + * ByteBufferExtractor#extract(MlImage)}.
- + */
- + public abstract ByteBuffer buffer();
-
- - image.addContainer(new ByteBufferImageContainer(result.buffer(), result.format()));
- - return result;
- - } else if ((container = image.getContainer(MlImage.STORAGE_TYPE_BYTEBUFFER)) != null) {
- - ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container;
- - return Result.create(
- - byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer(),
- - byteBufferImageContainer.getImageFormat());
- - } else {
- - throw new IllegalArgumentException(
- - "Extract ByteBuffer from an MlImage created by objects other than Bitmap or Bytebuffer"
- - + " is not supported");
- + /**
- + * Gets the {@link ImageFormat} in the result of {@link
- + * ByteBufferExtractor#extract(MlImage)}.
- + */
- + @ImageFormat
- + public abstract int format();
- +
- + static Result create(ByteBuffer buffer, @ImageFormat int imageFormat) {
- + return new AutoValue_ByteBufferExtractor_Result(buffer, imageFormat);
- + }
- }
- - }
-
- - @ImageFormat
- - private static int adviseImageFormat(Bitmap bitmap) {
- - if (bitmap.getConfig() == Config.ARGB_8888) {
- - return MlImage.IMAGE_FORMAT_RGBA;
- - } else {
- - throw new IllegalArgumentException(
- - String.format(
- - "Extracting ByteBuffer from an MlImage created by a Bitmap in config %s is not"
- - + " supported",
- - bitmap.getConfig()));
- + /**
- + * Extracts a {@link ByteBuffer} in any available {@code imageFormat} from an {@link MlImage}.
- + *
- + * <p>It will make the best effort to return an already existed {@link ByteBuffer} to avoid
- + * copy.
- + *
- + * <p>Notice: Properties of the {@code image} like rotation will not take effects.
- + *
- + * @return the readonly {@link ByteBuffer} stored in {@link MlImage}
- + * @throws IllegalArgumentException when {@code image} doesn't contain {@link ByteBuffer} with
- + * given {@code imageFormat}
- + */
- + static Result extractInRecommendedFormat(MlImage image) {
- + ImageContainer container;
- + if ((container = image.getContainer(MlImage.STORAGE_TYPE_BITMAP)) != null) {
- + Bitmap bitmap = ((BitmapImageContainer) container).getBitmap();
- + @ImageFormat
- + int format = adviseImageFormat(bitmap);
- + Result result = Result.create(
- + extractByteBufferFromBitmap(bitmap, format).asReadOnlyBuffer(), format);
- +
- + image.addContainer(new ByteBufferImageContainer(result.buffer(), result.format()));
- + return result;
- + } else if ((container = image.getContainer(MlImage.STORAGE_TYPE_BYTEBUFFER)) != null) {
- + ByteBufferImageContainer byteBufferImageContainer =
- + (ByteBufferImageContainer) container;
- + return Result.create(byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer(),
- + byteBufferImageContainer.getImageFormat());
- + } else {
- + throw new IllegalArgumentException(
- + "Extract ByteBuffer from an MlImage created by objects other than Bitmap or Bytebuffer"
- + + " is not supported");
- + }
- }
- - }
-
- - private static ByteBuffer extractByteBufferFromBitmap(
- - Bitmap bitmap, @ImageFormat int imageFormat) {
- - if (VERSION.SDK_INT >= VERSION_CODES.JELLY_BEAN_MR1 && bitmap.isPremultiplied()) {
- - throw new IllegalArgumentException(
- - "Extracting ByteBuffer from an MlImage created by a premultiplied Bitmap is not"
- - + " supported");
- + @ImageFormat
- + private static int adviseImageFormat(Bitmap bitmap) {
- + if (bitmap.getConfig() == Config.ARGB_8888) {
- + return MlImage.IMAGE_FORMAT_RGBA;
- + } else {
- + throw new IllegalArgumentException(String.format(
- + "Extracting ByteBuffer from an MlImage created by a Bitmap in config %s is not"
- + + " supported",
- + bitmap.getConfig()));
- + }
- }
- - if (bitmap.getConfig() == Config.ARGB_8888) {
- - if (imageFormat == MlImage.IMAGE_FORMAT_RGBA) {
- - ByteBuffer buffer = ByteBuffer.allocateDirect(bitmap.getByteCount());
- - bitmap.copyPixelsToBuffer(buffer);
- - buffer.rewind();
- - return buffer;
- - } else if (imageFormat == MlImage.IMAGE_FORMAT_RGB) {
- - // TODO(b/180504869): Try Use RGBA buffer to create RGB buffer which might be faster.
- - int w = bitmap.getWidth();
- - int h = bitmap.getHeight();
- - int[] pixels = new int[w * h];
- - bitmap.getPixels(pixels, 0, w, 0, 0, w, h);
- - ByteBuffer buffer = ByteBuffer.allocateDirect(w * h * 3);
- - buffer.order(ByteOrder.nativeOrder());
- - for (int pixel : pixels) {
- - // getPixels returns Color in ARGB rather than copyPixelsToBuffer which returns RGBA
- - buffer.put((byte) ((pixel >> 16) & 0xff));
- - buffer.put((byte) ((pixel >> 8) & 0xff));
- - buffer.put((byte) (pixel & 0xff));
- +
- + private static ByteBuffer extractByteBufferFromBitmap(
- + Bitmap bitmap, @ImageFormat int imageFormat) {
- + if (VERSION.SDK_INT >= VERSION_CODES.JELLY_BEAN_MR1 && bitmap.isPremultiplied()) {
- + throw new IllegalArgumentException(
- + "Extracting ByteBuffer from an MlImage created by a premultiplied Bitmap is not"
- + + " supported");
- }
- - buffer.rewind();
- - return buffer;
- - }
- + if (bitmap.getConfig() == Config.ARGB_8888) {
- + if (imageFormat == MlImage.IMAGE_FORMAT_RGBA) {
- + ByteBuffer buffer = ByteBuffer.allocateDirect(bitmap.getByteCount());
- + bitmap.copyPixelsToBuffer(buffer);
- + buffer.rewind();
- + return buffer;
- + } else if (imageFormat == MlImage.IMAGE_FORMAT_RGB) {
- + // TODO(b/180504869): Try Use RGBA buffer to create RGB buffer which might be
- + // faster.
- + int w = bitmap.getWidth();
- + int h = bitmap.getHeight();
- + int[] pixels = new int[w * h];
- + bitmap.getPixels(pixels, 0, w, 0, 0, w, h);
- + ByteBuffer buffer = ByteBuffer.allocateDirect(w * h * 3);
- + buffer.order(ByteOrder.nativeOrder());
- + for (int pixel : pixels) {
- + // getPixels returns Color in ARGB rather than copyPixelsToBuffer which returns
- + // RGBA
- + buffer.put((byte) ((pixel >> 16) & 0xff));
- + buffer.put((byte) ((pixel >> 8) & 0xff));
- + buffer.put((byte) (pixel & 0xff));
- + }
- + buffer.rewind();
- + return buffer;
- + }
- + }
- + throw new IllegalArgumentException(String.format(
- + "Extracting ByteBuffer from an MlImage created by Bitmap and convert from %s to format"
- + + " %d is not supported",
- + bitmap.getConfig(), imageFormat));
- }
- - throw new IllegalArgumentException(
- - String.format(
- - "Extracting ByteBuffer from an MlImage created by Bitmap and convert from %s to format"
- - + " %d is not supported",
- - bitmap.getConfig(), imageFormat));
- - }
-
- - private static ByteBuffer convertByteBuffer(
- - ByteBuffer source, @ImageFormat int sourceFormat, @ImageFormat int targetFormat) {
- - if (sourceFormat == MlImage.IMAGE_FORMAT_RGB && targetFormat == MlImage.IMAGE_FORMAT_RGBA) {
- - ByteBuffer target = ByteBuffer.allocateDirect(source.capacity() / 3 * 4);
- - // Extend the buffer when the target is longer than the source. Use two cursors and sweep the
- - // array reversely to convert in-place.
- - byte[] array = new byte[target.capacity()];
- - source.get(array, 0, source.capacity());
- - source.rewind();
- - int rgbCursor = source.capacity();
- - int rgbaCursor = target.capacity();
- - while (rgbCursor != rgbaCursor) {
- - array[--rgbaCursor] = (byte) 0xff; // A
- - array[--rgbaCursor] = array[--rgbCursor]; // B
- - array[--rgbaCursor] = array[--rgbCursor]; // G
- - array[--rgbaCursor] = array[--rgbCursor]; // R
- - }
- - target.put(array, 0, target.capacity());
- - target.rewind();
- - return target;
- - } else if (sourceFormat == MlImage.IMAGE_FORMAT_RGBA
- - && targetFormat == MlImage.IMAGE_FORMAT_RGB) {
- - ByteBuffer target = ByteBuffer.allocateDirect(source.capacity() / 4 * 3);
- - // Shrink the buffer when the target is shorter than the source. Use two cursors and sweep the
- - // array to convert in-place.
- - byte[] array = new byte[source.capacity()];
- - source.get(array, 0, source.capacity());
- - source.rewind();
- - int rgbaCursor = 0;
- - int rgbCursor = 0;
- - while (rgbaCursor < array.length) {
- - array[rgbCursor++] = array[rgbaCursor++]; // R
- - array[rgbCursor++] = array[rgbaCursor++]; // G
- - array[rgbCursor++] = array[rgbaCursor++]; // B
- - rgbaCursor++;
- - }
- - target.put(array, 0, target.capacity());
- - target.rewind();
- - return target;
- - } else {
- - throw new IllegalArgumentException(
- - String.format(
- - Locale.ENGLISH,
- - "Convert bytebuffer image format from %d to %d is not supported",
- - sourceFormat,
- - targetFormat));
- + private static ByteBuffer convertByteBuffer(
- + ByteBuffer source, @ImageFormat int sourceFormat, @ImageFormat int targetFormat) {
- + if (sourceFormat == MlImage.IMAGE_FORMAT_RGB && targetFormat == MlImage.IMAGE_FORMAT_RGBA) {
- + ByteBuffer target = ByteBuffer.allocateDirect(source.capacity() / 3 * 4);
- + // Extend the buffer when the target is longer than the source. Use two cursors and
- + // sweep the array reversely to convert in-place.
- + byte[] array = new byte[target.capacity()];
- + source.get(array, 0, source.capacity());
- + source.rewind();
- + int rgbCursor = source.capacity();
- + int rgbaCursor = target.capacity();
- + while (rgbCursor != rgbaCursor) {
- + array[--rgbaCursor] = (byte) 0xff; // A
- + array[--rgbaCursor] = array[--rgbCursor]; // B
- + array[--rgbaCursor] = array[--rgbCursor]; // G
- + array[--rgbaCursor] = array[--rgbCursor]; // R
- + }
- + target.put(array, 0, target.capacity());
- + target.rewind();
- + return target;
- + } else if (sourceFormat == MlImage.IMAGE_FORMAT_RGBA
- + && targetFormat == MlImage.IMAGE_FORMAT_RGB) {
- + ByteBuffer target = ByteBuffer.allocateDirect(source.capacity() / 4 * 3);
- + // Shrink the buffer when the target is shorter than the source. Use two cursors and
- + // sweep the array to convert in-place.
- + byte[] array = new byte[source.capacity()];
- + source.get(array, 0, source.capacity());
- + source.rewind();
- + int rgbaCursor = 0;
- + int rgbCursor = 0;
- + while (rgbaCursor < array.length) {
- + array[rgbCursor++] = array[rgbaCursor++]; // R
- + array[rgbCursor++] = array[rgbaCursor++]; // G
- + array[rgbCursor++] = array[rgbaCursor++]; // B
- + rgbaCursor++;
- + }
- + target.put(array, 0, target.capacity());
- + target.rewind();
- + return target;
- + } else {
- + throw new IllegalArgumentException(String.format(Locale.ENGLISH,
- + "Convert bytebuffer image format from %d to %d is not supported", sourceFormat,
- + targetFormat));
- + }
- }
- - }
-
- - // ByteBuffer is not able to be instantiated.
- - private ByteBufferExtractor() {}
- + // ByteBuffer is not able to be instantiated.
- + private ByteBufferExtractor() {}
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ByteBufferImageContainer.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ByteBufferImageContainer.java
- index 9fbc3cbb94994..f872db485a8a2 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ByteBufferImageContainer.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ByteBufferImageContainer.java
- @@ -16,42 +16,40 @@ limitations under the License.
- package com.google.android.odml.image;
-
- import com.google.android.odml.image.MlImage.ImageFormat;
- +
- import java.nio.ByteBuffer;
-
- class ByteBufferImageContainer implements ImageContainer {
- -
- - private final ByteBuffer buffer;
- - private final ImageProperties properties;
- -
- - public ByteBufferImageContainer(
- - ByteBuffer buffer,
- - @ImageFormat int imageFormat) {
- - this.buffer = buffer;
- - this.properties = ImageProperties.builder()
- - .setStorageType(MlImage.STORAGE_TYPE_BYTEBUFFER)
- - .setImageFormat(imageFormat)
- - .build();
- - }
- -
- - public ByteBuffer getByteBuffer() {
- - return buffer;
- - }
- -
- - @Override
- - public ImageProperties getImageProperties() {
- - return properties;
- - }
- -
- - /**
- - * Returns the image format.
- - */
- - @ImageFormat
- - public int getImageFormat() {
- - return properties.getImageFormat();
- - }
- -
- - @Override
- - public void close() {
- - // No op for ByteBuffer.
- - }
- + private final ByteBuffer buffer;
- + private final ImageProperties properties;
- +
- + public ByteBufferImageContainer(ByteBuffer buffer, @ImageFormat int imageFormat) {
- + this.buffer = buffer;
- + this.properties = ImageProperties.builder()
- + .setStorageType(MlImage.STORAGE_TYPE_BYTEBUFFER)
- + .setImageFormat(imageFormat)
- + .build();
- + }
- +
- + public ByteBuffer getByteBuffer() {
- + return buffer;
- + }
- +
- + @Override
- + public ImageProperties getImageProperties() {
- + return properties;
- + }
- +
- + /**
- + * Returns the image format.
- + */
- + @ImageFormat
- + public int getImageFormat() {
- + return properties.getImageFormat();
- + }
- +
- + @Override
- + public void close() {
- + // No op for ByteBuffer.
- + }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ByteBufferMlImageBuilder.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ByteBufferMlImageBuilder.java
- index 421e2b8f0de31..f4b0b31dd5e3b 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ByteBufferMlImageBuilder.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ByteBufferMlImageBuilder.java
- @@ -16,7 +16,9 @@ limitations under the License.
- package com.google.android.odml.image;
-
- import android.graphics.Rect;
- +
- import com.google.android.odml.image.MlImage.ImageFormat;
- +
- import java.nio.ByteBuffer;
-
- /**
- @@ -28,79 +30,74 @@ import java.nio.ByteBuffer;
- * <p>Use {@link ByteBufferExtractor} to get {@link ByteBuffer} you passed in.
- */
- public class ByteBufferMlImageBuilder {
- + // Mandatory fields.
- + private final ByteBuffer buffer;
- + private final int width;
- + private final int height;
- + @ImageFormat
- + private final int imageFormat;
-
- - // Mandatory fields.
- - private final ByteBuffer buffer;
- - private final int width;
- - private final int height;
- - @ImageFormat private final int imageFormat;
- -
- - // Optional fields.
- - private int rotation;
- - private Rect roi;
- - private long timestamp;
- + // Optional fields.
- + private int rotation;
- + private Rect roi;
- + private long timestamp;
-
- - /**
- - * Creates the builder with mandatory {@link ByteBuffer} and the represented image.
- - *
- - * <p>We will validate the size of the {@code byteBuffer} with given {@code width}, {@code height}
- - * and {@code imageFormat}.
- - *
- - * <p>Also calls {@link #setRotation(int)} to set the optional properties. If not set, the values
- - * will be set with default:
- - *
- - * <ul>
- - * <li>rotation: 0
- - * </ul>
- - *
- - * @param byteBuffer image data object.
- - * @param width the width of the represented image.
- - * @param height the height of the represented image.
- - * @param imageFormat how the data encode the image.
- - */
- - public ByteBufferMlImageBuilder(
- - ByteBuffer byteBuffer, int width, int height, @ImageFormat int imageFormat) {
- - this.buffer = byteBuffer;
- - this.width = width;
- - this.height = height;
- - this.imageFormat = imageFormat;
- - // TODO(b/180504869): Validate bytebuffer size with width, height and image format
- - this.rotation = 0;
- - this.roi = new Rect(0, 0, width, height);
- - this.timestamp = 0;
- - }
- + /**
- + * Creates the builder with mandatory {@link ByteBuffer} and the represented image.
- + *
- + * <p>We will validate the size of the {@code byteBuffer} with given {@code width}, {@code
- + * height} and {@code imageFormat}.
- + *
- + * <p>Also calls {@link #setRotation(int)} to set the optional properties. If not set, the
- + * values will be set with default:
- + *
- + * <ul>
- + * <li>rotation: 0
- + * </ul>
- + *
- + * @param byteBuffer image data object.
- + * @param width the width of the represented image.
- + * @param height the height of the represented image.
- + * @param imageFormat how the data encode the image.
- + */
- + public ByteBufferMlImageBuilder(
- + ByteBuffer byteBuffer, int width, int height, @ImageFormat int imageFormat) {
- + this.buffer = byteBuffer;
- + this.width = width;
- + this.height = height;
- + this.imageFormat = imageFormat;
- + // TODO(b/180504869): Validate bytebuffer size with width, height and image format
- + this.rotation = 0;
- + this.roi = new Rect(0, 0, width, height);
- + this.timestamp = 0;
- + }
-
- - /**
- - * Sets value for {@link MlImage#getRotation()}.
- - *
- - * @throws IllegalArgumentException if the rotation value is not 0, 90, 180 or 270.
- - */
- - public ByteBufferMlImageBuilder setRotation(int rotation) {
- - MlImage.validateRotation(rotation);
- - this.rotation = rotation;
- - return this;
- - }
- + /**
- + * Sets value for {@link MlImage#getRotation()}.
- + *
- + * @throws IllegalArgumentException if the rotation value is not 0, 90, 180 or 270.
- + */
- + public ByteBufferMlImageBuilder setRotation(int rotation) {
- + MlImage.validateRotation(rotation);
- + this.rotation = rotation;
- + return this;
- + }
-
- - /** Sets value for {@link MlImage#getRoi()}. */
- - ByteBufferMlImageBuilder setRoi(Rect roi) {
- - this.roi = roi;
- - return this;
- - }
- + /** Sets value for {@link MlImage#getRoi()}. */
- + ByteBufferMlImageBuilder setRoi(Rect roi) {
- + this.roi = roi;
- + return this;
- + }
-
- - /** Sets value for {@link MlImage#getTimestamp()}. */
- - ByteBufferMlImageBuilder setTimestamp(long timestamp) {
- - this.timestamp = timestamp;
- - return this;
- - }
- + /** Sets value for {@link MlImage#getTimestamp()}. */
- + ByteBufferMlImageBuilder setTimestamp(long timestamp) {
- + this.timestamp = timestamp;
- + return this;
- + }
-
- - /** Builds an {@link MlImage} instance. */
- - public MlImage build() {
- - return new MlImage(
- - new ByteBufferImageContainer(buffer, imageFormat),
- - rotation,
- - roi,
- - timestamp,
- - width,
- - height);
- - }
- + /** Builds an {@link MlImage} instance. */
- + public MlImage build() {
- + return new MlImage(new ByteBufferImageContainer(buffer, imageFormat), rotation, roi,
- + timestamp, width, height);
- + }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ImageContainer.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ImageContainer.java
- index 25ed2312ce580..bfa7c0a292f4f 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ImageContainer.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ImageContainer.java
- @@ -20,11 +20,11 @@ import com.google.android.odml.image.annotation.KeepForSdk;
- /** Manages internal image data storage. The interface is package-private. */
- @KeepForSdk
- interface ImageContainer {
- - /** Returns the properties of the contained image. */
- - @KeepForSdk
- - ImageProperties getImageProperties();
- + /** Returns the properties of the contained image. */
- + @KeepForSdk
- + ImageProperties getImageProperties();
-
- - /** Close the image container and releases the image resource inside. */
- - @KeepForSdk
- - void close();
- + /** Close the image container and releases the image resource inside. */
- + @KeepForSdk
- + void close();
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ImageProperties.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ImageProperties.java
- index 717bc5f9935ed..a61e97b81b872 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ImageProperties.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ImageProperties.java
- @@ -24,63 +24,61 @@ import com.google.auto.value.extension.memoized.Memoized;
- /** Groups a set of properties to describe how an image is stored. */
- @AutoValue
- public abstract class ImageProperties {
- -
- - /**
- - * Gets the pixel format of the image.
- - *
- - * @see MlImage.ImageFormat
- - */
- - @ImageFormat
- - public abstract int getImageFormat();
- -
- - /**
- - * Gets the storage type of the image.
- - *
- - * @see MlImage.StorageType
- - */
- - @StorageType
- - public abstract int getStorageType();
- -
- - @Memoized
- - @Override
- - public abstract int hashCode();
- -
- - /**
- - * Creates a builder of {@link ImageProperties}.
- - *
- - * @see ImageProperties.Builder
- - */
- - @KeepForSdk
- - static Builder builder() {
- - return new AutoValue_ImageProperties.Builder();
- - }
- -
- - /** Builds a {@link ImageProperties}. */
- - @AutoValue.Builder
- - @KeepForSdk
- - abstract static class Builder {
- + /**
- + * Gets the pixel format of the image.
- + *
- + * @see MlImage.ImageFormat
- + */
- + @ImageFormat
- + public abstract int getImageFormat();
-
- /**
- - * Sets the {@link MlImage.ImageFormat}.
- + * Gets the storage type of the image.
- *
- - * @see ImageProperties#getImageFormat
- + * @see MlImage.StorageType
- */
- - @KeepForSdk
- - abstract Builder setImageFormat(@ImageFormat int value);
- + @StorageType
- + public abstract int getStorageType();
- +
- + @Memoized
- + @Override
- + public abstract int hashCode();
-
- /**
- - * Sets the {@link MlImage.StorageType}.
- + * Creates a builder of {@link ImageProperties}.
- *
- - * @see ImageProperties#getStorageType
- + * @see ImageProperties.Builder
- */
- @KeepForSdk
- - abstract Builder setStorageType(@StorageType int value);
- + static Builder builder() {
- + return new AutoValue_ImageProperties.Builder();
- + }
-
- - /** Builds the {@link ImageProperties}. */
- + /** Builds a {@link ImageProperties}. */
- + @AutoValue.Builder
- @KeepForSdk
- - abstract ImageProperties build();
- - }
- + abstract static class Builder {
- + /**
- + * Sets the {@link MlImage.ImageFormat}.
- + *
- + * @see ImageProperties#getImageFormat
- + */
- + @KeepForSdk
- + abstract Builder setImageFormat(@ImageFormat int value);
- +
- + /**
- + * Sets the {@link MlImage.StorageType}.
- + *
- + * @see ImageProperties#getStorageType
- + */
- + @KeepForSdk
- + abstract Builder setStorageType(@StorageType int value);
- +
- + /** Builds the {@link ImageProperties}. */
- + @KeepForSdk
- + abstract ImageProperties build();
- + }
-
- - // Hide the constructor.
- - ImageProperties() {}
- + // Hide the constructor.
- + ImageProperties() {}
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/MediaImageContainer.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/MediaImageContainer.java
- index 9365d0b2a422e..9ed88ee30c62f 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/MediaImageContainer.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/MediaImageContainer.java
- @@ -19,55 +19,56 @@ import android.media.Image;
- import android.os.Build;
- import android.os.Build.VERSION;
- import android.os.Build.VERSION_CODES;
- +
- import androidx.annotation.RequiresApi;
- +
- import com.google.android.odml.image.MlImage.ImageFormat;
-
- @RequiresApi(VERSION_CODES.KITKAT)
- class MediaImageContainer implements ImageContainer {
- + private final Image mediaImage;
- + private final ImageProperties properties;
-
- - private final Image mediaImage;
- - private final ImageProperties properties;
- -
- - public MediaImageContainer(Image mediaImage) {
- - this.mediaImage = mediaImage;
- - this.properties = ImageProperties.builder()
- - .setStorageType(MlImage.STORAGE_TYPE_MEDIA_IMAGE)
- - .setImageFormat(convertFormatCode(mediaImage.getFormat()))
- - .build();
- - }
- -
- - public Image getImage() {
- - return mediaImage;
- - }
- + public MediaImageContainer(Image mediaImage) {
- + this.mediaImage = mediaImage;
- + this.properties = ImageProperties.builder()
- + .setStorageType(MlImage.STORAGE_TYPE_MEDIA_IMAGE)
- + .setImageFormat(convertFormatCode(mediaImage.getFormat()))
- + .build();
- + }
-
- - @Override
- - public ImageProperties getImageProperties() {
- - return properties;
- - }
- + public Image getImage() {
- + return mediaImage;
- + }
-
- - @Override
- - public void close() {
- - mediaImage.close();
- - }
- + @Override
- + public ImageProperties getImageProperties() {
- + return properties;
- + }
-
- - @ImageFormat
- - static int convertFormatCode(int graphicsFormat) {
- - // We only cover the format mentioned in
- - // https://developer.android.com/reference/android/media/Image#getFormat()
- - if (VERSION.SDK_INT >= Build.VERSION_CODES.M) {
- - if (graphicsFormat == android.graphics.ImageFormat.FLEX_RGBA_8888) {
- - return MlImage.IMAGE_FORMAT_RGBA;
- - } else if (graphicsFormat == android.graphics.ImageFormat.FLEX_RGB_888) {
- - return MlImage.IMAGE_FORMAT_RGB;
- - }
- + @Override
- + public void close() {
- + mediaImage.close();
- }
- - switch (graphicsFormat) {
- - case android.graphics.ImageFormat.JPEG:
- - return MlImage.IMAGE_FORMAT_JPEG;
- - case android.graphics.ImageFormat.YUV_420_888:
- - return MlImage.IMAGE_FORMAT_YUV_420_888;
- - default:
- - return MlImage.IMAGE_FORMAT_UNKNOWN;
- +
- + @ImageFormat
- + static int convertFormatCode(int graphicsFormat) {
- + // We only cover the format mentioned in
- + // https://developer.android.com/reference/android/media/Image#getFormat()
- + if (VERSION.SDK_INT >= Build.VERSION_CODES.M) {
- + if (graphicsFormat == android.graphics.ImageFormat.FLEX_RGBA_8888) {
- + return MlImage.IMAGE_FORMAT_RGBA;
- + } else if (graphicsFormat == android.graphics.ImageFormat.FLEX_RGB_888) {
- + return MlImage.IMAGE_FORMAT_RGB;
- + }
- + }
- + switch (graphicsFormat) {
- + case android.graphics.ImageFormat.JPEG:
- + return MlImage.IMAGE_FORMAT_JPEG;
- + case android.graphics.ImageFormat.YUV_420_888:
- + return MlImage.IMAGE_FORMAT_YUV_420_888;
- + default:
- + return MlImage.IMAGE_FORMAT_UNKNOWN;
- + }
- }
- - }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/MediaImageExtractor.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/MediaImageExtractor.java
- index 73aadabb38789..59ed98b569fa2 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/MediaImageExtractor.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/MediaImageExtractor.java
- @@ -17,6 +17,7 @@ package com.google.android.odml.image;
-
- import android.media.Image;
- import android.os.Build.VERSION_CODES;
- +
- import androidx.annotation.RequiresApi;
-
- /**
- @@ -27,26 +28,25 @@ import androidx.annotation.RequiresApi;
- */
- @RequiresApi(VERSION_CODES.KITKAT)
- public class MediaImageExtractor {
- -
- - private MediaImageExtractor() {}
- -
- - /**
- - * Extracts a {@link android.media.Image} from an {@link MlImage}. Currently it only works for
- - * {@link MlImage} that built from {@link MediaMlImageBuilder}.
- - *
- - * <p>Notice: Properties of the {@code image} like rotation will not take effects.
- - *
- - * @param image the image to extract {@link android.media.Image} from.
- - * @return {@link android.media.Image} that stored in {@link MlImage}.
- - * @throws IllegalArgumentException if the extraction failed.
- - */
- - public static Image extract(MlImage image) {
- - ImageContainer container;
- - if ((container = image.getContainer(MlImage.STORAGE_TYPE_MEDIA_IMAGE)) != null) {
- - return ((MediaImageContainer) container).getImage();
- + private MediaImageExtractor() {}
- +
- + /**
- + * Extracts a {@link android.media.Image} from an {@link MlImage}. Currently it only works for
- + * {@link MlImage} that built from {@link MediaMlImageBuilder}.
- + *
- + * <p>Notice: Properties of the {@code image} like rotation will not take effects.
- + *
- + * @param image the image to extract {@link android.media.Image} from.
- + * @return {@link android.media.Image} that stored in {@link MlImage}.
- + * @throws IllegalArgumentException if the extraction failed.
- + */
- + public static Image extract(MlImage image) {
- + ImageContainer container;
- + if ((container = image.getContainer(MlImage.STORAGE_TYPE_MEDIA_IMAGE)) != null) {
- + return ((MediaImageContainer) container).getImage();
- + }
- + throw new IllegalArgumentException(
- + "Extract Media Image from an MlImage created by objects other than Media Image"
- + + " is not supported");
- }
- - throw new IllegalArgumentException(
- - "Extract Media Image from an MlImage created by objects other than Media Image"
- - + " is not supported");
- - }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/MediaMlImageBuilder.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/MediaMlImageBuilder.java
- index e96ab38317bac..80771bdb91890 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/MediaMlImageBuilder.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/MediaMlImageBuilder.java
- @@ -18,6 +18,7 @@ package com.google.android.odml.image;
- import android.graphics.Rect;
- import android.media.Image;
- import android.os.Build.VERSION_CODES;
- +
- import androidx.annotation.RequiresApi;
-
- /**
- @@ -30,65 +31,59 @@ import androidx.annotation.RequiresApi;
- */
- @RequiresApi(VERSION_CODES.KITKAT)
- public class MediaMlImageBuilder {
- + // Mandatory fields.
- + private final Image mediaImage;
-
- - // Mandatory fields.
- - private final Image mediaImage;
- -
- - // Optional fields.
- - private int rotation;
- - private Rect roi;
- - private long timestamp;
- + // Optional fields.
- + private int rotation;
- + private Rect roi;
- + private long timestamp;
-
- - /**
- - * Creates the builder with a mandatory {@link android.media.Image}.
- - *
- - * <p>Also calls {@link #setRotation(int)} to set the optional properties. If not set, the values
- - * will be set with default:
- - *
- - * <ul>
- - * <li>rotation: 0
- - * </ul>
- - *
- - * @param mediaImage image data object.
- - */
- - public MediaMlImageBuilder(Image mediaImage) {
- - this.mediaImage = mediaImage;
- - this.rotation = 0;
- - this.roi = new Rect(0, 0, mediaImage.getWidth(), mediaImage.getHeight());
- - this.timestamp = 0;
- - }
- + /**
- + * Creates the builder with a mandatory {@link android.media.Image}.
- + *
- + * <p>Also calls {@link #setRotation(int)} to set the optional properties. If not set, the
- + * values will be set with default:
- + *
- + * <ul>
- + * <li>rotation: 0
- + * </ul>
- + *
- + * @param mediaImage image data object.
- + */
- + public MediaMlImageBuilder(Image mediaImage) {
- + this.mediaImage = mediaImage;
- + this.rotation = 0;
- + this.roi = new Rect(0, 0, mediaImage.getWidth(), mediaImage.getHeight());
- + this.timestamp = 0;
- + }
-
- - /**
- - * Sets value for {@link MlImage#getRotation()}.
- - *
- - * @throws IllegalArgumentException if the rotation value is not 0, 90, 180 or 270.
- - */
- - public MediaMlImageBuilder setRotation(int rotation) {
- - MlImage.validateRotation(rotation);
- - this.rotation = rotation;
- - return this;
- - }
- + /**
- + * Sets value for {@link MlImage#getRotation()}.
- + *
- + * @throws IllegalArgumentException if the rotation value is not 0, 90, 180 or 270.
- + */
- + public MediaMlImageBuilder setRotation(int rotation) {
- + MlImage.validateRotation(rotation);
- + this.rotation = rotation;
- + return this;
- + }
-
- - /** Sets value for {@link MlImage#getRoi()}. */
- - MediaMlImageBuilder setRoi(Rect roi) {
- - this.roi = roi;
- - return this;
- - }
- + /** Sets value for {@link MlImage#getRoi()}. */
- + MediaMlImageBuilder setRoi(Rect roi) {
- + this.roi = roi;
- + return this;
- + }
-
- - /** Sets value for {@link MlImage#getTimestamp()}. */
- - MediaMlImageBuilder setTimestamp(long timestamp) {
- - this.timestamp = timestamp;
- - return this;
- - }
- + /** Sets value for {@link MlImage#getTimestamp()}. */
- + MediaMlImageBuilder setTimestamp(long timestamp) {
- + this.timestamp = timestamp;
- + return this;
- + }
-
- - /** Builds an {@link MlImage} instance. */
- - public MlImage build() {
- - return new MlImage(
- - new MediaImageContainer(mediaImage),
- - rotation,
- - roi,
- - timestamp,
- - mediaImage.getWidth(),
- - mediaImage.getHeight());
- - }
- + /** Builds an {@link MlImage} instance. */
- + public MlImage build() {
- + return new MlImage(new MediaImageContainer(mediaImage), rotation, roi, timestamp,
- + mediaImage.getWidth(), mediaImage.getHeight());
- + }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/MlImage.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/MlImage.java
- index 2ed3539de67f5..7e21e6ad428f2 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/MlImage.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/MlImage.java
- @@ -16,9 +16,12 @@ limitations under the License.
- package com.google.android.odml.image;
-
- import android.graphics.Rect;
- +
- import androidx.annotation.IntDef;
- import androidx.annotation.Nullable;
- +
- import com.google.android.odml.image.annotation.KeepForSdk;
- +
- import java.io.Closeable;
- import java.lang.annotation.Retention;
- import java.lang.annotation.RetentionPolicy;
- @@ -62,228 +65,232 @@ import java.util.Map.Entry;
- * and multiple storages.
- */
- public class MlImage implements Closeable {
- + /** Specifies the image format of an image. */
- + @IntDef({
- + IMAGE_FORMAT_UNKNOWN,
- + IMAGE_FORMAT_RGBA,
- + IMAGE_FORMAT_RGB,
- + IMAGE_FORMAT_NV12,
- + IMAGE_FORMAT_NV21,
- + IMAGE_FORMAT_YV12,
- + IMAGE_FORMAT_YV21,
- + IMAGE_FORMAT_YUV_420_888,
- + IMAGE_FORMAT_ALPHA,
- + IMAGE_FORMAT_JPEG,
- + })
- + @Retention(RetentionPolicy.SOURCE)
- + public @interface ImageFormat {}
- +
- + public static final int IMAGE_FORMAT_UNKNOWN = 0;
- + public static final int IMAGE_FORMAT_RGBA = 1;
- + public static final int IMAGE_FORMAT_RGB = 2;
- + public static final int IMAGE_FORMAT_NV12 = 3;
- + public static final int IMAGE_FORMAT_NV21 = 4;
- + public static final int IMAGE_FORMAT_YV12 = 5;
- + public static final int IMAGE_FORMAT_YV21 = 6;
- + public static final int IMAGE_FORMAT_YUV_420_888 = 7;
- + public static final int IMAGE_FORMAT_ALPHA = 8;
- + public static final int IMAGE_FORMAT_JPEG = 9;
- +
- + /** Specifies the image container type. Would be useful for choosing extractors. */
- + @IntDef({
- + STORAGE_TYPE_BITMAP,
- + STORAGE_TYPE_BYTEBUFFER,
- + STORAGE_TYPE_MEDIA_IMAGE,
- + STORAGE_TYPE_IMAGE_PROXY,
- + })
- + @Retention(RetentionPolicy.SOURCE)
- + public @interface StorageType {}
- +
- + public static final int STORAGE_TYPE_BITMAP = 1;
- + public static final int STORAGE_TYPE_BYTEBUFFER = 2;
- + public static final int STORAGE_TYPE_MEDIA_IMAGE = 3;
- + public static final int STORAGE_TYPE_IMAGE_PROXY = 4;
- +
- + /**
- + * Returns a list of supported image properties for this {@link MlImage}.
- + *
- + * <p>Currently {@link MlImage} only support single storage type so the size of return list will
- + * always be 1.
- + *
- + * @see ImageProperties
- + */
- + public List<ImageProperties> getContainedImageProperties() {
- + return Collections.singletonList(getContainer().getImageProperties());
- + }
- +
- + /** Returns the rotation value attached to the image. Rotation value will be 0, 90, 180, 270. */
- + public int getRotation() {
- + return rotation;
- + }
- +
- + /** Returns the timestamp attached to the image. */
- + long getTimestamp() {
- + return timestamp;
- + }
- +
- + /** Returns the width of the image. */
- + public int getWidth() {
- + return width;
- + }
- +
- + /** Returns the height of the image. */
- + public int getHeight() {
- + return height;
- + }
-
- - /** Specifies the image format of an image. */
- - @IntDef({
- - IMAGE_FORMAT_UNKNOWN,
- - IMAGE_FORMAT_RGBA,
- - IMAGE_FORMAT_RGB,
- - IMAGE_FORMAT_NV12,
- - IMAGE_FORMAT_NV21,
- - IMAGE_FORMAT_YV12,
- - IMAGE_FORMAT_YV21,
- - IMAGE_FORMAT_YUV_420_888,
- - IMAGE_FORMAT_ALPHA,
- - IMAGE_FORMAT_JPEG,
- - })
- - @Retention(RetentionPolicy.SOURCE)
- - public @interface ImageFormat {}
- -
- - public static final int IMAGE_FORMAT_UNKNOWN = 0;
- - public static final int IMAGE_FORMAT_RGBA = 1;
- - public static final int IMAGE_FORMAT_RGB = 2;
- - public static final int IMAGE_FORMAT_NV12 = 3;
- - public static final int IMAGE_FORMAT_NV21 = 4;
- - public static final int IMAGE_FORMAT_YV12 = 5;
- - public static final int IMAGE_FORMAT_YV21 = 6;
- - public static final int IMAGE_FORMAT_YUV_420_888 = 7;
- - public static final int IMAGE_FORMAT_ALPHA = 8;
- - public static final int IMAGE_FORMAT_JPEG = 9;
- -
- - /** Specifies the image container type. Would be useful for choosing extractors. */
- - @IntDef({
- - STORAGE_TYPE_BITMAP,
- - STORAGE_TYPE_BYTEBUFFER,
- - STORAGE_TYPE_MEDIA_IMAGE,
- - STORAGE_TYPE_IMAGE_PROXY,
- - })
- - @Retention(RetentionPolicy.SOURCE)
- - public @interface StorageType {}
- -
- - public static final int STORAGE_TYPE_BITMAP = 1;
- - public static final int STORAGE_TYPE_BYTEBUFFER = 2;
- - public static final int STORAGE_TYPE_MEDIA_IMAGE = 3;
- - public static final int STORAGE_TYPE_IMAGE_PROXY = 4;
- -
- - /**
- - * Returns a list of supported image properties for this {@link MlImage}.
- - *
- - * <p>Currently {@link MlImage} only support single storage type so the size of return list will
- - * always be 1.
- - *
- - * @see ImageProperties
- - */
- - public List<ImageProperties> getContainedImageProperties() {
- - return Collections.singletonList(getContainer().getImageProperties());
- - }
- -
- - /** Returns the rotation value attached to the image. Rotation value will be 0, 90, 180, 270. */
- - public int getRotation() {
- - return rotation;
- - }
- -
- - /** Returns the timestamp attached to the image. */
- - long getTimestamp() {
- - return timestamp;
- - }
- -
- - /** Returns the width of the image. */
- - public int getWidth() {
- - return width;
- - }
- -
- - /** Returns the height of the image. */
- - public int getHeight() {
- - return height;
- - }
- -
- - /** Returns the region-of-interest rectangle attached to the image. */
- - Rect getRoi() {
- - Rect result = new Rect();
- - result.set(roi);
- - return result;
- - }
- -
- - /** Acquires a reference on this {@link MlImage}. This will increase the reference count by 1. */
- - private synchronized void acquire() {
- - referenceCount += 1;
- - }
- -
- - /**
- - * Removes a reference that was previously acquired or init.
- - *
- - * <p>When {@link MlImage} is created, it has 1 reference count.
- - *
- - * <p>When the reference count becomes 0, it will release the resource under the hood.
- - */
- - @Override
- - // TODO(b/189767728): Create an internal flag to indicate image is closed, or use referenceCount
- - public synchronized void close() {
- - referenceCount -= 1;
- - if (referenceCount == 0) {
- - for (ImageContainer imageContainer : containerMap.values()) {
- - imageContainer.close();
- - }
- + /** Returns the region-of-interest rectangle attached to the image. */
- + Rect getRoi() {
- + Rect result = new Rect();
- + result.set(roi);
- + return result;
- }
- - }
- -
- - /**
- - * Advanced API access for {@link MlImage}.
- - *
- - * <p>These APIs are useful for other infrastructures, for example, acquiring extra reference
- - * count for {@link MlImage}. However, an App developer should avoid using the following APIs.
- - *
- - * <p>APIs inside are treated as internal APIs which are subject to change.
- - */
- - public static final class Internal {
-
- /**
- * Acquires a reference on this {@link MlImage}. This will increase the reference count by 1.
- + */
- + private synchronized void acquire() {
- + referenceCount += 1;
- + }
- +
- + /**
- + * Removes a reference that was previously acquired or init.
- + *
- + * <p>When {@link MlImage} is created, it has 1 reference count.
- *
- - * <p>This method is more useful for image consumer to acquire a reference so image resource
- - * will not be closed accidentally. As image creator, normal developer doesn't need to call this
- - * method.
- + * <p>When the reference count becomes 0, it will release the resource under the hood.
- + */
- + @Override
- + // TODO(b/189767728): Create an internal flag to indicate image is closed, or use referenceCount
- + public synchronized void close() {
- + referenceCount -= 1;
- + if (referenceCount == 0) {
- + for (ImageContainer imageContainer : containerMap.values()) {
- + imageContainer.close();
- + }
- + }
- + }
- +
- + /**
- + * Advanced API access for {@link MlImage}.
- *
- - * <p>The reference count is 1 when {@link MlImage} is created. Developer can call {@link
- - * #close()} to indicate it doesn't need this {@link MlImage} anymore.
- + * <p>These APIs are useful for other infrastructures, for example, acquiring extra reference
- + * count for {@link MlImage}. However, an App developer should avoid using the following APIs.
- *
- - * @see #close()
- + * <p>APIs inside are treated as internal APIs which are subject to change.
- */
- - public void acquire() {
- - image.acquire();
- + public static final class Internal {
- + /**
- + * Acquires a reference on this {@link MlImage}. This will increase the reference count
- + * by 1.
- + *
- + * <p>This method is more useful for image consumer to acquire a reference so image resource
- + * will not be closed accidentally. As image creator, normal developer doesn't need to call
- + * this method.
- + *
- + * <p>The reference count is 1 when {@link MlImage} is created. Developer can call {@link
- + * #close()} to indicate it doesn't need this {@link MlImage} anymore.
- + *
- + * @see #close()
- + */
- + public void acquire() {
- + image.acquire();
- + }
- +
- + private final MlImage image;
- +
- + // Only MlImage creates the internal helper.
- + private Internal(MlImage image) {
- + this.image = image;
- + }
- + }
- +
- + /** Gets {@link Internal} object which contains internal APIs. */
- + public Internal getInternal() {
- + return new Internal(this);
- }
-
- - private final MlImage image;
- + private final Map<ImageProperties, ImageContainer> containerMap;
- + private final int rotation;
- + private final Rect roi;
- + private final long timestamp;
- + private final int width;
- + private final int height;
- +
- + private int referenceCount;
- +
- + /** Constructs an {@link MlImage} with a built container. */
- + @KeepForSdk
- + MlImage(ImageContainer container, int rotation, Rect roi, long timestamp, int width,
- + int height) {
- + this.containerMap = new HashMap<>();
- + containerMap.put(container.getImageProperties(), container);
- + this.rotation = rotation;
- + this.roi = new Rect();
- + this.roi.set(roi);
- + this.timestamp = timestamp;
- + this.width = width;
- + this.height = height;
- + this.referenceCount = 1;
- + }
- +
- + /**
- + * Gets one available container.
- + *
- + * @return the current container.
- + */
- + @KeepForSdk
- + ImageContainer getContainer() {
- + // According to the design, in the future we will support multiple containers in one image.
- + // Currently just return the original container.
- + // TODO(b/182443927): Cache multiple containers in MlImage.
- + return containerMap.values().iterator().next();
- + }
-
- - // Only MlImage creates the internal helper.
- - private Internal(MlImage image) {
- - this.image = image;
- + /**
- + * Gets container from required {@code storageType}. Returns {@code null} if not existed.
- + *
- + * <p>If there are multiple containers with required {@code storageType}, returns the first one.
- + */
- + @Nullable
- + @KeepForSdk
- + ImageContainer getContainer(@StorageType int storageType) {
- + for (Entry<ImageProperties, ImageContainer> entry : containerMap.entrySet()) {
- + if (entry.getKey().getStorageType() == storageType) {
- + return entry.getValue();
- + }
- + }
- + return null;
- }
- - }
- -
- - /** Gets {@link Internal} object which contains internal APIs. */
- - public Internal getInternal() {
- - return new Internal(this);
- - }
- -
- - private final Map<ImageProperties, ImageContainer> containerMap;
- - private final int rotation;
- - private final Rect roi;
- - private final long timestamp;
- - private final int width;
- - private final int height;
- -
- - private int referenceCount;
- -
- - /** Constructs an {@link MlImage} with a built container. */
- - @KeepForSdk
- - MlImage(ImageContainer container, int rotation, Rect roi, long timestamp, int width, int height) {
- - this.containerMap = new HashMap<>();
- - containerMap.put(container.getImageProperties(), container);
- - this.rotation = rotation;
- - this.roi = new Rect();
- - this.roi.set(roi);
- - this.timestamp = timestamp;
- - this.width = width;
- - this.height = height;
- - this.referenceCount = 1;
- - }
- -
- - /**
- - * Gets one available container.
- - *
- - * @return the current container.
- - */
- - @KeepForSdk
- - ImageContainer getContainer() {
- - // According to the design, in the future we will support multiple containers in one image.
- - // Currently just return the original container.
- - // TODO(b/182443927): Cache multiple containers in MlImage.
- - return containerMap.values().iterator().next();
- - }
- -
- - /**
- - * Gets container from required {@code storageType}. Returns {@code null} if not existed.
- - *
- - * <p>If there are multiple containers with required {@code storageType}, returns the first one.
- - */
- - @Nullable
- - @KeepForSdk
- - ImageContainer getContainer(@StorageType int storageType) {
- - for (Entry<ImageProperties, ImageContainer> entry : containerMap.entrySet()) {
- - if (entry.getKey().getStorageType() == storageType) {
- - return entry.getValue();
- - }
- +
- + /**
- + * Gets container from required {@code imageProperties}. Returns {@code null} if non existed.
- + */
- + @Nullable
- + @KeepForSdk
- + ImageContainer getContainer(ImageProperties imageProperties) {
- + return containerMap.get(imageProperties);
- }
- - return null;
- - }
- -
- - /** Gets container from required {@code imageProperties}. Returns {@code null} if non existed. */
- - @Nullable
- - @KeepForSdk
- - ImageContainer getContainer(ImageProperties imageProperties) {
- - return containerMap.get(imageProperties);
- - }
- -
- - /** Adds a new container if it doesn't exist. Returns {@code true} if it succeeds. */
- - boolean addContainer(ImageContainer container) {
- - ImageProperties imageProperties = container.getImageProperties();
- - if (containerMap.containsKey(imageProperties)) {
- - return false;
- +
- + /** Adds a new container if it doesn't exist. Returns {@code true} if it succeeds. */
- + boolean addContainer(ImageContainer container) {
- + ImageProperties imageProperties = container.getImageProperties();
- + if (containerMap.containsKey(imageProperties)) {
- + return false;
- + }
- + containerMap.put(imageProperties, container);
- + return true;
- }
- - containerMap.put(imageProperties, container);
- - return true;
- - }
- -
- - /**
- - * Validates rotation values for builders. Only supports 0, 90, 180, 270.
- - *
- - * @throws IllegalArgumentException if the rotation value is invalid.
- - */
- - static void validateRotation(int rotation) {
- - if (rotation != 0 && rotation != 90 && rotation != 180 && rotation != 270) {
- - throw new IllegalArgumentException(
- - "Rotation value " + rotation + " is not valid. Use only 0, 90, 180 or 270.");
- +
- + /**
- + * Validates rotation values for builders. Only supports 0, 90, 180, 270.
- + *
- + * @throws IllegalArgumentException if the rotation value is invalid.
- + */
- + static void validateRotation(int rotation) {
- + if (rotation != 0 && rotation != 90 && rotation != 180 && rotation != 270) {
- + throw new IllegalArgumentException(
- + "Rotation value " + rotation + " is not valid. Use only 0, 90, 180 or 270.");
- + }
- }
- - }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/BitmapExtractorTest.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/BitmapExtractorTest.java
- index 44eb1198884fa..8408a0e424a9b 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/BitmapExtractorTest.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/BitmapExtractorTest.java
- @@ -16,39 +16,37 @@ limitations under the License.
- package com.google.android.odml.image;
-
- import static com.google.common.truth.Truth.assertThat;
- +
- import static org.junit.Assert.assertThrows;
-
- import android.graphics.Bitmap;
- -import java.nio.ByteBuffer;
- +
- import org.junit.Test;
- import org.junit.runner.RunWith;
- import org.robolectric.RobolectricTestRunner;
-
- +import java.nio.ByteBuffer;
- +
- /** Unit test for {@link BitmapExtractor}. */
- @RunWith(RobolectricTestRunner.class)
- public class BitmapExtractorTest {
- + @Test
- + public void extract_fromBitmap_succeeds() {
- + Bitmap bitmap = TestImageCreator.createRgbaBitmap();
- + MlImage image = new BitmapMlImageBuilder(bitmap).build();
- +
- + Bitmap result = BitmapExtractor.extract(image);
- +
- + assertThat(result).isSameInstanceAs(bitmap);
- + }
- +
- + @Test
- + public void extract_fromByteBuffer_throwsException() {
- + ByteBuffer buffer = TestImageCreator.createRgbBuffer();
- + MlImage image = new ByteBufferMlImageBuilder(buffer, TestImageCreator.getWidth(),
- + TestImageCreator.getHeight(), MlImage.IMAGE_FORMAT_RGB)
- + .build();
-
- - @Test
- - public void extract_fromBitmap_succeeds() {
- - Bitmap bitmap = TestImageCreator.createRgbaBitmap();
- - MlImage image = new BitmapMlImageBuilder(bitmap).build();
- -
- - Bitmap result = BitmapExtractor.extract(image);
- -
- - assertThat(result).isSameInstanceAs(bitmap);
- - }
- -
- - @Test
- - public void extract_fromByteBuffer_throwsException() {
- - ByteBuffer buffer = TestImageCreator.createRgbBuffer();
- - MlImage image =
- - new ByteBufferMlImageBuilder(
- - buffer,
- - TestImageCreator.getWidth(),
- - TestImageCreator.getHeight(),
- - MlImage.IMAGE_FORMAT_RGB)
- - .build();
- -
- - assertThrows(IllegalArgumentException.class, () -> BitmapExtractor.extract(image));
- - }
- + assertThrows(IllegalArgumentException.class, () -> BitmapExtractor.extract(image));
- + }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/BitmapMlImageBuilderTest.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/BitmapMlImageBuilderTest.java
- index f9908210f2970..9a4051cdf8f6a 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/BitmapMlImageBuilderTest.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/BitmapMlImageBuilderTest.java
- @@ -16,11 +16,13 @@ limitations under the License.
- package com.google.android.odml.image;
-
- import static com.google.common.truth.Truth.assertThat;
- +
- import static org.junit.Assert.assertThrows;
-
- import android.graphics.Bitmap;
- import android.graphics.Bitmap.Config;
- import android.graphics.Rect;
- +
- import org.junit.Test;
- import org.junit.runner.RunWith;
- import org.robolectric.RobolectricTestRunner;
- @@ -28,63 +30,59 @@ import org.robolectric.RobolectricTestRunner;
- /** Tests for {@link BitmapMlImageBuilder} */
- @RunWith(RobolectricTestRunner.class)
- public final class BitmapMlImageBuilderTest {
- -
- - @Test
- - public void build_fromBitmap_succeeds() {
- - Bitmap bitmap = Bitmap.createBitmap(20, 25, Config.ARGB_8888);
- -
- - MlImage image = new BitmapMlImageBuilder(bitmap).build();
- - ImageContainer container = image.getContainer(MlImage.STORAGE_TYPE_BITMAP);
- -
- - assertThat(image.getWidth()).isEqualTo(20);
- - assertThat(image.getHeight()).isEqualTo(25);
- - assertThat(image.getContainedImageProperties())
- - .containsExactly(
- - ImageProperties.builder()
- - .setImageFormat(MlImage.IMAGE_FORMAT_RGBA)
- - .setStorageType(MlImage.STORAGE_TYPE_BITMAP)
- - .build());
- - assertThat(((BitmapImageContainer) container).getBitmap().getConfig())
- - .isEqualTo(Config.ARGB_8888);
- - }
- -
- - @Test
- - public void build_withOptionalProperties_succeeds() {
- - Bitmap bitmap = Bitmap.createBitmap(20, 25, Config.ARGB_8888);
- -
- - MlImage image =
- - new BitmapMlImageBuilder(bitmap)
- - .setRoi(new Rect(0, 5, 10, 15))
- - .setRotation(90)
- - .setTimestamp(12345)
- - .build();
- -
- - assertThat(image.getTimestamp()).isEqualTo(12345);
- - assertThat(image.getRotation()).isEqualTo(90);
- - assertThat(image.getRoi()).isEqualTo(new Rect(0, 5, 10, 15));
- - }
- -
- - @Test
- - public void build_withInvalidRotation_throwsException() {
- - Bitmap bitmap = Bitmap.createBitmap(20, 25, Config.ARGB_8888);
- - BitmapMlImageBuilder builder = new BitmapMlImageBuilder(bitmap);
- -
- - assertThrows(IllegalArgumentException.class, () -> builder.setRotation(360));
- - }
- -
- - @Test
- - public void release_recyclesBitmap() {
- - Bitmap bitmap = Bitmap.createBitmap(20, 25, Config.ARGB_8888);
- -
- - MlImage image =
- - new BitmapMlImageBuilder(bitmap)
- - .setRoi(new Rect(0, 5, 10, 15))
- - .setRotation(90)
- - .setTimestamp(12345)
- - .build();
- - assertThat(bitmap.isRecycled()).isFalse();
- - image.close();
- -
- - assertThat(bitmap.isRecycled()).isTrue();
- - }
- + @Test
- + public void build_fromBitmap_succeeds() {
- + Bitmap bitmap = Bitmap.createBitmap(20, 25, Config.ARGB_8888);
- +
- + MlImage image = new BitmapMlImageBuilder(bitmap).build();
- + ImageContainer container = image.getContainer(MlImage.STORAGE_TYPE_BITMAP);
- +
- + assertThat(image.getWidth()).isEqualTo(20);
- + assertThat(image.getHeight()).isEqualTo(25);
- + assertThat(image.getContainedImageProperties())
- + .containsExactly(ImageProperties.builder()
- + .setImageFormat(MlImage.IMAGE_FORMAT_RGBA)
- + .setStorageType(MlImage.STORAGE_TYPE_BITMAP)
- + .build());
- + assertThat(((BitmapImageContainer) container).getBitmap().getConfig())
- + .isEqualTo(Config.ARGB_8888);
- + }
- +
- + @Test
- + public void build_withOptionalProperties_succeeds() {
- + Bitmap bitmap = Bitmap.createBitmap(20, 25, Config.ARGB_8888);
- +
- + MlImage image = new BitmapMlImageBuilder(bitmap)
- + .setRoi(new Rect(0, 5, 10, 15))
- + .setRotation(90)
- + .setTimestamp(12345)
- + .build();
- +
- + assertThat(image.getTimestamp()).isEqualTo(12345);
- + assertThat(image.getRotation()).isEqualTo(90);
- + assertThat(image.getRoi()).isEqualTo(new Rect(0, 5, 10, 15));
- + }
- +
- + @Test
- + public void build_withInvalidRotation_throwsException() {
- + Bitmap bitmap = Bitmap.createBitmap(20, 25, Config.ARGB_8888);
- + BitmapMlImageBuilder builder = new BitmapMlImageBuilder(bitmap);
- +
- + assertThrows(IllegalArgumentException.class, () -> builder.setRotation(360));
- + }
- +
- + @Test
- + public void release_recyclesBitmap() {
- + Bitmap bitmap = Bitmap.createBitmap(20, 25, Config.ARGB_8888);
- +
- + MlImage image = new BitmapMlImageBuilder(bitmap)
- + .setRoi(new Rect(0, 5, 10, 15))
- + .setRotation(90)
- + .setTimestamp(12345)
- + .build();
- + assertThat(bitmap.isRecycled()).isFalse();
- + image.close();
- +
- + assertThat(bitmap.isRecycled()).isTrue();
- + }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/ByteBufferExtractorTest.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/ByteBufferExtractorTest.java
- index 2ff49010443a5..e675ba9abd479 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/ByteBufferExtractorTest.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/ByteBufferExtractorTest.java
- @@ -16,15 +16,18 @@ limitations under the License.
- package com.google.android.odml.image;
-
- import static com.google.common.truth.Truth.assertThat;
- +
- import static org.junit.Assert.assertThrows;
-
- import android.graphics.Bitmap;
- -import java.nio.Buffer;
- -import java.nio.ByteBuffer;
- +
- import org.junit.Test;
- import org.junit.runner.RunWith;
- import org.robolectric.RobolectricTestRunner;
-
- +import java.nio.Buffer;
- +import java.nio.ByteBuffer;
- +
- /**
- * Tests for {@link ByteBufferExtractor}.
- *
- @@ -35,145 +38,120 @@ import org.robolectric.RobolectricTestRunner;
- */
- @RunWith(RobolectricTestRunner.class)
- public final class ByteBufferExtractorTest {
- -
- - @Test
- - public void extract_fromByteBuffer_succeeds() {
- - ByteBuffer byteBuffer = TestImageCreator.createRgbBuffer();
- - MlImage image =
- - new ByteBufferMlImageBuilder(
- - byteBuffer,
- - TestImageCreator.getWidth(),
- - TestImageCreator.getHeight(),
- - MlImage.IMAGE_FORMAT_RGB)
- - .build();
- -
- - ByteBuffer result = ByteBufferExtractor.extract(image);
- -
- - assertThat(result).isEquivalentAccordingToCompareTo(byteBuffer);
- - assertThat(result.isReadOnly()).isTrue();
- - }
- -
- - @Test
- - public void extract_fromBitmap_throws() {
- - Bitmap rgbaBitmap = TestImageCreator.createRgbaBitmap();
- - MlImage image = new BitmapMlImageBuilder(rgbaBitmap).build();
- -
- - assertThrows(IllegalArgumentException.class, () -> ByteBufferExtractor.extract(image));
- - }
- -
- - @Test
- - public void extract_rgbFromRgbByteBuffer_succeeds() {
- - ByteBuffer buffer = TestImageCreator.createRgbBuffer();
- - MlImage image =
- - new ByteBufferMlImageBuilder(
- - buffer,
- - TestImageCreator.getWidth(),
- - TestImageCreator.getHeight(),
- - MlImage.IMAGE_FORMAT_RGB)
- - .build();
- -
- - ByteBuffer result = ByteBufferExtractor.extract(image, MlImage.IMAGE_FORMAT_RGB);
- -
- - assertThat(result.isReadOnly()).isTrue();
- - assertThat(result).isEquivalentAccordingToCompareTo(TestImageCreator.createRgbBuffer());
- - }
- -
- - @Test
- - public void extract_rgbFromRgbaByteBuffer_succeeds() {
- - ByteBuffer buffer = TestImageCreator.createRgbaBuffer();
- - MlImage image =
- - new ByteBufferMlImageBuilder(
- - buffer,
- - TestImageCreator.getWidth(),
- - TestImageCreator.getHeight(),
- - MlImage.IMAGE_FORMAT_RGBA)
- - .build();
- -
- - ByteBuffer result = ByteBufferExtractor.extract(image, MlImage.IMAGE_FORMAT_RGB);
- -
- - assertThat(result).isEquivalentAccordingToCompareTo(TestImageCreator.createRgbBuffer());
- - assertThat(buffer.position()).isEqualTo(0);
- - }
- -
- - @Test
- - public void extract_rgbaFromRgbByteBuffer_succeeds() {
- - ByteBuffer buffer = TestImageCreator.createRgbBuffer();
- - MlImage image =
- - new ByteBufferMlImageBuilder(
- - buffer,
- - TestImageCreator.getWidth(),
- - TestImageCreator.getHeight(),
- - MlImage.IMAGE_FORMAT_RGB)
- - .build();
- -
- - ByteBuffer result = ByteBufferExtractor.extract(image, MlImage.IMAGE_FORMAT_RGBA);
- -
- - assertThat(result).isEquivalentAccordingToCompareTo(TestImageCreator.createOpaqueRgbaBuffer());
- - assertThat(buffer.position()).isEqualTo(0);
- - }
- -
- - @Test
- - public void extract_rgbFromRgbaBitmap_succeeds() {
- - Bitmap rgbaBitmap = TestImageCreator.createRgbaBitmap();
- - MlImage image = new BitmapMlImageBuilder(rgbaBitmap).build();
- -
- - ByteBuffer result = ByteBufferExtractor.extract(image, MlImage.IMAGE_FORMAT_RGB);
- -
- - assertThat(result.isReadOnly()).isTrue();
- - assertThat(result).isEquivalentAccordingToCompareTo(TestImageCreator.createRgbBuffer());
- -
- - // Verifies ByteBuffer is cached inside MlImage.
- - ByteBufferImageContainer byteBufferImageContainer =
- - (ByteBufferImageContainer) image.getContainer(MlImage.STORAGE_TYPE_BYTEBUFFER);
- - assertThat(byteBufferImageContainer.getByteBuffer()).isEqualTo(result);
- - assertThat(byteBufferImageContainer.getImageFormat()).isEqualTo(MlImage.IMAGE_FORMAT_RGB);
- -
- - // Verifies that extracted ByteBuffer is the cached one.
- - ByteBuffer result2 = ByteBufferExtractor.extract(image, MlImage.IMAGE_FORMAT_RGB);
- - assertThat(result2).isEqualTo(result);
- - }
- -
- - @Test
- - public void extract_unsupportedFormatFromByteBuffer_throws() {
- - ByteBuffer buffer = TestImageCreator.createRgbaBuffer();
- - MlImage image =
- - new ByteBufferMlImageBuilder(
- - buffer,
- - TestImageCreator.getWidth(),
- - TestImageCreator.getHeight(),
- - MlImage.IMAGE_FORMAT_RGBA)
- - .build();
- -
- - assertThrows(
- - IllegalArgumentException.class,
- - () -> ByteBufferExtractor.extract(image, MlImage.IMAGE_FORMAT_YUV_420_888));
- - }
- -
- - @Test
- - public void extractInRecommendedFormat_anyFormatFromRgbByteBuffer_succeeds() {
- - ByteBuffer buffer = TestImageCreator.createRgbBuffer();
- - MlImage image =
- - new ByteBufferMlImageBuilder(
- - buffer,
- - TestImageCreator.getWidth(),
- - TestImageCreator.getHeight(),
- - MlImage.IMAGE_FORMAT_RGB)
- - .build();
- -
- - ByteBufferExtractor.Result result = ByteBufferExtractor.extractInRecommendedFormat(image);
- -
- - assertThat(result.buffer().isReadOnly()).isTrue();
- - assertThat(result.format()).isEqualTo(MlImage.IMAGE_FORMAT_RGB);
- -
- - // Verifies ByteBuffer is cached inside MlImage.
- - ByteBufferImageContainer byteBufferImageContainer =
- - (ByteBufferImageContainer) image.getContainer(MlImage.STORAGE_TYPE_BYTEBUFFER);
- - assertThat(byteBufferImageContainer.getByteBuffer()).isEqualTo(result.buffer());
- - assertThat(byteBufferImageContainer.getImageFormat()).isEqualTo(MlImage.IMAGE_FORMAT_RGB);
- -
- - // Verifies that extracted ByteBuffer is the cached one.
- - ByteBufferExtractor.Result result2 = ByteBufferExtractor.extractInRecommendedFormat(image);
- - assertThat(result2.buffer()).isEqualTo(result.buffer());
- - assertThat(result2.format()).isEqualTo(result.format());
- - }
- + @Test
- + public void extract_fromByteBuffer_succeeds() {
- + ByteBuffer byteBuffer = TestImageCreator.createRgbBuffer();
- + MlImage image = new ByteBufferMlImageBuilder(byteBuffer, TestImageCreator.getWidth(),
- + TestImageCreator.getHeight(), MlImage.IMAGE_FORMAT_RGB)
- + .build();
- +
- + ByteBuffer result = ByteBufferExtractor.extract(image);
- +
- + assertThat(result).isEquivalentAccordingToCompareTo(byteBuffer);
- + assertThat(result.isReadOnly()).isTrue();
- + }
- +
- + @Test
- + public void extract_fromBitmap_throws() {
- + Bitmap rgbaBitmap = TestImageCreator.createRgbaBitmap();
- + MlImage image = new BitmapMlImageBuilder(rgbaBitmap).build();
- +
- + assertThrows(IllegalArgumentException.class, () -> ByteBufferExtractor.extract(image));
- + }
- +
- + @Test
- + public void extract_rgbFromRgbByteBuffer_succeeds() {
- + ByteBuffer buffer = TestImageCreator.createRgbBuffer();
- + MlImage image = new ByteBufferMlImageBuilder(buffer, TestImageCreator.getWidth(),
- + TestImageCreator.getHeight(), MlImage.IMAGE_FORMAT_RGB)
- + .build();
- +
- + ByteBuffer result = ByteBufferExtractor.extract(image, MlImage.IMAGE_FORMAT_RGB);
- +
- + assertThat(result.isReadOnly()).isTrue();
- + assertThat(result).isEquivalentAccordingToCompareTo(TestImageCreator.createRgbBuffer());
- + }
- +
- + @Test
- + public void extract_rgbFromRgbaByteBuffer_succeeds() {
- + ByteBuffer buffer = TestImageCreator.createRgbaBuffer();
- + MlImage image = new ByteBufferMlImageBuilder(buffer, TestImageCreator.getWidth(),
- + TestImageCreator.getHeight(), MlImage.IMAGE_FORMAT_RGBA)
- + .build();
- +
- + ByteBuffer result = ByteBufferExtractor.extract(image, MlImage.IMAGE_FORMAT_RGB);
- +
- + assertThat(result).isEquivalentAccordingToCompareTo(TestImageCreator.createRgbBuffer());
- + assertThat(buffer.position()).isEqualTo(0);
- + }
- +
- + @Test
- + public void extract_rgbaFromRgbByteBuffer_succeeds() {
- + ByteBuffer buffer = TestImageCreator.createRgbBuffer();
- + MlImage image = new ByteBufferMlImageBuilder(buffer, TestImageCreator.getWidth(),
- + TestImageCreator.getHeight(), MlImage.IMAGE_FORMAT_RGB)
- + .build();
- +
- + ByteBuffer result = ByteBufferExtractor.extract(image, MlImage.IMAGE_FORMAT_RGBA);
- +
- + assertThat(result).isEquivalentAccordingToCompareTo(
- + TestImageCreator.createOpaqueRgbaBuffer());
- + assertThat(buffer.position()).isEqualTo(0);
- + }
- +
- + @Test
- + public void extract_rgbFromRgbaBitmap_succeeds() {
- + Bitmap rgbaBitmap = TestImageCreator.createRgbaBitmap();
- + MlImage image = new BitmapMlImageBuilder(rgbaBitmap).build();
- +
- + ByteBuffer result = ByteBufferExtractor.extract(image, MlImage.IMAGE_FORMAT_RGB);
- +
- + assertThat(result.isReadOnly()).isTrue();
- + assertThat(result).isEquivalentAccordingToCompareTo(TestImageCreator.createRgbBuffer());
- +
- + // Verifies ByteBuffer is cached inside MlImage.
- + ByteBufferImageContainer byteBufferImageContainer =
- + (ByteBufferImageContainer) image.getContainer(MlImage.STORAGE_TYPE_BYTEBUFFER);
- + assertThat(byteBufferImageContainer.getByteBuffer()).isEqualTo(result);
- + assertThat(byteBufferImageContainer.getImageFormat()).isEqualTo(MlImage.IMAGE_FORMAT_RGB);
- +
- + // Verifies that extracted ByteBuffer is the cached one.
- + ByteBuffer result2 = ByteBufferExtractor.extract(image, MlImage.IMAGE_FORMAT_RGB);
- + assertThat(result2).isEqualTo(result);
- + }
- +
- + @Test
- + public void extract_unsupportedFormatFromByteBuffer_throws() {
- + ByteBuffer buffer = TestImageCreator.createRgbaBuffer();
- + MlImage image = new ByteBufferMlImageBuilder(buffer, TestImageCreator.getWidth(),
- + TestImageCreator.getHeight(), MlImage.IMAGE_FORMAT_RGBA)
- + .build();
- +
- + assertThrows(IllegalArgumentException.class,
- + () -> ByteBufferExtractor.extract(image, MlImage.IMAGE_FORMAT_YUV_420_888));
- + }
- +
- + @Test
- + public void extractInRecommendedFormat_anyFormatFromRgbByteBuffer_succeeds() {
- + ByteBuffer buffer = TestImageCreator.createRgbBuffer();
- + MlImage image = new ByteBufferMlImageBuilder(buffer, TestImageCreator.getWidth(),
- + TestImageCreator.getHeight(), MlImage.IMAGE_FORMAT_RGB)
- + .build();
- +
- + ByteBufferExtractor.Result result = ByteBufferExtractor.extractInRecommendedFormat(image);
- +
- + assertThat(result.buffer().isReadOnly()).isTrue();
- + assertThat(result.format()).isEqualTo(MlImage.IMAGE_FORMAT_RGB);
- +
- + // Verifies ByteBuffer is cached inside MlImage.
- + ByteBufferImageContainer byteBufferImageContainer =
- + (ByteBufferImageContainer) image.getContainer(MlImage.STORAGE_TYPE_BYTEBUFFER);
- + assertThat(byteBufferImageContainer.getByteBuffer()).isEqualTo(result.buffer());
- + assertThat(byteBufferImageContainer.getImageFormat()).isEqualTo(MlImage.IMAGE_FORMAT_RGB);
- +
- + // Verifies that extracted ByteBuffer is the cached one.
- + ByteBufferExtractor.Result result2 = ByteBufferExtractor.extractInRecommendedFormat(image);
- + assertThat(result2.buffer()).isEqualTo(result.buffer());
- + assertThat(result2.format()).isEqualTo(result.format());
- + }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/ByteBufferMlImageBuilderTest.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/ByteBufferMlImageBuilderTest.java
- index 45ba77934a61f..374c82b3f4e8d 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/ByteBufferMlImageBuilderTest.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/ByteBufferMlImageBuilderTest.java
- @@ -16,61 +16,62 @@ limitations under the License.
- package com.google.android.odml.image;
-
- import static com.google.common.truth.Truth.assertThat;
- +
- import static org.junit.Assert.assertThrows;
-
- import android.graphics.Rect;
- -import java.nio.ByteBuffer;
- +
- import org.junit.Test;
- import org.junit.runner.RunWith;
- import org.robolectric.RobolectricTestRunner;
-
- +import java.nio.ByteBuffer;
- +
- /** Tests for {@link ByteBufferMlImageBuilder} */
- @RunWith(RobolectricTestRunner.class)
- public final class ByteBufferMlImageBuilderTest {
- + @Test
- + public void build_fromByteBuffer_succeeds() {
- + ByteBuffer buffer = ByteBuffer.allocate(500);
- +
- + MlImage image =
- + new ByteBufferMlImageBuilder(buffer, 20, 25, MlImage.IMAGE_FORMAT_RGB).build();
- + ImageContainer container = image.getContainer(MlImage.STORAGE_TYPE_BYTEBUFFER);
- +
- + assertThat(image.getWidth()).isEqualTo(20);
- + assertThat(image.getHeight()).isEqualTo(25);
- + assertThat(image.getRoi()).isEqualTo(new Rect(0, 0, 20, 25));
- + assertThat(image.getRotation()).isEqualTo(0);
- + assertThat(image.getContainedImageProperties())
- + .containsExactly(ImageProperties.builder()
- + .setStorageType(MlImage.STORAGE_TYPE_BYTEBUFFER)
- + .setImageFormat(MlImage.IMAGE_FORMAT_RGB)
- + .build());
- + assertThat(((ByteBufferImageContainer) container).getImageFormat())
- + .isEqualTo(MlImage.IMAGE_FORMAT_RGB);
- + }
- +
- + @Test
- + public void build_withOptionalProperties_succeeds() {
- + ByteBuffer buffer = ByteBuffer.allocate(500);
- +
- + MlImage image = new ByteBufferMlImageBuilder(buffer, 20, 25, MlImage.IMAGE_FORMAT_RGB)
- + .setRoi(new Rect(0, 5, 10, 15))
- + .setRotation(90)
- + .setTimestamp(12345)
- + .build();
- +
- + assertThat(image.getTimestamp()).isEqualTo(12345);
- + assertThat(image.getRotation()).isEqualTo(90);
- + assertThat(image.getRoi()).isEqualTo(new Rect(0, 5, 10, 15));
- + }
- +
- + @Test
- + public void build_withInvalidRotation_throwsException() {
- + ByteBuffer buffer = ByteBuffer.allocate(500);
- + ByteBufferMlImageBuilder builder =
- + new ByteBufferMlImageBuilder(buffer, 20, 25, MlImage.IMAGE_FORMAT_RGB);
-
- - @Test
- - public void build_fromByteBuffer_succeeds() {
- - ByteBuffer buffer = ByteBuffer.allocate(500);
- -
- - MlImage image = new ByteBufferMlImageBuilder(buffer, 20, 25, MlImage.IMAGE_FORMAT_RGB).build();
- - ImageContainer container = image.getContainer(MlImage.STORAGE_TYPE_BYTEBUFFER);
- -
- - assertThat(image.getWidth()).isEqualTo(20);
- - assertThat(image.getHeight()).isEqualTo(25);
- - assertThat(image.getRoi()).isEqualTo(new Rect(0, 0, 20, 25));
- - assertThat(image.getRotation()).isEqualTo(0);
- - assertThat(image.getContainedImageProperties())
- - .containsExactly(
- - ImageProperties.builder()
- - .setStorageType(MlImage.STORAGE_TYPE_BYTEBUFFER)
- - .setImageFormat(MlImage.IMAGE_FORMAT_RGB)
- - .build());
- - assertThat(((ByteBufferImageContainer) container).getImageFormat())
- - .isEqualTo(MlImage.IMAGE_FORMAT_RGB);
- - }
- -
- - @Test
- - public void build_withOptionalProperties_succeeds() {
- - ByteBuffer buffer = ByteBuffer.allocate(500);
- -
- - MlImage image =
- - new ByteBufferMlImageBuilder(buffer, 20, 25, MlImage.IMAGE_FORMAT_RGB)
- - .setRoi(new Rect(0, 5, 10, 15))
- - .setRotation(90)
- - .setTimestamp(12345)
- - .build();
- -
- - assertThat(image.getTimestamp()).isEqualTo(12345);
- - assertThat(image.getRotation()).isEqualTo(90);
- - assertThat(image.getRoi()).isEqualTo(new Rect(0, 5, 10, 15));
- - }
- -
- - @Test
- - public void build_withInvalidRotation_throwsException() {
- - ByteBuffer buffer = ByteBuffer.allocate(500);
- - ByteBufferMlImageBuilder builder =
- - new ByteBufferMlImageBuilder(buffer, 20, 25, MlImage.IMAGE_FORMAT_RGB);
- -
- - assertThrows(IllegalArgumentException.class, () -> builder.setRotation(360));
- - }
- + assertThrows(IllegalArgumentException.class, () -> builder.setRotation(360));
- + }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/MediaImageExtractorTest.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/MediaImageExtractorTest.java
- index 67ed4a7f6e2c4..fa832671e4458 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/MediaImageExtractorTest.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/MediaImageExtractorTest.java
- @@ -16,6 +16,7 @@ limitations under the License.
- package com.google.android.odml.image;
-
- import static com.google.common.truth.Truth.assertThat;
- +
- import static org.junit.Assert.assertThrows;
- import static org.mockito.Mockito.when;
-
- @@ -23,6 +24,7 @@ import android.graphics.Bitmap;
- import android.graphics.Bitmap.Config;
- import android.graphics.ImageFormat;
- import android.media.Image;
- +
- import org.junit.Before;
- import org.junit.Test;
- import org.junit.runner.RunWith;
- @@ -33,34 +35,34 @@ import org.robolectric.RobolectricTestRunner;
- /** Tests for {@link MediaImageExtractor} */
- @RunWith(RobolectricTestRunner.class)
- public final class MediaImageExtractorTest {
- - private static final int HEIGHT = 100;
- - private static final int WIDTH = 50;
- + private static final int HEIGHT = 100;
- + private static final int WIDTH = 50;
-
- - @Mock private Image mediaImage;
- + @Mock
- + private Image mediaImage;
-
- - @Before
- - public void setUp() {
- - MockitoAnnotations.initMocks(this);
- + @Before
- + public void setUp() {
- + MockitoAnnotations.initMocks(this);
-
- - when(mediaImage.getHeight()).thenReturn(HEIGHT);
- - when(mediaImage.getWidth()).thenReturn(WIDTH);
- - when(mediaImage.getFormat()).thenReturn(ImageFormat.YUV_420_888);
- - }
- + when(mediaImage.getHeight()).thenReturn(HEIGHT);
- + when(mediaImage.getWidth()).thenReturn(WIDTH);
- + when(mediaImage.getFormat()).thenReturn(ImageFormat.YUV_420_888);
- + }
-
- - @Test
- - public void extract_fromMediaMlImage_succeeds() {
- - MlImage image = new MediaMlImageBuilder(mediaImage).build();
- - Image extractedMediaImage = MediaImageExtractor.extract(image);
- + @Test
- + public void extract_fromMediaMlImage_succeeds() {
- + MlImage image = new MediaMlImageBuilder(mediaImage).build();
- + Image extractedMediaImage = MediaImageExtractor.extract(image);
-
- - assertThat(extractedMediaImage).isSameInstanceAs(image);
- - }
- + assertThat(extractedMediaImage).isSameInstanceAs(image);
- + }
-
- - @Test
- - public void extract_fromBitmapMlImage_throwsException() {
- - MlImage image =
- - new BitmapMlImageBuilder(
- + @Test
- + public void extract_fromBitmapMlImage_throwsException() {
- + MlImage image = new BitmapMlImageBuilder(
- Bitmap.createBitmap(/* width= */ 20, /* height= */ 25, Config.ARGB_8888))
- - .build();
- - assertThrows(IllegalArgumentException.class, () -> MediaImageExtractor.extract(image));
- - }
- + .build();
- + assertThrows(IllegalArgumentException.class, () -> MediaImageExtractor.extract(image));
- + }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/MediaMlImageBuilderTest.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/MediaMlImageBuilderTest.java
- index 4f589874bfaf8..60397feceb067 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/MediaMlImageBuilderTest.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/MediaMlImageBuilderTest.java
- @@ -16,12 +16,14 @@ limitations under the License.
- package com.google.android.odml.image;
-
- import static com.google.common.truth.Truth.assertThat;
- +
- import static org.junit.Assert.assertThrows;
- import static org.mockito.Mockito.when;
-
- import android.graphics.ImageFormat;
- import android.graphics.Rect;
- import android.media.Image;
- +
- import org.junit.Before;
- import org.junit.Test;
- import org.junit.runner.RunWith;
- @@ -32,58 +34,57 @@ import org.robolectric.RobolectricTestRunner;
- /** Tests for {@link MediaMlImageBuilder} */
- @RunWith(RobolectricTestRunner.class)
- public final class MediaMlImageBuilderTest {
- - private static final int HEIGHT = 100;
- - private static final int WIDTH = 50;
- -
- - @Mock private Image mediaImage;
- -
- - @Before
- - public void setUp() {
- - MockitoAnnotations.initMocks(this);
- -
- - when(mediaImage.getHeight()).thenReturn(HEIGHT);
- - when(mediaImage.getWidth()).thenReturn(WIDTH);
- - when(mediaImage.getFormat()).thenReturn(ImageFormat.YUV_420_888);
- - }
- -
- - @Test
- - public void build_fromMediaImage_succeeds() {
- - MlImage image = new MediaMlImageBuilder(mediaImage).build();
- - ImageContainer container = image.getContainer(MlImage.STORAGE_TYPE_MEDIA_IMAGE);
- -
- - assertThat(image.getWidth()).isEqualTo(WIDTH);
- - assertThat(image.getHeight()).isEqualTo(HEIGHT);
- - assertThat(image.getRoi()).isEqualTo(new Rect(0, 0, WIDTH, HEIGHT));
- - assertThat(image.getRotation()).isEqualTo(0);
- - assertThat(image.getTimestamp()).isAtLeast(0);
- - assertThat(image.getContainedImageProperties())
- - .containsExactly(
- - ImageProperties.builder()
- - .setStorageType(MlImage.STORAGE_TYPE_MEDIA_IMAGE)
- - .setImageFormat(MlImage.IMAGE_FORMAT_YUV_420_888)
- - .build());
- - assertThat(((MediaImageContainer) container).getImage().getFormat())
- - .isEqualTo(ImageFormat.YUV_420_888);
- - }
- -
- - @Test
- - public void build_withOptionalProperties_succeeds() {
- - MlImage image =
- - new MediaMlImageBuilder(mediaImage)
- - .setTimestamp(12345)
- - .setRoi(new Rect(0, 5, 10, 15))
- - .setRotation(90)
- - .build();
- -
- - assertThat(image.getTimestamp()).isEqualTo(12345);
- - assertThat(image.getRotation()).isEqualTo(90);
- - assertThat(image.getRoi()).isEqualTo(new Rect(0, 5, 10, 15));
- - }
- -
- - @Test
- - public void build_withInvalidRotation_throwsException() {
- - MediaMlImageBuilder builder = new MediaMlImageBuilder(mediaImage);
- -
- - assertThrows(IllegalArgumentException.class, () -> builder.setRotation(360));
- - }
- + private static final int HEIGHT = 100;
- + private static final int WIDTH = 50;
- +
- + @Mock
- + private Image mediaImage;
- +
- + @Before
- + public void setUp() {
- + MockitoAnnotations.initMocks(this);
- +
- + when(mediaImage.getHeight()).thenReturn(HEIGHT);
- + when(mediaImage.getWidth()).thenReturn(WIDTH);
- + when(mediaImage.getFormat()).thenReturn(ImageFormat.YUV_420_888);
- + }
- +
- + @Test
- + public void build_fromMediaImage_succeeds() {
- + MlImage image = new MediaMlImageBuilder(mediaImage).build();
- + ImageContainer container = image.getContainer(MlImage.STORAGE_TYPE_MEDIA_IMAGE);
- +
- + assertThat(image.getWidth()).isEqualTo(WIDTH);
- + assertThat(image.getHeight()).isEqualTo(HEIGHT);
- + assertThat(image.getRoi()).isEqualTo(new Rect(0, 0, WIDTH, HEIGHT));
- + assertThat(image.getRotation()).isEqualTo(0);
- + assertThat(image.getTimestamp()).isAtLeast(0);
- + assertThat(image.getContainedImageProperties())
- + .containsExactly(ImageProperties.builder()
- + .setStorageType(MlImage.STORAGE_TYPE_MEDIA_IMAGE)
- + .setImageFormat(MlImage.IMAGE_FORMAT_YUV_420_888)
- + .build());
- + assertThat(((MediaImageContainer) container).getImage().getFormat())
- + .isEqualTo(ImageFormat.YUV_420_888);
- + }
- +
- + @Test
- + public void build_withOptionalProperties_succeeds() {
- + MlImage image = new MediaMlImageBuilder(mediaImage)
- + .setTimestamp(12345)
- + .setRoi(new Rect(0, 5, 10, 15))
- + .setRotation(90)
- + .build();
- +
- + assertThat(image.getTimestamp()).isEqualTo(12345);
- + assertThat(image.getRotation()).isEqualTo(90);
- + assertThat(image.getRoi()).isEqualTo(new Rect(0, 5, 10, 15));
- + }
- +
- + @Test
- + public void build_withInvalidRotation_throwsException() {
- + MediaMlImageBuilder builder = new MediaMlImageBuilder(mediaImage);
- +
- + assertThrows(IllegalArgumentException.class, () -> builder.setRotation(360));
- + }
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/TestImageCreator.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/TestImageCreator.java
- index c9e7134bedd93..28f54be2c70a3 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/TestImageCreator.java
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/TestImageCreator.java
- @@ -17,6 +17,7 @@ package com.google.android.odml.image;
-
- import android.graphics.Bitmap;
- import android.graphics.Color;
- +
- import java.nio.ByteBuffer;
-
- /**
- @@ -35,113 +36,113 @@ import java.nio.ByteBuffer;
- * <p>The created {@link Bitmap} is not pre-multiplied.
- */
- final class TestImageCreator {
- + private static final int RED = 0x73;
- + private static final int GREEN = 0x85;
- + private static final int BLUE = 0x96;
- + private static final int ALPHA = 0x70;
- +
- + static int getWidth() {
- + return 10;
- + }
- +
- + static int getHeight() {
- + return 2;
- + }
- +
- + /**
- + * Creates an example non-pre-multiplied bitmap which is 100% opaque.
- + *
- + * @see TestImageCreator for details.
- + */
- + static Bitmap createOpaqueRgbaBitmap() {
- + return createRgbaBitmap(0xff);
- + }
- +
- + /**
- + * Creates an example non-pre-multiplied bitmap which has non-trivial alpha channel.
- + *
- + * @see TestImageCreator for details.
- + */
- + static Bitmap createRgbaBitmap() {
- + return createRgbaBitmap(ALPHA);
- + }
-
- - private static final int RED = 0x73;
- - private static final int GREEN = 0x85;
- - private static final int BLUE = 0x96;
- - private static final int ALPHA = 0x70;
- -
- - static int getWidth() {
- - return 10;
- - }
- -
- - static int getHeight() {
- - return 2;
- - }
- -
- - /**
- - * Creates an example non-pre-multiplied bitmap which is 100% opaque.
- - *
- - * @see TestImageCreator for details.
- - */
- - static Bitmap createOpaqueRgbaBitmap() {
- - return createRgbaBitmap(0xff);
- - }
- -
- - /**
- - * Creates an example non-pre-multiplied bitmap which has non-trivial alpha channel.
- - *
- - * @see TestImageCreator for details.
- - */
- - static Bitmap createRgbaBitmap() {
- - return createRgbaBitmap(ALPHA);
- - }
- -
- - /**
- - * Creates an example 10x2 bitmap demonstrated in the class doc. A channel sets to {@code alpha}.
- - */
- - static Bitmap createRgbaBitmap(int alpha) {
- - int[] colors = new int[20];
- - for (int i = 0; i < 5; i++) {
- - colors[i] = Color.argb(alpha, 0, 0, BLUE);
- - colors[i + 5] = Color.argb(alpha, 0xff, 0xff, 0xff);
- - colors[i + 10] = Color.argb(alpha, 0, GREEN, 0);
- - colors[i + 15] = Color.argb(alpha, RED, 0, 0);
- + /**
- + * Creates an example 10x2 bitmap demonstrated in the class doc. A channel sets to {@code
- + * alpha}.
- + */
- + static Bitmap createRgbaBitmap(int alpha) {
- + int[] colors = new int[20];
- + for (int i = 0; i < 5; i++) {
- + colors[i] = Color.argb(alpha, 0, 0, BLUE);
- + colors[i + 5] = Color.argb(alpha, 0xff, 0xff, 0xff);
- + colors[i + 10] = Color.argb(alpha, 0, GREEN, 0);
- + colors[i + 15] = Color.argb(alpha, RED, 0, 0);
- + }
- + // We don't use Bitmap#createBitmap(int[] ...) here, because that method creates
- + // pre-multiplied bitmaps.
- + Bitmap bitmap = Bitmap.createBitmap(10, 2, Bitmap.Config.ARGB_8888);
- + bitmap.setPremultiplied(false);
- + bitmap.setPixels(colors, 0, 10, 0, 0, 10, 2);
- + return bitmap;
- }
- - // We don't use Bitmap#createBitmap(int[] ...) here, because that method creates pre-multiplied
- - // bitmaps.
- - Bitmap bitmap = Bitmap.createBitmap(10, 2, Bitmap.Config.ARGB_8888);
- - bitmap.setPremultiplied(false);
- - bitmap.setPixels(colors, 0, 10, 0, 0, 10, 2);
- - return bitmap;
- - }
- -
- - /**
- - * Creates an example 10*10*3 bytebuffer in R-G-B format.
- - *
- - * @see TestImageCreator for details.
- - */
- - static ByteBuffer createRgbBuffer() {
- - return createRgbOrRgbaBuffer(false, 0xff);
- - }
- -
- - /**
- - * Creates an example 10*10*4 bytebuffer in R-G-B-A format.
- - *
- - * @see TestImageCreator for details.
- - */
- - static ByteBuffer createRgbaBuffer() {
- - return createRgbOrRgbaBuffer(true, ALPHA);
- - }
- -
- - /**
- - * Creates an example 10*10*4 bytebuffer in R-G-B-A format, but the A channel is 0xFF.
- - *
- - * @see TestImageCreator for details.
- - */
- - static ByteBuffer createOpaqueRgbaBuffer() {
- - return createRgbOrRgbaBuffer(true, 0xff);
- - }
- -
- - /**
- - * Creates an example 10x2x4 (or 10x2x3 if no alpha) bytebuffer demonstrated in the class doc.
- - *
- - * @param withAlpha if true, set A to {@code alpha}, otherwise A channel is ignored.
- - * @param alpha alpha channel value. Only effective when {@code withAlpha} is {@code true}.
- - */
- - static ByteBuffer createRgbOrRgbaBuffer(boolean withAlpha, int alpha) {
- - int capacity = withAlpha ? 80 : 60;
- - ByteBuffer buffer = ByteBuffer.allocateDirect(capacity);
- - putColorInByteBuffer(buffer, 0, 0, BLUE, withAlpha, alpha, 5);
- - putColorInByteBuffer(buffer, 0xff, 0xff, 0xff, withAlpha, alpha, 5);
- - putColorInByteBuffer(buffer, 0, GREEN, 0, withAlpha, alpha, 5);
- - putColorInByteBuffer(buffer, RED, 0, 0, withAlpha, alpha, 5);
- - buffer.rewind();
- - return buffer;
- - }
- -
- - private static void putColorInByteBuffer(
- - ByteBuffer buffer, int r, int g, int b, boolean withAlpha, int alpha, int num) {
- - for (int i = 0; i < num; i++) {
- - buffer.put((byte) r);
- - buffer.put((byte) g);
- - buffer.put((byte) b);
- - if (withAlpha) {
- - buffer.put((byte) alpha);
- - }
- +
- + /**
- + * Creates an example 10*10*3 bytebuffer in R-G-B format.
- + *
- + * @see TestImageCreator for details.
- + */
- + static ByteBuffer createRgbBuffer() {
- + return createRgbOrRgbaBuffer(false, 0xff);
- + }
- +
- + /**
- + * Creates an example 10*10*4 bytebuffer in R-G-B-A format.
- + *
- + * @see TestImageCreator for details.
- + */
- + static ByteBuffer createRgbaBuffer() {
- + return createRgbOrRgbaBuffer(true, ALPHA);
- + }
- +
- + /**
- + * Creates an example 10*10*4 bytebuffer in R-G-B-A format, but the A channel is 0xFF.
- + *
- + * @see TestImageCreator for details.
- + */
- + static ByteBuffer createOpaqueRgbaBuffer() {
- + return createRgbOrRgbaBuffer(true, 0xff);
- + }
- +
- + /**
- + * Creates an example 10x2x4 (or 10x2x3 if no alpha) bytebuffer demonstrated in the class doc.
- + *
- + * @param withAlpha if true, set A to {@code alpha}, otherwise A channel is ignored.
- + * @param alpha alpha channel value. Only effective when {@code withAlpha} is {@code true}.
- + */
- + static ByteBuffer createRgbOrRgbaBuffer(boolean withAlpha, int alpha) {
- + int capacity = withAlpha ? 80 : 60;
- + ByteBuffer buffer = ByteBuffer.allocateDirect(capacity);
- + putColorInByteBuffer(buffer, 0, 0, BLUE, withAlpha, alpha, 5);
- + putColorInByteBuffer(buffer, 0xff, 0xff, 0xff, withAlpha, alpha, 5);
- + putColorInByteBuffer(buffer, 0, GREEN, 0, withAlpha, alpha, 5);
- + putColorInByteBuffer(buffer, RED, 0, 0, withAlpha, alpha, 5);
- + buffer.rewind();
- + return buffer;
- + }
- +
- + private static void putColorInByteBuffer(
- + ByteBuffer buffer, int r, int g, int b, boolean withAlpha, int alpha, int num) {
- + for (int i = 0; i < num; i++) {
- + buffer.put((byte) r);
- + buffer.put((byte) g);
- + buffer.put((byte) b);
- + if (withAlpha) {
- + buffer.put((byte) alpha);
- + }
- + }
- }
- - }
-
- - // Should not be instantiated.
- - private TestImageCreator() {}
- + // Should not be instantiated.
- + private TestImageCreator() {}
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/task/audio/core/pybinds/_pywrap_audio_buffer.cc b/third_party/tflite_support/src/tensorflow_lite_support/python/task/audio/core/pybinds/_pywrap_audio_buffer.cc
- index b46a997c4e254..c5e317d8a82c0 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/python/task/audio/core/pybinds/_pywrap_audio_buffer.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/python/task/audio/core/pybinds/_pywrap_audio_buffer.cc
- @@ -39,16 +39,15 @@ PYBIND11_MODULE(_pywrap_audio_buffer, m) {
- .def_readonly("sample_rate", &AudioBuffer::AudioFormat::sample_rate);
-
- py::class_<AudioBuffer>(m, "AudioBuffer", py::buffer_protocol())
- - .def(py::init([](
- - py::buffer buffer, const int sample_count,
- - const AudioBuffer::AudioFormat& audio_format)
- - -> std::unique_ptr<AudioBuffer> {
- - py::buffer_info info = buffer.request();
- + .def(py::init([](py::buffer buffer, const int sample_count,
- + const AudioBuffer::AudioFormat& audio_format)
- + -> std::unique_ptr<AudioBuffer> {
- + py::buffer_info info = buffer.request();
-
- - auto audio_buffer = AudioBuffer::Create(
- - static_cast<float*>(info.ptr), sample_count, audio_format);
- - return core::get_value(audio_buffer);
- - }))
- + auto audio_buffer = AudioBuffer::Create(static_cast<float*>(info.ptr),
- + sample_count, audio_format);
- + return core::get_value(audio_buffer);
- + }))
- .def_property_readonly("audio_format", &AudioBuffer::GetAudioFormat)
- .def_property_readonly("buffer_size", &AudioBuffer::GetBufferSize)
- .def_property_readonly("float_buffer", [](AudioBuffer& self) {
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/task/audio/pybinds/_pywrap_audio_classifier.cc b/third_party/tflite_support/src/tensorflow_lite_support/python/task/audio/pybinds/_pywrap_audio_classifier.cc
- index 5d94db2a01b37..e2054cf645c08 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/python/task/audio/pybinds/_pywrap_audio_classifier.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/python/task/audio/pybinds/_pywrap_audio_classifier.cc
- @@ -20,7 +20,6 @@ limitations under the License.
- #include "tensorflow_lite_support/cc/task/audio/proto/classifications_proto_inc.h"
- #include "tensorflow_lite_support/cc/task/processor/proto/classification_options.pb.h"
- #include "tensorflow_lite_support/cc/task/processor/proto/classifications.pb.h"
- -#include "tensorflow_lite_support/cc/task/processor/proto/classifications.pb.h"
- #include "tensorflow_lite_support/python/task/core/pybinds/task_utils.h"
-
- namespace tflite {
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/task/audio/pybinds/_pywrap_audio_embedder.cc b/third_party/tflite_support/src/tensorflow_lite_support/python/task/audio/pybinds/_pywrap_audio_embedder.cc
- index 50e0b4f7ce4a8..8b1d67d9f8e05 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/python/task/audio/pybinds/_pywrap_audio_embedder.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/python/task/audio/pybinds/_pywrap_audio_embedder.cc
- @@ -15,9 +15,9 @@ limitations under the License.
-
- #include "pybind11/pybind11.h"
- #include "pybind11_protobuf/native_proto_caster.h" // from @pybind11_protobuf
- -#include "tensorflow_lite_support/cc/task/processor/proto/embedding.pb.h"
- #include "tensorflow_lite_support/cc/task/audio/audio_embedder.h"
- #include "tensorflow_lite_support/cc/task/audio/core/audio_buffer.h"
- +#include "tensorflow_lite_support/cc/task/processor/proto/embedding.pb.h"
- #include "tensorflow_lite_support/python/task/core/pybinds/task_utils.h"
-
- namespace tflite {
- @@ -50,17 +50,17 @@ PYBIND11_MODULE(_pywrap_audio_embedder, m) {
- return core::get_value(embedder);
- })
- .def_static("cosine_similarity",
- - [](const processor::FeatureVector& u,
- - const processor::FeatureVector& v) -> double {
- - auto similarity = AudioEmbedder::CosineSimilarity(u, v);
- - return core::get_value(similarity);
- - })
- + [](const processor::FeatureVector& u,
- + const processor::FeatureVector& v) -> double {
- + auto similarity = AudioEmbedder::CosineSimilarity(u, v);
- + return core::get_value(similarity);
- + })
- .def("embed",
- - [](AudioEmbedder& self,
- - const AudioBuffer& audio_buffer) -> processor::EmbeddingResult {
- - auto embedding_result = self.Embed(audio_buffer);
- - return core::get_value(embedding_result);
- - })
- + [](AudioEmbedder& self,
- + const AudioBuffer& audio_buffer) -> processor::EmbeddingResult {
- + auto embedding_result = self.Embed(audio_buffer);
- + return core::get_value(embedding_result);
- + })
- .def("get_embedding_dimension", &AudioEmbedder::GetEmbeddingDimension)
- .def("get_number_of_output_layers",
- &AudioEmbedder::GetNumberOfOutputLayers)
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/core/pybinds/image_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/core/pybinds/image_utils.cc
- index 977b4e16175ac..124f5cb1ad15d 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/core/pybinds/image_utils.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/core/pybinds/image_utils.cc
- @@ -43,13 +43,13 @@ PYBIND11_MODULE(image_utils, m) {
- int width = info.shape[1];
- int channels = info.ndim == 3 ? info.shape[2] : 1;
-
- - return ImageData{static_cast<uint8 *>(info.ptr), width, height,
- + return ImageData{static_cast<uint8*>(info.ptr), width, height,
- channels};
- }))
- .def_readonly("width", &ImageData::width)
- .def_readonly("height", &ImageData::height)
- .def_readonly("channels", &ImageData::channels)
- - .def_buffer([](ImageData &data) -> py::buffer_info {
- + .def_buffer([](ImageData& data) -> py::buffer_info {
- return py::buffer_info(
- data.pixel_data, sizeof(uint8),
- py::format_descriptor<uint8>::format(), 3,
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/pybinds/_pywrap_image_classifier.cc b/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/pybinds/_pywrap_image_classifier.cc
- index 4ca20a363345e..b4f23baa6e0b1 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/pybinds/_pywrap_image_classifier.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/pybinds/_pywrap_image_classifier.cc
- @@ -67,17 +67,17 @@ PYBIND11_MODULE(_pywrap_image_classifier, m) {
- return core::get_value(classifier);
- })
- .def("classify",
- - [](ImageClassifier& self, const ImageData& image_data)
- - -> processor::ClassificationResult {
- + [](ImageClassifier& self,
- + const ImageData& image_data) -> processor::ClassificationResult {
- auto frame_buffer = CreateFrameBufferFromImageData(image_data);
- - auto vision_classification_result = self.Classify(
- - *core::get_value(frame_buffer));
- + auto vision_classification_result =
- + self.Classify(*core::get_value(frame_buffer));
- // Convert from vision::ClassificationResult to
- // processor::ClassificationResult as required by the Python layer.
- processor::ClassificationResult classification_result;
- - classification_result.ParseFromString(
- + classification_result.ParseFromString(
- core::get_value(vision_classification_result)
- - .SerializeAsString());
- + .SerializeAsString());
- return classification_result;
- })
- .def("classify",
- @@ -96,9 +96,9 @@ PYBIND11_MODULE(_pywrap_image_classifier, m) {
- // Convert from vision::ClassificationResult to
- // processor::ClassificationResult as required by the Python layer.
- processor::ClassificationResult classification_result;
- - classification_result.ParseFromString(
- + classification_result.ParseFromString(
- core::get_value(vision_classification_result)
- - .SerializeAsString());
- + .SerializeAsString());
- return classification_result;
- });
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/pybinds/_pywrap_image_segmenter.cc b/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/pybinds/_pywrap_image_segmenter.cc
- index 3ebf09fb4f284..e71048e9ebb0b 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/pybinds/_pywrap_image_segmenter.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/pybinds/_pywrap_image_segmenter.cc
- @@ -47,23 +47,23 @@ PYBIND11_MODULE(_pywrap_image_segmenter, m) {
-
- if (segmentation_options.has_display_names_locale()) {
- options.set_display_names_locale(
- - segmentation_options.display_names_locale());
- + segmentation_options.display_names_locale());
- }
- if (segmentation_options.has_output_type()) {
- options.set_output_type(
- static_cast<ImageSegmenterOptions::OutputType>(
- - segmentation_options.output_type()));
- + segmentation_options.output_type()));
- }
-
- auto segmenter = ImageSegmenter::CreateFromOptions(options);
- return core::get_value(segmenter);
- })
- .def("segment",
- - [](ImageSegmenter& self, const ImageData& image_data)
- - -> SegmentationResult {
- + [](ImageSegmenter& self,
- + const ImageData& image_data) -> SegmentationResult {
- auto frame_buffer = CreateFrameBufferFromImageData(image_data);
- - auto vision_segmentation_result = self.Segment(
- - *core::get_value(frame_buffer));
- + auto vision_segmentation_result =
- + self.Segment(*core::get_value(frame_buffer));
- return core::get_value(vision_segmentation_result);
- });
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/pybinds/_pywrap_object_detector.cc b/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/pybinds/_pywrap_object_detector.cc
- index 39e39c9df00e1..3749efc811019 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/pybinds/_pywrap_object_detector.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/pybinds/_pywrap_object_detector.cc
- @@ -65,17 +65,16 @@ PYBIND11_MODULE(_pywrap_object_detector, m) {
- return core::get_value(detector);
- })
- .def("detect",
- - [](ObjectDetector& self, const ImageData& image_data)
- - -> processor::DetectionResult {
- + [](ObjectDetector& self,
- + const ImageData& image_data) -> processor::DetectionResult {
- auto frame_buffer = CreateFrameBufferFromImageData(image_data);
- - auto vision_detection_result = self.Detect(
- - *core::get_value(frame_buffer));
- + auto vision_detection_result =
- + self.Detect(*core::get_value(frame_buffer));
- // Convert from vision::DetectionResult to
- // processor::DetectionResult as required by the Python layer.
- processor::DetectionResult detection_result;
- - detection_result.ParseFromString(
- - core::get_value(vision_detection_result)
- - .SerializeAsString());
- + detection_result.ParseFromString(
- + core::get_value(vision_detection_result).SerializeAsString());
- return detection_result;
- });
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/index_table_sum.h b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/index_table_sum.h
- index 89c96e7d5e50a..67e0e303d4231 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/index_table_sum.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/index_table_sum.h
- @@ -29,7 +29,9 @@ namespace scann_ondevice {
- namespace core {
-
- template <typename LutType>
- -void RearrangeLUT(const LutType* input_data, int batch_elems, int batch_size,
- +void RearrangeLUT(const LutType* input_data,
- + int batch_elems,
- + int batch_size,
- LutType* const output_data) {
- std::vector<int64_t> simd_sizes;
- if (std::is_same<LutType, float>::value) {
- @@ -88,10 +90,15 @@ struct MaxQuantizationValue<uint16_t> {
- };
-
- template <typename SimdType, typename LutType, size_t NumCenters = 0>
- -size_t IndexTableSumSimdBatch(const uint8_t* indices, size_t num_chunks,
- - size_t num_outputs, const LutType* lookup_table,
- - size_t batch_size, size_t num_centers, float min,
- - float max, size_t batch_index,
- +size_t IndexTableSumSimdBatch(const uint8_t* indices,
- + size_t num_chunks,
- + size_t num_outputs,
- + const LutType* lookup_table,
- + size_t batch_size,
- + size_t num_centers,
- + float min,
- + float max,
- + size_t batch_index,
- float* const output) {
- if (num_centers == 256) {
- return IndexTableSumSimdBatch<SimdType, LutType, 256>(
- @@ -176,9 +183,14 @@ size_t IndexTableSumSimdBatch(const uint8_t* indices, size_t num_chunks,
- }
-
- template <typename LutType>
- -void IndexTableSum(const uint8_t* indices, size_t num_chunks,
- - size_t num_outputs, const LutType* lookup_table,
- - size_t batch_size, size_t num_centers, float min, float max,
- +void IndexTableSum(const uint8_t* indices,
- + size_t num_chunks,
- + size_t num_outputs,
- + const LutType* lookup_table,
- + size_t batch_size,
- + size_t num_centers,
- + float min,
- + float max,
- float* const output) {
- static_assert(std::is_same<LutType, uint8_t>::value ||
- std::is_same<LutType, uint16_t>::value,
- @@ -206,10 +218,15 @@ void IndexTableSum(const uint8_t* indices, size_t num_chunks,
- }
-
- template <>
- -inline void IndexTableSum<float>(const uint8_t* indices, size_t num_chunks,
- - size_t num_outputs, const float* lookup_table,
- - size_t batch_size, size_t num_centers,
- - float min, float max, float* const output) {
- +inline void IndexTableSum<float>(const uint8_t* indices,
- + size_t num_chunks,
- + size_t num_outputs,
- + const float* lookup_table,
- + size_t batch_size,
- + size_t num_centers,
- + float min,
- + float max,
- + float* const output) {
- std::fill(output, output + batch_size * num_outputs, 0.0f);
- size_t i = 0;
- #ifdef __AVX__
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/indexer.cc b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/indexer.cc
- index ed17d7f1708f8..6df064553d2c5 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/indexer.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/indexer.cc
- @@ -31,7 +31,8 @@ namespace core {
- namespace {
-
- float ComputeSquaredL2Distance(Span<const float> a, Span<const float> b) {
- - if (a.size() != b.size()) return 0;
- + if (a.size() != b.size())
- + return 0;
- float result = 0;
- for (int i = 0; i < a.size(); ++i) {
- result += (a[i] - b[i]) * (a[i] - b[i]);
- @@ -40,7 +41,8 @@ float ComputeSquaredL2Distance(Span<const float> a, Span<const float> b) {
- }
-
- float ComputeDotProductDistance(Span<const float> a, Span<const float> b) {
- - if (a.size() != b.size()) return 0;
- + if (a.size() != b.size())
- + return 0;
- float result = 0;
- for (int i = 0; i < a.size(); ++i) {
- result += a[i] * b[i];
- @@ -62,7 +64,8 @@ AsymmetricHashingIndexer::AsymmetricHashingIndexer(
- int subspace_index = 0;
- for (const AsymmetricHashingProto::SubspaceCodebook& codebook :
- ah_proto.subspace()) {
- - if (codebook.entry().empty()) return;
- + if (codebook.entry().empty())
- + return;
-
- const int dimension = codebook.entry(0).dimension_size();
- const int num_codes = codebook.entry_size();
- @@ -81,13 +84,17 @@ AsymmetricHashingIndexer::AsymmetricHashingIndexer(
- }
-
- total_dimension_ = 0;
- - for (const uint8_t dim : dimensions_) total_dimension_ += dim;
- + for (const uint8_t dim : dimensions_)
- + total_dimension_ += dim;
- }
-
- void AsymmetricHashingIndexer::EncodeDatapoint(
- - absl::Span<const float> original, absl::Span<uint8_t> encoded) const {
- - if (original.size() != total_dimension_) return;
- - if (encoded.size() != dimensions_.size()) return;
- + absl::Span<const float> original,
- + absl::Span<uint8_t> encoded) const {
- + if (original.size() != total_dimension_)
- + return;
- + if (encoded.size() != dimensions_.size())
- + return;
-
- int start_index = 0;
- for (int i = 0; i < dimensions_.size(); ++i) {
- @@ -118,7 +125,8 @@ void AsymmetricHashingIndexer::EncodeDatapoint(
- }
-
- absl::Status AsymmetricHashingIndexer::DecodeDatapoint(
- - absl::Span<const uint8_t> encoded, absl::Span<float> reconstructed) const {
- + absl::Span<const uint8_t> encoded,
- + absl::Span<float> reconstructed) const {
- if (encoded.size() < dimensions_.size()) {
- return absl::InvalidArgumentError("Mismatching dimensions");
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/indexer.h b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/indexer.h
- index 0328a75837ba9..a0515667e8373 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/indexer.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/indexer.h
- @@ -21,7 +21,7 @@ limitations under the License.
- #include <string>
-
- #include "absl/status/status.h" // from @com_google_absl
- -#include "absl/types/span.h" // from @com_google_absl
- +#include "absl/types/span.h" // from @com_google_absl
- #include "tensorflow_lite_support/scann_ondevice/cc/core/serialized_searcher.pb.h"
-
- namespace tflite {
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/indexer_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/indexer_test.cc
- index 3ef7427d6e21a..fca3d8f3d21c9 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/indexer_test.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/indexer_test.cc
- @@ -92,7 +92,8 @@ TEST(IndexerTest, SquaredL2AsymmetricReconstruct1) {
- indexer.EncodeDatapoint(datapoint, absl::MakeSpan(result));
-
- vector<float> datapoint_recon(5, 0);
- - SUPPORT_EXPECT_OK(indexer.DecodeDatapoint(result, absl::MakeSpan(datapoint_recon)));
- + SUPPORT_EXPECT_OK(
- + indexer.DecodeDatapoint(result, absl::MakeSpan(datapoint_recon)));
-
- EXPECT_EQ(std::vector<float>({0.1, 0.2, -0.1, -0.2, -0.3}), datapoint_recon);
- }
- @@ -122,7 +123,8 @@ TEST(IndexerTest, SquaredL2AsymmetricReconstruct2) {
- indexer.EncodeDatapoint(datapoint, absl::MakeSpan(result));
-
- vector<float> datapoint_recon = {0.1, 0.2, -0.1, -0.2, -0.3};
- - SUPPORT_EXPECT_OK(indexer.DecodeDatapoint(result, absl::MakeSpan(datapoint_recon)));
- + SUPPORT_EXPECT_OK(
- + indexer.DecodeDatapoint(result, absl::MakeSpan(datapoint_recon)));
-
- EXPECT_EQ(std::vector<float>({0.9, 0.8, -0.3, -0.2, -0.1}), datapoint_recon);
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/partitioner.cc b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/partitioner.cc
- index 3217c57c0e831..e86fd77cc3321 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/partitioner.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/partitioner.cc
- @@ -87,7 +87,9 @@ bool Partitioner::Partition(const Eigen::Ref<const Eigen::MatrixXf>& queries,
- return true;
- }
-
- -int Partitioner::NumPartitions() const { return leaves_.rows(); }
- +int Partitioner::NumPartitions() const {
- + return leaves_.rows();
- +}
-
- bool NoOpPartitioner::Partition(
- const Eigen::Ref<const Eigen::MatrixXf>& queries,
- @@ -108,7 +110,9 @@ bool NoOpPartitioner::Partition(
- return true;
- }
-
- -int NoOpPartitioner::NumPartitions() const { return 1; }
- +int NoOpPartitioner::NumPartitions() const {
- + return 1;
- +}
-
- } // namespace core
- } // namespace scann_ondevice
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/partitioner.h b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/partitioner.h
- index 2a1fb36e9f28e..f4e9eb9e34804 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/partitioner.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/partitioner.h
- @@ -17,8 +17,8 @@ limitations under the License.
-
- #include <utility>
-
- +#include "Eigen/Core" // from @eigen
- #include "absl/types/optional.h" // from @com_google_absl
- -#include "Eigen/Core" // from @eigen
- #include "tensorflow_lite_support/scann_ondevice/cc/core/serialized_searcher.pb.h"
-
- namespace tflite {
- @@ -45,7 +45,8 @@ class Partitioner : public PartitionerInterface {
- }
-
- private:
- - Partitioner(Eigen::MatrixXf leaves, Eigen::VectorXf leaf_norms,
- + Partitioner(Eigen::MatrixXf leaves,
- + Eigen::VectorXf leaf_norms,
- DistanceMeasure distance)
- : leaves_(std::move(leaves)),
- leaf_norms_(std::move(leaf_norms)),
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/searcher.h b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/searcher.h
- index 9fab870790db6..419681b829b1d 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/searcher.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/searcher.h
- @@ -22,8 +22,8 @@ limitations under the License.
- #include <vector>
-
- #include <glog/logging.h>
- +#include "Eigen/Core" // from @eigen
- #include "absl/types/span.h" // from @com_google_absl
- -#include "Eigen/Core" // from @eigen
- #include "tensorflow_lite_support/cc/port/integral_types.h"
- #include "tensorflow_lite_support/scann_ondevice/cc/core/index_table_sum.h"
- #include "tensorflow_lite_support/scann_ondevice/cc/core/processor.h"
- @@ -47,7 +47,8 @@ void ComputeAHDistance(const QueryInfo& query_info,
- template <class T>
- bool AsymmetricHashFindNeighbors(const QueryInfo& query_info,
- Eigen::Ref<const Matrix8u> database,
- - size_t global_offset, absl::Span<T> topn) {
- + size_t global_offset,
- + absl::Span<T> topn) {
- const int batch_size = query_info.query_lut->cols();
- if (topn.size() != batch_size) {
- return false;
- @@ -67,7 +68,8 @@ template <class T>
- bool AsymmetricHashFindNeighbors(Eigen::Ref<const Eigen::MatrixXf> queries,
- const PreProcessorInterface& preprocessor,
- Eigen::Ref<const Matrix8u> database,
- - size_t global_offset, absl::Span<T> topn) {
- + size_t global_offset,
- + absl::Span<T> topn) {
- if (queries.cols() != topn.size()) {
- return false;
- }
- @@ -116,10 +118,12 @@ template <class T>
- class AsymmetricHashLeafSearcherT : public SearcherInterfaceT<T> {
- public:
- static std::unique_ptr<AsymmetricHashLeafSearcherT<T>> Create(
- - std::shared_ptr<QueryInfo::Matrix<uint8_t>> database, int global_offset,
- + std::shared_ptr<QueryInfo::Matrix<uint8_t>> database,
- + int global_offset,
- std::shared_ptr<PreProcessorInterface> preprocessor);
- static std::unique_ptr<AsymmetricHashLeafSearcherT<T>> Create(
- - std::shared_ptr<QueryInfo::Matrix<uint8_t>> database, int global_offset,
- + std::shared_ptr<QueryInfo::Matrix<uint8_t>> database,
- + int global_offset,
- std::shared_ptr<PreProcessorInterface> preprocessor,
- size_t mini_batch_size);
- bool FindNeighbors(const Eigen::Ref<const Eigen::MatrixXf>& queries,
- @@ -128,7 +132,8 @@ class AsymmetricHashLeafSearcherT : public SearcherInterfaceT<T> {
-
- private:
- AsymmetricHashLeafSearcherT(
- - std::shared_ptr<QueryInfo::Matrix<uint8_t>> database, int global_offset,
- + std::shared_ptr<QueryInfo::Matrix<uint8_t>> database,
- + int global_offset,
- std::shared_ptr<PreProcessorInterface> preprocessor,
- size_t mini_batch_size)
- : database_(std::move(database)),
- @@ -154,7 +159,8 @@ class LinearLeafSearcherT : public SearcherInterfaceT<T> {
-
- private:
- LinearLeafSearcherT(std::shared_ptr<Eigen::MatrixXf> database,
- - DistanceMeasure distance_measure, int global_offset)
- + DistanceMeasure distance_measure,
- + int global_offset)
- : database_(std::move(database)),
- distance_measure_(distance_measure),
- global_offset_(global_offset) {}
- @@ -167,7 +173,8 @@ class LinearLeafSearcherT : public SearcherInterfaceT<T> {
- template <class T>
- std::unique_ptr<AsymmetricHashLeafSearcherT<T>>
- AsymmetricHashLeafSearcherT<T>::Create(
- - std::shared_ptr<Matrix8u> database, int global_offset,
- + std::shared_ptr<Matrix8u> database,
- + int global_offset,
- std::shared_ptr<PreProcessorInterface> preprocessor) {
- return AsymmetricHashLeafSearcherT<T>::Create(
- database, global_offset, preprocessor,
- @@ -177,7 +184,8 @@ AsymmetricHashLeafSearcherT<T>::Create(
- template <class T>
- std::unique_ptr<AsymmetricHashLeafSearcherT<T>>
- AsymmetricHashLeafSearcherT<T>::Create(
- - std::shared_ptr<Matrix8u> database, int global_offset,
- + std::shared_ptr<Matrix8u> database,
- + int global_offset,
- std::shared_ptr<PreProcessorInterface> preprocessor,
- size_t mini_batch_size) {
- if (mini_batch_size == 0 || global_offset < 0) {
- @@ -220,7 +228,8 @@ bool AsymmetricHashLeafSearcherT<T>::FindNeighbors(const QueryInfo& query_info,
-
- template <class T>
- std::unique_ptr<LinearLeafSearcherT<T>> LinearLeafSearcherT<T>::Create(
- - std::shared_ptr<Eigen::MatrixXf> database, DistanceMeasure distance_measure,
- + std::shared_ptr<Eigen::MatrixXf> database,
- + DistanceMeasure distance_measure,
- int global_offset) {
- if (global_offset < 0) {
- return nullptr;
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/searcher_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/searcher_test.cc
- index 8c67bca0da939..f3931f3619b8d 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/searcher_test.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/searcher_test.cc
- @@ -21,8 +21,8 @@ limitations under the License.
- #include <utility>
-
- #include <glog/logging.h>
- +#include "Eigen/Core" // from @eigen
- #include "absl/synchronization/mutex.h" // from @com_google_absl
- -#include "Eigen/Core" // from @eigen
- #include "tensorflow_lite_support/cc/port/gmock.h"
- #include "tensorflow_lite_support/cc/port/gtest.h"
- #include "tensorflow_lite_support/cc/port/integral_types.h"
- @@ -520,9 +520,10 @@ TEST_P(SearcherTest, AsymmetricHashMiniBatchedSimdFail) {
- }
- #endif
-
- -INSTANTIATE_TEST_SUITE_P(SearcherTest, SearcherTest,
- - Values(std::numeric_limits<size_t>::max(), 1, 2, 3, 7,
- - 23));
- +INSTANTIATE_TEST_SUITE_P(
- + SearcherTest,
- + SearcherTest,
- + Values(std::numeric_limits<size_t>::max(), 1, 2, 3, 7, 23));
-
- } // namespace
-
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/top_n_amortized_constant.h b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/top_n_amortized_constant.h
- index 8f53ddf0669c4..3e5a6b00736d0 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/top_n_amortized_constant.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/top_n_amortized_constant.h
- @@ -44,7 +44,8 @@ class TopNAmortizedConstant {
- std::vector<T> TakeUnsorted() {
- DCHECK_GT(limit_, 0) << "Cannot call TakeUnsorted on uninitialized "
- "TopNAmortizedConstant instance.";
- - if (elements_.size() > limit_) PartitionAndResizeToLimit();
- + if (elements_.size() > limit_)
- + PartitionAndResizeToLimit();
- auto result = std::move(elements_);
- elements_.clear();
- approx_bottom_ = original_approx_bottom_;
- @@ -53,13 +54,15 @@ class TopNAmortizedConstant {
- const std::vector<T>& ExtractUnsorted() {
- DCHECK_GT(limit_, 0) << "Cannot call ExtractUnsorted on uninitialized "
- "TopNAmortizedConstant instance.";
- - if (elements_.size() > limit_) PartitionAndResizeToLimit();
- + if (elements_.size() > limit_)
- + PartitionAndResizeToLimit();
- return elements_;
- }
- std::vector<T> Take() {
- DCHECK_GT(limit_, 0) << "Cannot call Take on uninitialized "
- "TopNAmortizedConstant instance.";
- - if (elements_.size() > limit_) PartitionAndResizeToLimit();
- + if (elements_.size() > limit_)
- + PartitionAndResizeToLimit();
- std::sort(elements_.begin(), elements_.end(), cmp_);
- auto result = std::move(elements_);
- elements_.clear();
- @@ -100,7 +103,8 @@ struct Comparator {
- const std::pair<float, int>& b) const {
- return a.first < b.first;
- }
- - bool operator()(float distance, int,
- + bool operator()(float distance,
- + int,
- const std::pair<float, int>& other) const {
- return distance < other.first;
- }
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/index.cc b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/index.cc
- index 8e45119d7364d..e8be5f6572f17 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/index.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/index.cc
- @@ -18,17 +18,17 @@ limitations under the License.
- #include <cstddef>
- #include <memory>
-
- -#include "absl/memory/memory.h" // from @com_google_absl
- -#include "absl/status/status.h" // from @com_google_absl
- -#include "absl/status/statusor.h" // from @com_google_absl
- -#include "absl/strings/str_format.h" // from @com_google_absl
- +#include "absl/memory/memory.h" // from @com_google_absl
- +#include "absl/status/status.h" // from @com_google_absl
- +#include "absl/status/statusor.h" // from @com_google_absl
- +#include "absl/strings/str_format.h" // from @com_google_absl
- #include "absl/strings/string_view.h" // from @com_google_absl
- -#include "leveldb/cache.h" // from @com_google_leveldb
- -#include "leveldb/iterator.h" // from @com_google_leveldb
- -#include "leveldb/options.h" // from @com_google_leveldb
- -#include "leveldb/slice.h" // from @com_google_leveldb
- -#include "leveldb/status.h" // from @com_google_leveldb
- -#include "leveldb/table.h" // from @com_google_leveldb
- +#include "leveldb/cache.h" // from @com_google_leveldb
- +#include "leveldb/iterator.h" // from @com_google_leveldb
- +#include "leveldb/options.h" // from @com_google_leveldb
- +#include "leveldb/slice.h" // from @com_google_leveldb
- +#include "leveldb/status.h" // from @com_google_leveldb
- +#include "leveldb/table.h" // from @com_google_leveldb
- #include "tensorflow_lite_support/cc/port/status_macros.h"
- #include "tensorflow_lite_support/scann_ondevice/cc/mem_random_access_file.h"
- #include "tensorflow_lite_support/scann_ondevice/cc/utils.h"
- @@ -60,7 +60,8 @@ absl::StatusOr<absl::string_view> GetValueForKey(leveldb::Iterator* iterator,
-
- /* static */
- absl::StatusOr<std::unique_ptr<Index>> Index::CreateFromIndexBuffer(
- - const char* buffer_data, size_t buffer_size) {
- + const char* buffer_data,
- + size_t buffer_size) {
- // Use absl::WrapUnique() to call private constructor:
- // https://abseil.io/tips/126.
- std::unique_ptr<Index> index = absl::WrapUnique(new Index());
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/index.h b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/index.h
- index c630e6f827caa..15e709183a606 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/index.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/index.h
- @@ -18,12 +18,12 @@ limitations under the License.
-
- #include <memory>
-
- -#include "absl/status/status.h" // from @com_google_absl
- -#include "absl/status/statusor.h" // from @com_google_absl
- +#include "absl/status/status.h" // from @com_google_absl
- +#include "absl/status/statusor.h" // from @com_google_absl
- #include "absl/strings/string_view.h" // from @com_google_absl
- -#include "leveldb/cache.h" // from @com_google_leveldb
- -#include "leveldb/iterator.h" // from @com_google_leveldb
- -#include "leveldb/table.h" // from @com_google_leveldb
- +#include "leveldb/cache.h" // from @com_google_leveldb
- +#include "leveldb/iterator.h" // from @com_google_leveldb
- +#include "leveldb/table.h" // from @com_google_leveldb
- #include "tensorflow_lite_support/scann_ondevice/cc/mem_random_access_file.h"
- #include "tensorflow_lite_support/scann_ondevice/proto/index_config.pb.h"
-
- @@ -43,7 +43,8 @@ class Index {
- // Warning: Does not take ownership of the provided buffer, which must outlive
- // this object.
- static absl::StatusOr<std::unique_ptr<Index>> CreateFromIndexBuffer(
- - const char* buffer_data, size_t buffer_size);
- + const char* buffer_data,
- + size_t buffer_size);
-
- // Parses and returns the `IndexConfig` stored in the index file.
- absl::StatusOr<IndexConfig> GetIndexConfig() const;
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/index_builder.cc b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/index_builder.cc
- index fe5d1ef1175e4..0d802024c2b01 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/index_builder.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/index_builder.cc
- @@ -21,13 +21,13 @@ limitations under the License.
- #include <vector>
-
- #include "absl/container/btree_map.h" // from @com_google_absl
- -#include "absl/status/status.h" // from @com_google_absl
- -#include "absl/strings/str_format.h" // from @com_google_absl
- -#include "leveldb/options.h" // from @com_google_leveldb
- -#include "leveldb/slice.h" // from @com_google_leveldb
- -#include "leveldb/status.h" // from @com_google_leveldb
- -#include "leveldb/table_builder.h" // from @com_google_leveldb
- -#include "leveldb/write_batch.h" // from @com_google_leveldb
- +#include "absl/status/status.h" // from @com_google_absl
- +#include "absl/strings/str_format.h" // from @com_google_absl
- +#include "leveldb/options.h" // from @com_google_leveldb
- +#include "leveldb/slice.h" // from @com_google_leveldb
- +#include "leveldb/status.h" // from @com_google_leveldb
- +#include "leveldb/table_builder.h" // from @com_google_leveldb
- +#include "leveldb/write_batch.h" // from @com_google_leveldb
- #include "tensorflow_lite_support/cc/port/status_macros.h"
- #include "tensorflow_lite_support/scann_ondevice/cc/mem_writable_file.h"
- #include "tensorflow_lite_support/scann_ondevice/cc/utils.h"
- @@ -56,8 +56,10 @@ template <typename T>
- absl::StatusOr<std::string> CreateIndexBufferImpl(
- absl::Span<const T> database,
- absl::optional<absl::Span<const uint32_t>> partition_assignment,
- - absl::Span<const std::string> metadata, const std::string& userinfo,
- - IndexConfig index_config, bool compression) {
- + absl::Span<const std::string> metadata,
- + const std::string& userinfo,
- + IndexConfig index_config,
- + bool compression) {
- size_t num_partitions = 1;
- if (partition_assignment) {
- if (partition_assignment->size() != metadata.size()) {
- @@ -145,8 +147,8 @@ absl::StatusOr<std::string> CreateIndexBufferImpl(
-
- } // namespace
-
- -absl::StatusOr<std::string> CreateIndexBuffer(
- - const IndexedArtifacts& artifacts, bool compression) {
- +absl::StatusOr<std::string> CreateIndexBuffer(const IndexedArtifacts& artifacts,
- + bool compression) {
- if (artifacts.hashed_database.has_value() &&
- artifacts.float_database.has_value()) {
- return absl::InvalidArgumentError(
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/index_builder.h b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/index_builder.h
- index e8f8f06220578..53cac9b583da4 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/index_builder.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/index_builder.h
- @@ -16,12 +16,12 @@ limitations under the License.
- #ifndef TENSORFLOW_LITE_SUPPORT_SCANN_ONDEVICE_CC_INDEX_FILE_MUTATOR_H_
- #define TENSORFLOW_LITE_SUPPORT_SCANN_ONDEVICE_CC_INDEX_FILE_MUTATOR_H_
-
- -#include "tensorflow_lite_support/scann_ondevice/cc/core/serialized_searcher.pb.h"
- -#include "absl/status/statusor.h" // from @com_google_absl
- +#include "absl/status/statusor.h" // from @com_google_absl
- #include "absl/strings/string_view.h" // from @com_google_absl
- -#include "absl/types/optional.h" // from @com_google_absl
- -#include "absl/types/span.h" // from @com_google_absl
- -#include "leveldb/db.h" // from @com_google_leveldb
- +#include "absl/types/optional.h" // from @com_google_absl
- +#include "absl/types/span.h" // from @com_google_absl
- +#include "leveldb/db.h" // from @com_google_leveldb
- +#include "tensorflow_lite_support/scann_ondevice/cc/core/serialized_searcher.pb.h"
- #include "tensorflow_lite_support/scann_ondevice/proto/index_config.pb.h"
-
- namespace tflite {
- @@ -60,8 +60,8 @@ struct IndexedArtifacts {
- // Creates a byte buffer for the index file from the artifacts. Returns errors
- // when there are not exactly one database specified, or other issues with input
- // such as shape mismatch, invalid partition indices etc.
- -absl::StatusOr<std::string> CreateIndexBuffer(
- - const IndexedArtifacts& artifacts, bool compression);
- +absl::StatusOr<std::string> CreateIndexBuffer(const IndexedArtifacts& artifacts,
- + bool compression);
-
- } // namespace scann_ondevice
- } // namespace tflite
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/mem_random_access_file.cc b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/mem_random_access_file.cc
- index 7be71b90ef91d..59b9deb8e8682 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/mem_random_access_file.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/mem_random_access_file.cc
- @@ -19,8 +19,8 @@ limitations under the License.
- #include <cstddef>
- #include <cstdint>
-
- -#include "leveldb/env.h" // from @com_google_leveldb
- -#include "leveldb/slice.h" // from @com_google_leveldb
- +#include "leveldb/env.h" // from @com_google_leveldb
- +#include "leveldb/slice.h" // from @com_google_leveldb
- #include "leveldb/status.h" // from @com_google_leveldb
-
- namespace tflite {
- @@ -32,7 +32,8 @@ MemRandomAccessFile::MemRandomAccessFile(const char* buffer_data,
-
- MemRandomAccessFile::~MemRandomAccessFile() {}
-
- -leveldb::Status MemRandomAccessFile::Read(uint64_t offset, size_t n,
- +leveldb::Status MemRandomAccessFile::Read(uint64_t offset,
- + size_t n,
- leveldb::Slice* result,
- char* scratch) const {
- // Sanity check.
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/mem_random_access_file.h b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/mem_random_access_file.h
- index 0cf9cbfed59f4..5ca68f2e2c91e 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/mem_random_access_file.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/mem_random_access_file.h
- @@ -19,8 +19,8 @@ limitations under the License.
- #include <cstddef>
- #include <cstdint>
-
- -#include "leveldb/env.h" // from @com_google_leveldb
- -#include "leveldb/slice.h" // from @com_google_leveldb
- +#include "leveldb/env.h" // from @com_google_leveldb
- +#include "leveldb/slice.h" // from @com_google_leveldb
- #include "leveldb/status.h" // from @com_google_leveldb
-
- namespace tflite {
- @@ -39,7 +39,9 @@ class MemRandomAccessFile : public leveldb::RandomAccessFile {
-
- // Override of the `Read` function. Note that `scratch` is unused in the
- // implementation.
- - leveldb::Status Read(uint64_t offset, size_t n, leveldb::Slice* result,
- + leveldb::Status Read(uint64_t offset,
- + size_t n,
- + leveldb::Slice* result,
- char* scratch) const override;
-
- // Class is movable and non-copyable.
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/mem_writable_file.h b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/mem_writable_file.h
- index bb346bc7f12dc..842e837927d4e 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/mem_writable_file.h
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/mem_writable_file.h
- @@ -20,10 +20,10 @@ limitations under the License.
- #include <string>
-
- #include "absl/status/statusor.h" // from @com_google_absl
- -#include "absl/strings/cord.h" // from @com_google_absl
- -#include "leveldb/env.h" // from @com_google_leveldb
- -#include "leveldb/slice.h" // from @com_google_leveldb
- -#include "leveldb/status.h" // from @com_google_leveldb
- +#include "absl/strings/cord.h" // from @com_google_absl
- +#include "leveldb/env.h" // from @com_google_leveldb
- +#include "leveldb/slice.h" // from @com_google_leveldb
- +#include "leveldb/status.h" // from @com_google_leveldb
-
- namespace tflite {
- namespace scann_ondevice {
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/python/index_builder_py_wrapper.cc b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/python/index_builder_py_wrapper.cc
- index da147af88bc2a..709564035ff1f 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/python/index_builder_py_wrapper.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/python/index_builder_py_wrapper.cc
- @@ -15,14 +15,14 @@ limitations under the License.
-
- #include <string>
-
- -#include "tensorflow_lite_support/scann_ondevice/cc/core/serialized_searcher.pb.h"
- #include "absl/types/optional.h" // from @com_google_absl
- -#include "absl/types/span.h" // from @com_google_absl
- +#include "absl/types/span.h" // from @com_google_absl
- #include "pybind11/cast.h"
- #include "pybind11/pybind11.h"
- #include "pybind11/pytypes.h"
- -#include "pybind11_abseil/absl_casters.h" // from @pybind11_abseil
- +#include "pybind11_abseil/absl_casters.h" // from @pybind11_abseil
- #include "pybind11_abseil/status_casters.h" // from @pybind11_abseil
- +#include "tensorflow_lite_support/scann_ondevice/cc/core/serialized_searcher.pb.h"
- #include "tensorflow_lite_support/scann_ondevice/cc/index_builder.h"
-
- namespace pybind11 {
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/test/index_builder_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/test/index_builder_test.cc
- index 07da739f4a888..a1af840cc2f14 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/test/index_builder_test.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/test/index_builder_test.cc
- @@ -18,18 +18,18 @@ limitations under the License.
- #include <cstdint>
- #include <string>
-
- -#include "absl/flags/flag.h" // from @com_google_absl
- -#include "absl/memory/memory.h" // from @com_google_absl
- -#include "absl/status/status.h" // from @com_google_absl
- -#include "absl/strings/str_format.h" // from @com_google_absl
- +#include "absl/flags/flag.h" // from @com_google_absl
- +#include "absl/memory/memory.h" // from @com_google_absl
- +#include "absl/status/status.h" // from @com_google_absl
- +#include "absl/strings/str_format.h" // from @com_google_absl
- #include "absl/strings/string_view.h" // from @com_google_absl
- -#include "absl/types/span.h" // from @com_google_absl
- -#include "leveldb/env.h" // from @com_google_leveldb
- -#include "leveldb/iterator.h" // from @com_google_leveldb
- -#include "leveldb/options.h" // from @com_google_leveldb
- -#include "leveldb/slice.h" // from @com_google_leveldb
- -#include "leveldb/status.h" // from @com_google_leveldb
- -#include "leveldb/table.h" // from @com_google_leveldb
- +#include "absl/types/span.h" // from @com_google_absl
- +#include "leveldb/env.h" // from @com_google_leveldb
- +#include "leveldb/iterator.h" // from @com_google_leveldb
- +#include "leveldb/options.h" // from @com_google_leveldb
- +#include "leveldb/slice.h" // from @com_google_leveldb
- +#include "leveldb/status.h" // from @com_google_leveldb
- +#include "leveldb/table.h" // from @com_google_leveldb
- #include "tensorflow_lite_support/cc/port/gmock.h"
- #include "tensorflow_lite_support/cc/port/gtest.h"
- #include "tensorflow_lite_support/cc/port/status_matchers.h"
- @@ -137,22 +137,23 @@ TEST_P(PopulateIndexFileTest, WritesHashedDatabaseWithPartitioner) {
-
- {
- tflite::scann_ondevice::core::ScannOnDeviceConfig config =
- - ParseTextProtoOrDie<tflite::scann_ondevice::core::ScannOnDeviceConfig>(R"pb(
- - partitioner: {
- - leaf { dimension: 0 dimension: 0 }
- - leaf { dimension: 1 dimension: 1 }
- - leaf { dimension: 2 dimension: 2 }
- - leaf { dimension: 3 dimension: 3 }
- - leaf { dimension: 4 dimension: 4 }
- - leaf { dimension: 5 dimension: 5 }
- - leaf { dimension: 6 dimension: 6 }
- - leaf { dimension: 7 dimension: 7 }
- - leaf { dimension: 8 dimension: 8 }
- - leaf { dimension: 9 dimension: 9 }
- - leaf { dimension: 10 dimension: 10 }
- - leaf { dimension: 11 dimension: 11 }
- - }
- - )pb");
- + ParseTextProtoOrDie<tflite::scann_ondevice::core::ScannOnDeviceConfig>(
- + R"pb(
- + partitioner: {
- + leaf { dimension: 0 dimension: 0 }
- + leaf { dimension: 1 dimension: 1 }
- + leaf { dimension: 2 dimension: 2 }
- + leaf { dimension: 3 dimension: 3 }
- + leaf { dimension: 4 dimension: 4 }
- + leaf { dimension: 5 dimension: 5 }
- + leaf { dimension: 6 dimension: 6 }
- + leaf { dimension: 7 dimension: 7 }
- + leaf { dimension: 8 dimension: 8 }
- + leaf { dimension: 9 dimension: 9 }
- + leaf { dimension: 10 dimension: 10 }
- + leaf { dimension: 11 dimension: 11 }
- + }
- + )pb");
- std::vector<uint8_t> hashed_database;
- hashed_database.reserve(kNumEmbeddings * kDimensions);
- for (int i = 0; i < kNumEmbeddings; ++i) {
- @@ -202,16 +203,18 @@ TEST_P(PopulateIndexFileTest, WritesHashedDatabaseWithPartitioner) {
- auto hashed_table_iterator =
- absl::WrapUnique(hashed_table->NewIterator(leveldb::ReadOptions()));
-
- - SUPPORT_ASSERT_OK_AND_ASSIGN(std::string serialized_config,
- - LookupKey(hashed_table_iterator.get(), "INDEX_CONFIG"));
- + SUPPORT_ASSERT_OK_AND_ASSIGN(
- + std::string serialized_config,
- + LookupKey(hashed_table_iterator.get(), "INDEX_CONFIG"));
- IndexConfig index_config;
- EXPECT_TRUE(index_config.ParseFromString(serialized_config));
- EXPECT_THAT(
- index_config,
- EqualsProto(CreateExpectedConfigWithPartitioner(IndexConfig::UINT8)));
-
- - SUPPORT_ASSERT_OK_AND_ASSIGN(std::string userinfo,
- - LookupKey(hashed_table_iterator.get(), "USER_INFO"));
- + SUPPORT_ASSERT_OK_AND_ASSIGN(
- + std::string userinfo,
- + LookupKey(hashed_table_iterator.get(), "USER_INFO"));
- EXPECT_EQ(userinfo, "hashed_userinfo");
-
- // Partition assignment is based on i % kNumPartitions, so:
- @@ -253,9 +256,10 @@ TEST_P(PopulateIndexFileTest, WritesHashedDatabaseWithoutPartitioner) {
-
- {
- tflite::scann_ondevice::core::ScannOnDeviceConfig config =
- - ParseTextProtoOrDie<tflite::scann_ondevice::core::ScannOnDeviceConfig>(R"pb(
- - query_distance: SQUARED_L2_DISTANCE
- - )pb");
- + ParseTextProtoOrDie<tflite::scann_ondevice::core::ScannOnDeviceConfig>(
- + R"pb(
- + query_distance: SQUARED_L2_DISTANCE
- + )pb");
- std::vector<uint8_t> hashed_database;
- hashed_database.reserve(kNumEmbeddings * kDimensions);
- for (int i = 0; i < kNumEmbeddings; ++i) {
- @@ -299,22 +303,23 @@ TEST_P(PopulateIndexFileTest, WritesHashedDatabaseWithoutPartitioner) {
- auto float_table_iterator =
- absl::WrapUnique(float_table->NewIterator(leveldb::ReadOptions()));
-
- - SUPPORT_ASSERT_OK_AND_ASSIGN(std::string serialized_config,
- - LookupKey(float_table_iterator.get(), "INDEX_CONFIG"));
- + SUPPORT_ASSERT_OK_AND_ASSIGN(
- + std::string serialized_config,
- + LookupKey(float_table_iterator.get(), "INDEX_CONFIG"));
- IndexConfig index_config;
- EXPECT_TRUE(index_config.ParseFromString(serialized_config));
- EXPECT_THAT(
- index_config,
- EqualsProto(CreateExpectedConfigWithoutPartitioner(IndexConfig::UINT8)));
-
- - SUPPORT_ASSERT_OK_AND_ASSIGN(std::string userinfo,
- - LookupKey(float_table_iterator.get(), "USER_INFO"));
- + SUPPORT_ASSERT_OK_AND_ASSIGN(
- + std::string userinfo, LookupKey(float_table_iterator.get(), "USER_INFO"));
- EXPECT_EQ(userinfo, "hashed_userinfo");
-
- // Check that the unique embedding partition has the exact same contents as
- // the database used at construction time.
- SUPPORT_ASSERT_OK_AND_ASSIGN(std::string raw_partition_hashed,
- - LookupKey(float_table_iterator.get(), "E_0"));
- + LookupKey(float_table_iterator.get(), "E_0"));
- std::vector<char> hashed_partition(raw_partition_hashed.begin(),
- raw_partition_hashed.end());
- std::vector<char> expected;
- @@ -342,22 +347,23 @@ TEST_P(PopulateIndexFileTest, WritesFloatDatabaseWithPartitioner) {
-
- {
- tflite::scann_ondevice::core::ScannOnDeviceConfig config =
- - ParseTextProtoOrDie<tflite::scann_ondevice::core::ScannOnDeviceConfig>(R"pb(
- - partitioner: {
- - leaf { dimension: 0 dimension: 0 }
- - leaf { dimension: 1 dimension: 1 }
- - leaf { dimension: 2 dimension: 2 }
- - leaf { dimension: 3 dimension: 3 }
- - leaf { dimension: 4 dimension: 4 }
- - leaf { dimension: 5 dimension: 5 }
- - leaf { dimension: 6 dimension: 6 }
- - leaf { dimension: 7 dimension: 7 }
- - leaf { dimension: 8 dimension: 8 }
- - leaf { dimension: 9 dimension: 9 }
- - leaf { dimension: 10 dimension: 10 }
- - leaf { dimension: 11 dimension: 11 }
- - }
- - )pb");
- + ParseTextProtoOrDie<tflite::scann_ondevice::core::ScannOnDeviceConfig>(
- + R"pb(
- + partitioner: {
- + leaf { dimension: 0 dimension: 0 }
- + leaf { dimension: 1 dimension: 1 }
- + leaf { dimension: 2 dimension: 2 }
- + leaf { dimension: 3 dimension: 3 }
- + leaf { dimension: 4 dimension: 4 }
- + leaf { dimension: 5 dimension: 5 }
- + leaf { dimension: 6 dimension: 6 }
- + leaf { dimension: 7 dimension: 7 }
- + leaf { dimension: 8 dimension: 8 }
- + leaf { dimension: 9 dimension: 9 }
- + leaf { dimension: 10 dimension: 10 }
- + leaf { dimension: 11 dimension: 11 }
- + }
- + )pb");
- std::vector<float> float_database;
- float_database.reserve(kNumEmbeddings * kDimensions);
- for (int i = 0; i < kNumEmbeddings; ++i) {
- @@ -407,16 +413,17 @@ TEST_P(PopulateIndexFileTest, WritesFloatDatabaseWithPartitioner) {
- auto float_table_iterator =
- absl::WrapUnique(float_table->NewIterator(leveldb::ReadOptions()));
-
- - SUPPORT_ASSERT_OK_AND_ASSIGN(std::string serialized_config,
- - LookupKey(float_table_iterator.get(), "INDEX_CONFIG"));
- + SUPPORT_ASSERT_OK_AND_ASSIGN(
- + std::string serialized_config,
- + LookupKey(float_table_iterator.get(), "INDEX_CONFIG"));
- IndexConfig index_config;
- EXPECT_TRUE(index_config.ParseFromString(serialized_config));
- EXPECT_THAT(
- index_config,
- EqualsProto(CreateExpectedConfigWithPartitioner(IndexConfig::FLOAT)));
-
- - SUPPORT_ASSERT_OK_AND_ASSIGN(std::string userinfo,
- - LookupKey(float_table_iterator.get(), "USER_INFO"));
- + SUPPORT_ASSERT_OK_AND_ASSIGN(
- + std::string userinfo, LookupKey(float_table_iterator.get(), "USER_INFO"));
- EXPECT_EQ(userinfo, "float_userinfo");
-
- // Partition assignment is based on i % kNumPartitions, so:
- @@ -461,9 +468,10 @@ TEST_P(PopulateIndexFileTest, WritesFloatDatabaseWithoutPartitioner) {
-
- {
- tflite::scann_ondevice::core::ScannOnDeviceConfig config =
- - ParseTextProtoOrDie<tflite::scann_ondevice::core::ScannOnDeviceConfig>(R"pb(
- - query_distance: SQUARED_L2_DISTANCE
- - )pb");
- + ParseTextProtoOrDie<tflite::scann_ondevice::core::ScannOnDeviceConfig>(
- + R"pb(
- + query_distance: SQUARED_L2_DISTANCE
- + )pb");
- std::vector<float> float_database;
- float_database.reserve(kNumEmbeddings * kDimensions);
- for (int i = 0; i < kNumEmbeddings; ++i) {
- @@ -506,22 +514,23 @@ TEST_P(PopulateIndexFileTest, WritesFloatDatabaseWithoutPartitioner) {
- auto float_table_iterator =
- absl::WrapUnique(float_table->NewIterator(leveldb::ReadOptions()));
-
- - SUPPORT_ASSERT_OK_AND_ASSIGN(std::string serialized_config,
- - LookupKey(float_table_iterator.get(), "INDEX_CONFIG"));
- + SUPPORT_ASSERT_OK_AND_ASSIGN(
- + std::string serialized_config,
- + LookupKey(float_table_iterator.get(), "INDEX_CONFIG"));
- IndexConfig index_config;
- EXPECT_TRUE(index_config.ParseFromString(serialized_config));
- EXPECT_THAT(
- index_config,
- EqualsProto(CreateExpectedConfigWithoutPartitioner(IndexConfig::FLOAT)));
-
- - SUPPORT_ASSERT_OK_AND_ASSIGN(std::string userinfo,
- - LookupKey(float_table_iterator.get(), "USER_INFO"));
- + SUPPORT_ASSERT_OK_AND_ASSIGN(
- + std::string userinfo, LookupKey(float_table_iterator.get(), "USER_INFO"));
- EXPECT_EQ(userinfo, "float_userinfo");
-
- // Check that the unique embedding partition has the exact same contents as
- // the database used at construction time.
- SUPPORT_ASSERT_OK_AND_ASSIGN(std::string raw_partition_float,
- - LookupKey(float_table_iterator.get(), "E_0"));
- + LookupKey(float_table_iterator.get(), "E_0"));
- const float* raw_partition_float_ptr =
- reinterpret_cast<const float*>(raw_partition_float.data());
- std::vector<float> float_partition(
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/test/index_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/test/index_test.cc
- index 983dd8d2bc8e8..cc1225f679f66 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/test/index_test.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/test/index_test.cc
- @@ -18,9 +18,8 @@ limitations under the License.
- #include <cstdint>
- #include <memory>
-
- -#include "tensorflow_lite_support/scann_ondevice/cc/core/serialized_searcher.pb.h"
- -#include "absl/flags/flag.h" // from @com_google_absl
- -#include "absl/status/status.h" // from @com_google_absl
- +#include "absl/flags/flag.h" // from @com_google_absl
- +#include "absl/status/status.h" // from @com_google_absl
- #include "absl/strings/string_view.h" // from @com_google_absl
- #include "tensorflow/lite/core/shims/cc/shims_test_util.h"
- #include "tensorflow_lite_support/cc/port/gmock.h"
- @@ -29,6 +28,7 @@ limitations under the License.
- #include "tensorflow_lite_support/cc/task/core/external_file_handler.h"
- #include "tensorflow_lite_support/cc/task/core/proto/external_file_proto_inc.h"
- #include "tensorflow_lite_support/cc/test/test_utils.h"
- +#include "tensorflow_lite_support/scann_ondevice/cc/core/serialized_searcher.pb.h"
- #include "tensorflow_lite_support/scann_ondevice/proto/index_config.pb.h"
-
- namespace tflite {
- @@ -47,10 +47,10 @@ constexpr char kDummyIndexPath[] =
- TEST(CreateFromOptionsTest, Succeeds) {
- // Load file in memory using ExternalFile.
- ExternalFile file;
- - file.set_file_name(
- - JoinPath("./" /*test src dir*/, kDummyIndexPath));
- - SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ExternalFileHandler> handler,
- - ExternalFileHandler::CreateFromExternalFile(&file));
- + file.set_file_name(JoinPath("./" /*test src dir*/, kDummyIndexPath));
- + SUPPORT_ASSERT_OK_AND_ASSIGN(
- + std::unique_ptr<ExternalFileHandler> handler,
- + ExternalFileHandler::CreateFromExternalFile(&file));
- absl::string_view file_contents = handler->GetFileContent();
-
- SUPPORT_EXPECT_OK(
- @@ -62,8 +62,7 @@ class IndexTest : public tflite_shims::testing::Test {
- IndexTest() {
- // Load file in memory using ExternalFile.
- ExternalFile file;
- - file.set_file_name(
- - JoinPath("./" /*test src dir*/, kDummyIndexPath));
- + file.set_file_name(JoinPath("./" /*test src dir*/, kDummyIndexPath));
- handler_ = ExternalFileHandler::CreateFromExternalFile(&file).value();
- absl::string_view file_contents = handler_->GetFileContent();
- // Build index.
- @@ -98,18 +97,18 @@ TEST_F(IndexTest, GetUserInfoSucceeds) {
-
- TEST_F(IndexTest, GetPartitionAtIndexSucceeds) {
- SUPPORT_ASSERT_OK_AND_ASSIGN(absl::string_view partition_0,
- - index_->GetPartitionAtIndex(0));
- + index_->GetPartitionAtIndex(0));
- EXPECT_EQ(partition_0.size(), 8);
- - const uint8_t *partition =
- - reinterpret_cast<const uint8_t *>(partition_0.data());
- + const uint8_t* partition =
- + reinterpret_cast<const uint8_t*>(partition_0.data());
- for (int i = 0; i < 8; ++i) {
- EXPECT_EQ(partition[i], i);
- }
-
- SUPPORT_ASSERT_OK_AND_ASSIGN(absl::string_view partition_1,
- - index_->GetPartitionAtIndex(1));
- + index_->GetPartitionAtIndex(1));
- EXPECT_EQ(partition_1.size(), 4);
- - partition = reinterpret_cast<const uint8_t *>(partition_1.data());
- + partition = reinterpret_cast<const uint8_t*>(partition_1.data());
- for (int i = 0; i < 4; ++i) {
- EXPECT_EQ(partition[i], i + 8);
- }
- @@ -122,15 +121,15 @@ TEST_F(IndexTest, GetPartitionAtIndexFailsOutOfBounds) {
-
- TEST_F(IndexTest, GetMetadataAtIndexSucceeds) {
- SUPPORT_ASSERT_OK_AND_ASSIGN(absl::string_view metadata_0,
- - index_->GetMetadataAtIndex(0));
- + index_->GetMetadataAtIndex(0));
- EXPECT_EQ(metadata_0, "metadata_0");
-
- SUPPORT_ASSERT_OK_AND_ASSIGN(absl::string_view metadata_1,
- - index_->GetMetadataAtIndex(1));
- + index_->GetMetadataAtIndex(1));
- EXPECT_EQ(metadata_1, "metadata_1");
-
- SUPPORT_ASSERT_OK_AND_ASSIGN(absl::string_view metadata_2,
- - index_->GetMetadataAtIndex(2));
- + index_->GetMetadataAtIndex(2));
- EXPECT_EQ(metadata_2, "metadata_2");
- }
-
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/test/mem_writable_file_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/test/mem_writable_file_test.cc
- index 3230b34db05ba..afb55e5472161 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/test/mem_writable_file_test.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/test/mem_writable_file_test.cc
- @@ -26,7 +26,7 @@ namespace {
- TEST(MemWritableFileTest, AppendsContent) {
- std::string buffer;
- SUPPORT_ASSERT_OK_AND_ASSIGN(auto mem_writable_file,
- - MemWritableFile::Create(&buffer));
- + MemWritableFile::Create(&buffer));
-
- ASSERT_TRUE(mem_writable_file->Append("aaa").ok());
- EXPECT_EQ(buffer, "aaa");
- diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/test/python/leveldb_testing_utils_py_wrapper.cc b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/test/python/leveldb_testing_utils_py_wrapper.cc
- index ca364e06e7d1d..1ae7e0ce9ed09 100644
- --- a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/test/python/leveldb_testing_utils_py_wrapper.cc
- +++ b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/test/python/leveldb_testing_utils_py_wrapper.cc
- @@ -16,17 +16,17 @@ limitations under the License.
- #include <cstdint>
- #include <vector>
-
- -#include "absl/memory/memory.h" // from @com_google_absl
- -#include "absl/status/status.h" // from @com_google_absl
- -#include "absl/status/statusor.h" // from @com_google_absl
- +#include "absl/memory/memory.h" // from @com_google_absl
- +#include "absl/status/status.h" // from @com_google_absl
- +#include "absl/status/statusor.h" // from @com_google_absl
- #include "absl/strings/str_format.h" // from @com_google_absl
- -#include "leveldb/env.h" // from @com_google_leveldb
- -#include "leveldb/options.h" // from @com_google_leveldb
- -#include "leveldb/table.h" // from @com_google_leveldb
- +#include "leveldb/env.h" // from @com_google_leveldb
- +#include "leveldb/options.h" // from @com_google_leveldb
- +#include "leveldb/table.h" // from @com_google_leveldb
- #include "pybind11/cast.h"
- #include "pybind11/pybind11.h"
- #include "pybind11/pytypes.h"
- -#include "pybind11_abseil/absl_casters.h" // from @pybind11_abseil
- +#include "pybind11_abseil/absl_casters.h" // from @pybind11_abseil
- #include "pybind11_abseil/status_casters.h" // from @pybind11_abseil
-
- namespace pybind11 {
- diff --git a/third_party/tflite_support/src/third_party/fft2d/fft.h b/third_party/tflite_support/src/third_party/fft2d/fft.h
- index 36d838b7f6280..35dbcc766c169 100644
- --- a/third_party/tflite_support/src/third_party/fft2d/fft.h
- +++ b/third_party/tflite_support/src/third_party/fft2d/fft.h
- @@ -22,12 +22,12 @@ limitations under the License.
- extern "C" {
- #endif
-
- -extern void cdft(int, int, double *, int *, double *);
- -extern void rdft(int, int, double *, int *, double *);
- -extern void ddct(int, int, double *, int *, double *);
- -extern void ddst(int, int, double *, int *, double *);
- -extern void dfct(int, double *, double *, int *, double *);
- -extern void dfst(int, double *, double *, int *, double *);
- +extern void cdft(int, int, double*, int*, double*);
- +extern void rdft(int, int, double*, int*, double*);
- +extern void ddct(int, int, double*, int*, double*);
- +extern void ddst(int, int, double*, int*, double*);
- +extern void dfct(int, double*, double*, int*, double*);
- +extern void dfst(int, double*, double*, int*, double*);
-
- #ifdef __cplusplus
- }
- diff --git a/third_party/tflite_support/src/third_party/fft2d/fft2d.h b/third_party/tflite_support/src/third_party/fft2d/fft2d.h
- index d587b3b441ce2..d79441827d54c 100644
- --- a/third_party/tflite_support/src/third_party/fft2d/fft2d.h
- +++ b/third_party/tflite_support/src/third_party/fft2d/fft2d.h
- @@ -22,12 +22,12 @@ limitations under the License.
- extern "C" {
- #endif
-
- -extern void cdft2d(int, int, int, double **, double *, int *, double *);
- -extern void rdft2d(int, int, int, double **, double *, int *, double *);
- -extern void ddct2d(int, int, int, double **, double *, int *, double *);
- -extern void ddst2d(int, int, int, double **, double *, int *, double *);
- -extern void ddct8x8s(int isgn, double **a);
- -extern void ddct16x16s(int isgn, double **a);
- +extern void cdft2d(int, int, int, double**, double*, int*, double*);
- +extern void rdft2d(int, int, int, double**, double*, int*, double*);
- +extern void ddct2d(int, int, int, double**, double*, int*, double*);
- +extern void ddst2d(int, int, int, double**, double*, int*, double*);
- +extern void ddct8x8s(int isgn, double** a);
- +extern void ddct16x16s(int isgn, double** a);
-
- #ifdef __cplusplus
- }
- --
- 2.36.1.124.g0e6072fb45-goog
|