# ==================================================================== # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ==================================================================== from unittest import TestCase, main from lucene import * class PositionIncrementTestCase(TestCase): """ Unit tests ported from Java Lucene """ def testSetPosition(self): class _analyzer(PythonAnalyzer): def tokenStream(_self, fieldName, reader): class _tokenStream(PythonTokenStream): def __init__(self_): super(_tokenStream, self_).__init__() self_.TOKENS = ["1", "2", "3", "4", "5"] self_.INCREMENTS = [1, 2, 1, 0, 1] self_.i = 0 self_.posIncrAtt = self_.addAttribute(PositionIncrementAttribute.class_) self_.termAtt = self_.addAttribute(TermAttribute.class_) self_.offsetAtt = self_.addAttribute(OffsetAttribute.class_) def incrementToken(self_): if self_.i == len(self_.TOKENS): return False self_.termAtt.setTermBuffer(self_.TOKENS[self_.i]) self_.offsetAtt.setOffset(self_.i, self_.i) self_.posIncrAtt.setPositionIncrement(self_.INCREMENTS[self_.i]) self_.i += 1 return True def end(self_): pass def reset(self_): pass def close(self_): pass return _tokenStream() analyzer = _analyzer() store = RAMDirectory() writer = IndexWriter(store, analyzer, True, IndexWriter.MaxFieldLength.LIMITED) d = Document() d.add(Field("field", "bogus", Field.Store.YES, Field.Index.ANALYZED)) writer.addDocument(d) writer.optimize() writer.close() searcher = IndexSearcher(store, True) pos = searcher.getIndexReader().termPositions(Term("field", "1")) pos.next() # first token should be at position 0 self.assertEqual(0, pos.nextPosition()) pos = searcher.getIndexReader().termPositions(Term("field", "2")) pos.next() # second token should be at position 2 self.assertEqual(2, pos.nextPosition()) q = PhraseQuery() q.add(Term("field", "1")) q.add(Term("field", "2")) hits = searcher.search(q, None, 1000).scoreDocs self.assertEqual(0, len(hits)) # same as previous, just specify positions explicitely. q = PhraseQuery() q.add(Term("field", "1"), 0) q.add(Term("field", "2"), 1) hits = searcher.search(q, None, 1000).scoreDocs self.assertEqual(0, len(hits)) # specifying correct positions should find the phrase. q = PhraseQuery() q.add(Term("field", "1"), 0) q.add(Term("field", "2"), 2) hits = searcher.search(q, None, 1000).scoreDocs self.assertEqual(1, len(hits)) q = PhraseQuery() q.add(Term("field", "2")) q.add(Term("field", "3")) hits = searcher.search(q, None, 1000).scoreDocs self.assertEqual(1, len(hits)) q = PhraseQuery() q.add(Term("field", "3")) q.add(Term("field", "4")) hits = searcher.search(q, None, 1000).scoreDocs self.assertEqual(0, len(hits)) # phrase query would find it when correct positions are specified. q = PhraseQuery() q.add(Term("field", "3"), 0) q.add(Term("field", "4"), 0) hits = searcher.search(q, None, 1000).scoreDocs self.assertEqual(1, len(hits)) # phrase query should fail for non existing searched term # even if there exist another searched terms in the same searched # position. q = PhraseQuery() q.add(Term("field", "3"), 0) q.add(Term("field", "9"), 0) hits = searcher.search(q, None, 1000).scoreDocs self.assertEqual(0, len(hits)) # multi-phrase query should succed for non existing searched term # because there exist another searched terms in the same searched # position. mq = MultiPhraseQuery() mq.add([Term("field", "3"), Term("field", "9")], 0) hits = searcher.search(mq, None, 1000).scoreDocs self.assertEqual(1, len(hits)) q = PhraseQuery() q.add(Term("field", "2")) q.add(Term("field", "4")) hits = searcher.search(q, None, 1000).scoreDocs self.assertEqual(1, len(hits)) q = PhraseQuery() q.add(Term("field", "3")) q.add(Term("field", "5")) hits = searcher.search(q, None, 1000).scoreDocs self.assertEqual(1, len(hits)) q = PhraseQuery() q.add(Term("field", "4")) q.add(Term("field", "5")) hits = searcher.search(q, None, 1000).scoreDocs self.assertEqual(1, len(hits)) q = PhraseQuery() q.add(Term("field", "2")) q.add(Term("field", "5")) hits = searcher.search(q, None, 1000).scoreDocs self.assertEqual(0, len(hits)) # should not find "1 2" because there is a gap of 1 in the index qp = QueryParser(Version.LUCENE_CURRENT, "field", StopWhitespaceAnalyzer(False)) q = PhraseQuery.cast_(qp.parse("\"1 2\"")) hits = searcher.search(q, None, 1000).scoreDocs self.assertEqual(0, len(hits)) # omitted stop word cannot help because stop filter swallows the # increments. q = PhraseQuery.cast_(qp.parse("\"1 stop 2\"")) hits = searcher.search(q, None, 1000).scoreDocs self.assertEqual(0, len(hits)) # query parser alone won't help, because stop filter swallows the # increments. qp.setEnablePositionIncrements(True) q = PhraseQuery.cast_(qp.parse("\"1 stop 2\"")) hits = searcher.search(q, None, 1000).scoreDocs self.assertEqual(0, len(hits)) # stop filter alone won't help, because query parser swallows the # increments. qp.setEnablePositionIncrements(False) q = PhraseQuery.cast_(qp.parse("\"1 stop 2\"")) hits = searcher.search(q, None, 1000).scoreDocs self.assertEqual(0, len(hits)) # when both qp qnd stopFilter propagate increments, we should find # the doc. qp = QueryParser(Version.LUCENE_CURRENT, "field", StopWhitespaceAnalyzer(True)) qp.setEnablePositionIncrements(True) q = PhraseQuery.cast_(qp.parse("\"1 stop 2\"")) hits = searcher.search(q, None, 1000).scoreDocs self.assertEqual(1, len(hits)) def testPayloadsPos0(self): dir = RAMDirectory() writer = IndexWriter(dir, TestPayloadAnalyzer(), True, IndexWriter.MaxFieldLength.LIMITED) doc = Document() doc.add(Field("content", StringReader("a a b c d e a f g h i j a b k k"))) writer.addDocument(doc) r = writer.getReader() tp = r.termPositions(Term("content", "a")) count = 0 self.assert_(tp.next()) # "a" occurs 4 times self.assertEqual(4, tp.freq()) expected = 0 self.assertEqual(expected, tp.nextPosition()) self.assertEqual(1, tp.nextPosition()) self.assertEqual(3, tp.nextPosition()) self.assertEqual(6, tp.nextPosition()) # only one doc has "a" self.assert_(not tp.next()) searcher = IndexSearcher(r) stq1 = SpanTermQuery(Term("content", "a")) stq2 = SpanTermQuery(Term("content", "k")) sqs = [stq1, stq2] snq = SpanNearQuery(sqs, 30, False) count = 0 sawZero = False pspans = snq.getSpans(searcher.getIndexReader()) while pspans.next(): payloads = pspans.getPayload() sawZero |= pspans.start() == 0 it = payloads.iterator() while it.hasNext(): count += 1 it.next() self.assertEqual(5, count) self.assert_(sawZero) spans = snq.getSpans(searcher.getIndexReader()) count = 0 sawZero = False while spans.next(): count += 1 sawZero |= spans.start() == 0 self.assertEqual(4, count) self.assert_(sawZero) sawZero = False psu = PayloadSpanUtil(searcher.getIndexReader()) pls = psu.getPayloadsForQuery(snq) count = pls.size() it = pls.iterator() while it.hasNext(): bytes = JArray('byte').cast_(it.next()) s = bytes.string_ sawZero |= s == "pos: 0" self.assertEqual(5, count) self.assert_(sawZero) writer.close() searcher.getIndexReader().close() dir.close() class StopWhitespaceAnalyzer(PythonAnalyzer): def __init__(self, enablePositionIncrements): super(StopWhitespaceAnalyzer, self).__init__() self.enablePositionIncrements = enablePositionIncrements self.a = WhitespaceAnalyzer() def tokenStream(self, fieldName, reader): ts = self.a.tokenStream(fieldName, reader) set = HashSet() set.add("stop") return StopFilter(self.enablePositionIncrements, ts, set) class TestPayloadAnalyzer(PythonAnalyzer): def tokenStream(self, fieldName, reader): result = LowerCaseTokenizer(reader) return PayloadFilter(result, fieldName) class PayloadFilter(PythonTokenFilter): def __init__(self, input, fieldName): super(PayloadFilter, self).__init__(input) self.input = input self.fieldName = fieldName self.pos = 0 self.i = 0 self.posIncrAttr = input.addAttribute(PositionIncrementAttribute.class_) self.payloadAttr = input.addAttribute(PayloadAttribute.class_) self.termAttr = input.addAttribute(TermAttribute.class_) def incrementToken(self): if self.input.incrementToken(): bytes = JArray('byte')("pos: %d" %(self.pos)) self.payloadAttr.setPayload(Payload(bytes)) if self.i % 2 == 1: posIncr = 1 else: posIncr = 0 self.posIncrAttr.setPositionIncrement(posIncr) self.pos += posIncr self.i += 1 return True return False if __name__ == "__main__": import sys, lucene lucene.initVM() if '-loop' in sys.argv: sys.argv.remove('-loop') while True: try: main() except: pass else: main()